Merge branch 'main' into refactor/render-mermaid-in-shadow-dom

This commit is contained in:
one 2025-08-15 15:51:08 +08:00
commit a34e407aa2
31 changed files with 958 additions and 292 deletions

View File

@ -78,7 +78,7 @@
"node-stream-zip": "^1.15.0", "node-stream-zip": "^1.15.0",
"officeparser": "^4.2.0", "officeparser": "^4.2.0",
"os-proxy-config": "^1.1.2", "os-proxy-config": "^1.1.2",
"selection-hook": "^1.0.8", "selection-hook": "^1.0.9",
"turndown": "7.2.0" "turndown": "7.2.0"
}, },
"devDependencies": { "devDependencies": {

View File

@ -9,6 +9,7 @@ import { CancellationToken, UpdateInfo } from 'builder-util-runtime'
import { app, BrowserWindow, dialog } from 'electron' import { app, BrowserWindow, dialog } from 'electron'
import { AppUpdater as _AppUpdater, autoUpdater, Logger, NsisUpdater, UpdateCheckResult } from 'electron-updater' import { AppUpdater as _AppUpdater, autoUpdater, Logger, NsisUpdater, UpdateCheckResult } from 'electron-updater'
import path from 'path' import path from 'path'
import semver from 'semver'
import icon from '../../../build/icon.png?asset' import icon from '../../../build/icon.png?asset'
import { configManager } from './ConfigManager' import { configManager } from './ConfigManager'
@ -44,12 +45,6 @@ export default class AppUpdater {
// 检测到不需要更新时 // 检测到不需要更新时
autoUpdater.on('update-not-available', () => { autoUpdater.on('update-not-available', () => {
if (configManager.getTestPlan() && this.autoUpdater.channel !== UpgradeChannel.LATEST) {
logger.info('test plan is enabled, but update is not available, do not send update not available event')
// will not send update not available event, because will check for updates with latest channel
return
}
windowService.getMainWindow()?.webContents.send(IpcChannel.UpdateNotAvailable) windowService.getMainWindow()?.webContents.send(IpcChannel.UpdateNotAvailable)
}) })
@ -72,18 +67,24 @@ export default class AppUpdater {
this.autoUpdater = autoUpdater this.autoUpdater = autoUpdater
} }
private async _getPreReleaseVersionFromGithub(channel: UpgradeChannel) { private async _getReleaseVersionFromGithub(channel: UpgradeChannel) {
const headers = {
Accept: 'application/vnd.github+json',
'X-GitHub-Api-Version': '2022-11-28',
'Accept-Language': 'en-US,en;q=0.9'
}
try { try {
logger.info(`get pre release version from github: ${channel}`) logger.info(`get release version from github: ${channel}`)
const responses = await fetch('https://api.github.com/repos/CherryHQ/cherry-studio/releases?per_page=8', { const responses = await fetch('https://api.github.com/repos/CherryHQ/cherry-studio/releases?per_page=8', {
headers: { headers
Accept: 'application/vnd.github+json',
'X-GitHub-Api-Version': '2022-11-28',
'Accept-Language': 'en-US,en;q=0.9'
}
}) })
const data = (await responses.json()) as GithubReleaseInfo[] const data = (await responses.json()) as GithubReleaseInfo[]
let mightHaveLatest = false
const release: GithubReleaseInfo | undefined = data.find((item: GithubReleaseInfo) => { const release: GithubReleaseInfo | undefined = data.find((item: GithubReleaseInfo) => {
if (!item.draft && !item.prerelease) {
mightHaveLatest = true
}
return item.prerelease && item.tag_name.includes(`-${channel}.`) return item.prerelease && item.tag_name.includes(`-${channel}.`)
}) })
@ -91,8 +92,29 @@ export default class AppUpdater {
return null return null
} }
logger.info(`prerelease url is ${release.tag_name}, set channel to ${channel}`) // if the release version is the same as the current version, return null
if (release.tag_name === app.getVersion()) {
return null
}
if (mightHaveLatest) {
logger.info(`might have latest release, get latest release`)
const latestReleaseResponse = await fetch(
'https://api.github.com/repos/CherryHQ/cherry-studio/releases/latest',
{
headers
}
)
const latestRelease = (await latestReleaseResponse.json()) as GithubReleaseInfo
if (semver.gt(latestRelease.tag_name, release.tag_name)) {
logger.info(
`latest release version is ${latestRelease.tag_name}, prerelease version is ${release.tag_name}, return null`
)
return null
}
}
logger.info(`release url is ${release.tag_name}, set channel to ${channel}`)
return `https://github.com/CherryHQ/cherry-studio/releases/download/${release.tag_name}` return `https://github.com/CherryHQ/cherry-studio/releases/download/${release.tag_name}`
} catch (error) { } catch (error) {
logger.error('Failed to get latest not draft version from github:', error as Error) logger.error('Failed to get latest not draft version from github:', error as Error)
@ -151,14 +173,14 @@ export default class AppUpdater {
return return
} }
const preReleaseUrl = await this._getPreReleaseVersionFromGithub(channel) const releaseUrl = await this._getReleaseVersionFromGithub(channel)
if (preReleaseUrl) { if (releaseUrl) {
logger.info(`prerelease url is ${preReleaseUrl}, set channel to ${channel}`) logger.info(`release url is ${releaseUrl}, set channel to ${channel}`)
this._setChannel(channel, preReleaseUrl) this._setChannel(channel, releaseUrl)
return return
} }
// if no prerelease url, use github latest to avoid error // if no prerelease url, use github latest to get release
this._setChannel(UpgradeChannel.LATEST, FeedUrl.GITHUB_LATEST) this._setChannel(UpgradeChannel.LATEST, FeedUrl.GITHUB_LATEST)
return return
} }
@ -195,17 +217,6 @@ export default class AppUpdater {
`update check result: ${this.updateCheckResult?.isUpdateAvailable}, channel: ${this.autoUpdater.channel}, currentVersion: ${this.autoUpdater.currentVersion}` `update check result: ${this.updateCheckResult?.isUpdateAvailable}, channel: ${this.autoUpdater.channel}, currentVersion: ${this.autoUpdater.currentVersion}`
) )
// if the update is not available, and the test plan is enabled, set the feed url to the github latest
if (
!this.updateCheckResult?.isUpdateAvailable &&
configManager.getTestPlan() &&
this.autoUpdater.channel !== UpgradeChannel.LATEST
) {
logger.info('test plan is enabled, but update is not available, set channel to latest')
this._setChannel(UpgradeChannel.LATEST, FeedUrl.GITHUB_LATEST)
this.updateCheckResult = await this.autoUpdater.checkForUpdates()
}
if (this.updateCheckResult?.isUpdateAvailable && !this.autoUpdater.autoDownload) { if (this.updateCheckResult?.isUpdateAvailable && !this.autoUpdater.autoDownload) {
// 如果 autoDownload 为 false则需要再调用下面的函数触发下 // 如果 autoDownload 为 false则需要再调用下面的函数触发下
// do not use await, because it will block the return of this function // do not use await, because it will block the return of this function

View File

@ -21,6 +21,27 @@ class BackupManager {
private tempDir = path.join(app.getPath('temp'), 'cherry-studio', 'backup', 'temp') private tempDir = path.join(app.getPath('temp'), 'cherry-studio', 'backup', 'temp')
private backupDir = path.join(app.getPath('temp'), 'cherry-studio', 'backup') private backupDir = path.join(app.getPath('temp'), 'cherry-studio', 'backup')
// 缓存实例,避免重复创建
private s3Storage: S3Storage | null = null
private webdavInstance: WebDav | null = null
// 缓存核心连接配置,用于检测连接配置是否变更
private cachedS3ConnectionConfig: {
endpoint: string
region: string
bucket: string
accessKeyId: string
secretAccessKey: string
root?: string
} | null = null
private cachedWebdavConnectionConfig: {
webdavHost: string
webdavUser?: string
webdavPass?: string
webdavPath?: string
} | null = null
constructor() { constructor() {
this.checkConnection = this.checkConnection.bind(this) this.checkConnection = this.checkConnection.bind(this)
this.backup = this.backup.bind(this) this.backup = this.backup.bind(this)
@ -87,6 +108,88 @@ class BackupManager {
} }
} }
/**
* fileName
*/
private isS3ConfigEqual(cachedConfig: typeof this.cachedS3ConnectionConfig, config: S3Config): boolean {
if (!cachedConfig) return false
return (
cachedConfig.endpoint === config.endpoint &&
cachedConfig.region === config.region &&
cachedConfig.bucket === config.bucket &&
cachedConfig.accessKeyId === config.accessKeyId &&
cachedConfig.secretAccessKey === config.secretAccessKey &&
cachedConfig.root === config.root
)
}
/**
* WebDAV fileName
*/
private isWebDavConfigEqual(cachedConfig: typeof this.cachedWebdavConnectionConfig, config: WebDavConfig): boolean {
if (!cachedConfig) return false
return (
cachedConfig.webdavHost === config.webdavHost &&
cachedConfig.webdavUser === config.webdavUser &&
cachedConfig.webdavPass === config.webdavPass &&
cachedConfig.webdavPath === config.webdavPath
)
}
/**
* S3Storage
*
*/
private getS3Storage(config: S3Config): S3Storage {
// 检查核心连接配置是否变更
const configChanged = !this.isS3ConfigEqual(this.cachedS3ConnectionConfig, config)
if (configChanged || !this.s3Storage) {
this.s3Storage = new S3Storage(config)
// 只缓存连接相关的配置字段
this.cachedS3ConnectionConfig = {
endpoint: config.endpoint,
region: config.region,
bucket: config.bucket,
accessKeyId: config.accessKeyId,
secretAccessKey: config.secretAccessKey,
root: config.root
}
logger.debug('[BackupManager] Created new S3Storage instance')
} else {
logger.debug('[BackupManager] Reusing existing S3Storage instance')
}
return this.s3Storage
}
/**
* WebDav
*
*/
private getWebDavInstance(config: WebDavConfig): WebDav {
// 检查核心连接配置是否变更
const configChanged = !this.isWebDavConfigEqual(this.cachedWebdavConnectionConfig, config)
if (configChanged || !this.webdavInstance) {
this.webdavInstance = new WebDav(config)
// 只缓存连接相关的配置字段
this.cachedWebdavConnectionConfig = {
webdavHost: config.webdavHost,
webdavUser: config.webdavUser,
webdavPass: config.webdavPass,
webdavPath: config.webdavPath
}
logger.debug('[BackupManager] Created new WebDav instance')
} else {
logger.debug('[BackupManager] Reusing existing WebDav instance')
}
return this.webdavInstance
}
async backup( async backup(
_: Electron.IpcMainInvokeEvent, _: Electron.IpcMainInvokeEvent,
fileName: string, fileName: string,
@ -322,7 +425,7 @@ class BackupManager {
async backupToWebdav(_: Electron.IpcMainInvokeEvent, data: string, webdavConfig: WebDavConfig) { async backupToWebdav(_: Electron.IpcMainInvokeEvent, data: string, webdavConfig: WebDavConfig) {
const filename = webdavConfig.fileName || 'cherry-studio.backup.zip' const filename = webdavConfig.fileName || 'cherry-studio.backup.zip'
const backupedFilePath = await this.backup(_, filename, data, undefined, webdavConfig.skipBackupFile) const backupedFilePath = await this.backup(_, filename, data, undefined, webdavConfig.skipBackupFile)
const webdavClient = new WebDav(webdavConfig) const webdavClient = this.getWebDavInstance(webdavConfig)
try { try {
let result let result
if (webdavConfig.disableStream) { if (webdavConfig.disableStream) {
@ -349,7 +452,7 @@ class BackupManager {
async restoreFromWebdav(_: Electron.IpcMainInvokeEvent, webdavConfig: WebDavConfig) { async restoreFromWebdav(_: Electron.IpcMainInvokeEvent, webdavConfig: WebDavConfig) {
const filename = webdavConfig.fileName || 'cherry-studio.backup.zip' const filename = webdavConfig.fileName || 'cherry-studio.backup.zip'
const webdavClient = new WebDav(webdavConfig) const webdavClient = this.getWebDavInstance(webdavConfig)
try { try {
const retrievedFile = await webdavClient.getFileContents(filename) const retrievedFile = await webdavClient.getFileContents(filename)
const backupedFilePath = path.join(this.backupDir, filename) const backupedFilePath = path.join(this.backupDir, filename)
@ -377,7 +480,7 @@ class BackupManager {
listWebdavFiles = async (_: Electron.IpcMainInvokeEvent, config: WebDavConfig) => { listWebdavFiles = async (_: Electron.IpcMainInvokeEvent, config: WebDavConfig) => {
try { try {
const client = new WebDav(config) const client = this.getWebDavInstance(config)
const response = await client.getDirectoryContents() const response = await client.getDirectoryContents()
const files = Array.isArray(response) ? response : response.data const files = Array.isArray(response) ? response : response.data
@ -467,7 +570,7 @@ class BackupManager {
} }
async checkConnection(_: Electron.IpcMainInvokeEvent, webdavConfig: WebDavConfig) { async checkConnection(_: Electron.IpcMainInvokeEvent, webdavConfig: WebDavConfig) {
const webdavClient = new WebDav(webdavConfig) const webdavClient = this.getWebDavInstance(webdavConfig)
return await webdavClient.checkConnection() return await webdavClient.checkConnection()
} }
@ -477,13 +580,13 @@ class BackupManager {
path: string, path: string,
options?: CreateDirectoryOptions options?: CreateDirectoryOptions
) { ) {
const webdavClient = new WebDav(webdavConfig) const webdavClient = this.getWebDavInstance(webdavConfig)
return await webdavClient.createDirectory(path, options) return await webdavClient.createDirectory(path, options)
} }
async deleteWebdavFile(_: Electron.IpcMainInvokeEvent, fileName: string, webdavConfig: WebDavConfig) { async deleteWebdavFile(_: Electron.IpcMainInvokeEvent, fileName: string, webdavConfig: WebDavConfig) {
try { try {
const webdavClient = new WebDav(webdavConfig) const webdavClient = this.getWebDavInstance(webdavConfig)
return await webdavClient.deleteFile(fileName) return await webdavClient.deleteFile(fileName)
} catch (error: any) { } catch (error: any) {
logger.error('Failed to delete WebDAV file:', error) logger.error('Failed to delete WebDAV file:', error)
@ -525,7 +628,7 @@ class BackupManager {
logger.debug(`Starting S3 backup to ${filename}`) logger.debug(`Starting S3 backup to ${filename}`)
const backupedFilePath = await this.backup(_, filename, data, undefined, s3Config.skipBackupFile) const backupedFilePath = await this.backup(_, filename, data, undefined, s3Config.skipBackupFile)
const s3Client = new S3Storage(s3Config) const s3Client = this.getS3Storage(s3Config)
try { try {
const fileBuffer = await fs.promises.readFile(backupedFilePath) const fileBuffer = await fs.promises.readFile(backupedFilePath)
const result = await s3Client.putFileContents(filename, fileBuffer) const result = await s3Client.putFileContents(filename, fileBuffer)
@ -603,7 +706,7 @@ class BackupManager {
logger.debug(`Starting restore from S3: ${filename}`) logger.debug(`Starting restore from S3: ${filename}`)
const s3Client = new S3Storage(s3Config) const s3Client = this.getS3Storage(s3Config)
try { try {
const retrievedFile = await s3Client.getFileContents(filename) const retrievedFile = await s3Client.getFileContents(filename)
const backupedFilePath = path.join(this.backupDir, filename) const backupedFilePath = path.join(this.backupDir, filename)
@ -628,7 +731,7 @@ class BackupManager {
listS3Files = async (_: Electron.IpcMainInvokeEvent, s3Config: S3Config) => { listS3Files = async (_: Electron.IpcMainInvokeEvent, s3Config: S3Config) => {
try { try {
const s3Client = new S3Storage(s3Config) const s3Client = this.getS3Storage(s3Config)
const objects = await s3Client.listFiles() const objects = await s3Client.listFiles()
const files = objects const files = objects
@ -652,7 +755,7 @@ class BackupManager {
async deleteS3File(_: Electron.IpcMainInvokeEvent, fileName: string, s3Config: S3Config) { async deleteS3File(_: Electron.IpcMainInvokeEvent, fileName: string, s3Config: S3Config) {
try { try {
const s3Client = new S3Storage(s3Config) const s3Client = this.getS3Storage(s3Config)
return await s3Client.deleteFile(fileName) return await s3Client.deleteFile(fileName)
} catch (error: any) { } catch (error: any) {
logger.error('Failed to delete S3 file:', error) logger.error('Failed to delete S3 file:', error)
@ -661,7 +764,7 @@ class BackupManager {
} }
async checkS3Connection(_: Electron.IpcMainInvokeEvent, s3Config: S3Config) { async checkS3Connection(_: Electron.IpcMainInvokeEvent, s3Config: S3Config) {
const s3Client = new S3Storage(s3Config) const s3Client = this.getS3Storage(s3Config)
return await s3Client.checkConnection() return await s3Client.checkConnection()
} }
} }

View File

@ -5,6 +5,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
import { AihubmixAPIClient } from '../AihubmixAPIClient' import { AihubmixAPIClient } from '../AihubmixAPIClient'
import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient' import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient'
import { ApiClientFactory } from '../ApiClientFactory' import { ApiClientFactory } from '../ApiClientFactory'
import { AwsBedrockAPIClient } from '../aws/AwsBedrockAPIClient'
import { GeminiAPIClient } from '../gemini/GeminiAPIClient' import { GeminiAPIClient } from '../gemini/GeminiAPIClient'
import { VertexAPIClient } from '../gemini/VertexAPIClient' import { VertexAPIClient } from '../gemini/VertexAPIClient'
import { NewAPIClient } from '../NewAPIClient' import { NewAPIClient } from '../NewAPIClient'
@ -54,6 +55,19 @@ vi.mock('../openai/OpenAIResponseAPIClient', () => ({
vi.mock('../ppio/PPIOAPIClient', () => ({ vi.mock('../ppio/PPIOAPIClient', () => ({
PPIOAPIClient: vi.fn().mockImplementation(() => ({})) PPIOAPIClient: vi.fn().mockImplementation(() => ({}))
})) }))
vi.mock('../aws/AwsBedrockAPIClient', () => ({
AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({}))
}))
// Mock the models config to prevent circular dependency issues
vi.mock('@renderer/config/models', () => ({
findTokenLimit: vi.fn(),
isReasoningModel: vi.fn(),
SYSTEM_MODELS: {
silicon: [],
defaultModel: []
}
}))
describe('ApiClientFactory', () => { describe('ApiClientFactory', () => {
beforeEach(() => { beforeEach(() => {
@ -144,6 +158,15 @@ describe('ApiClientFactory', () => {
expect(client).toBeDefined() expect(client).toBeDefined()
}) })
it('should create AwsBedrockAPIClient for aws-bedrock type', () => {
const provider = createTestProvider('aws-bedrock', 'aws-bedrock')
const client = ApiClientFactory.create(provider)
expect(AwsBedrockAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
// 测试默认情况 // 测试默认情况
it('should create OpenAIAPIClient as default for unknown type', () => { it('should create OpenAIAPIClient as default for unknown type', () => {
const provider = createTestProvider('unknown', 'unknown-type') const provider = createTestProvider('unknown', 'unknown-type')

View File

@ -2,19 +2,23 @@ import { BedrockClient, ListFoundationModelsCommand, ListInferenceProfilesComman
import { import {
BedrockRuntimeClient, BedrockRuntimeClient,
ConverseCommand, ConverseCommand,
ConverseStreamCommand, InvokeModelCommand,
InvokeModelCommand InvokeModelWithResponseStreamCommand
} from '@aws-sdk/client-bedrock-runtime' } from '@aws-sdk/client-bedrock-runtime'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas' import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { findTokenLimit, isReasoningModel } from '@renderer/config/models'
import { import {
getAwsBedrockAccessKeyId, getAwsBedrockAccessKeyId,
getAwsBedrockRegion, getAwsBedrockRegion,
getAwsBedrockSecretAccessKey getAwsBedrockSecretAccessKey
} from '@renderer/hooks/useAwsBedrock' } from '@renderer/hooks/useAwsBedrock'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import { estimateTextTokens } from '@renderer/services/TokenService' import { estimateTextTokens } from '@renderer/services/TokenService'
import { import {
Assistant,
EFFORT_RATIO,
GenerateImageParams, GenerateImageParams,
MCPCallToolResponse, MCPCallToolResponse,
MCPTool, MCPTool,
@ -23,7 +27,13 @@ import {
Provider, Provider,
ToolCallResponse ToolCallResponse
} from '@renderer/types' } from '@renderer/types'
import { ChunkType, MCPToolCreatedChunk, TextDeltaChunk } from '@renderer/types/chunk' import {
ChunkType,
MCPToolCreatedChunk,
TextDeltaChunk,
ThinkingDeltaChunk,
ThinkingStartChunk
} from '@renderer/types/chunk'
import { Message } from '@renderer/types/newMessage' import { Message } from '@renderer/types/newMessage'
import { import {
AwsBedrockSdkInstance, AwsBedrockSdkInstance,
@ -33,6 +43,7 @@ import {
AwsBedrockSdkRawOutput, AwsBedrockSdkRawOutput,
AwsBedrockSdkTool, AwsBedrockSdkTool,
AwsBedrockSdkToolCall, AwsBedrockSdkToolCall,
AwsBedrockStreamChunk,
SdkModel SdkModel
} from '@renderer/types/sdk' } from '@renderer/types/sdk'
import { convertBase64ImageToAwsBedrockFormat } from '@renderer/utils/aws-bedrock-utils' import { convertBase64ImageToAwsBedrockFormat } from '@renderer/utils/aws-bedrock-utils'
@ -103,46 +114,65 @@ export class AwsBedrockAPIClient extends BaseApiClient<
override async createCompletions(payload: AwsBedrockSdkParams): Promise<AwsBedrockSdkRawOutput> { override async createCompletions(payload: AwsBedrockSdkParams): Promise<AwsBedrockSdkRawOutput> {
const sdk = await this.getSdkInstance() const sdk = await this.getSdkInstance()
// 转换消息格式到AWS SDK原生格式 // 转换消息格式(用于 InvokeModelWithResponseStreamCommand
const awsMessages = payload.messages.map((msg) => ({ const awsMessages = payload.messages.map((msg) => ({
role: msg.role, role: msg.role,
content: msg.content.map((content) => { content: msg.content.map((content) => {
if (content.text) { if (content.text) {
return { text: content.text } return { type: 'text', text: content.text }
} }
if (content.image) { if (content.image) {
// 处理图片数据,将 Uint8Array 或数字数组转换为 base64 字符串
let base64Data = ''
if (content.image.source.bytes) {
if (typeof content.image.source.bytes === 'string') {
// 如果已经是字符串,直接使用
base64Data = content.image.source.bytes
} else {
// 如果是数组或 Uint8Array转换为 base64
const uint8Array = new Uint8Array(Object.values(content.image.source.bytes))
const binaryString = Array.from(uint8Array)
.map((byte) => String.fromCharCode(byte))
.join('')
base64Data = btoa(binaryString)
}
}
return { return {
image: { type: 'image',
format: content.image.format, source: {
source: content.image.source type: 'base64',
media_type: `image/${content.image.format}`,
data: base64Data
} }
} }
} }
if (content.toolResult) { if (content.toolResult) {
return { return {
toolResult: { type: 'tool_result',
toolUseId: content.toolResult.toolUseId, tool_use_id: content.toolResult.toolUseId,
content: content.toolResult.content, content: content.toolResult.content
status: content.toolResult.status
}
} }
} }
if (content.toolUse) { if (content.toolUse) {
return { return {
toolUse: { type: 'tool_use',
toolUseId: content.toolUse.toolUseId, id: content.toolUse.toolUseId,
name: content.toolUse.name, name: content.toolUse.name,
input: content.toolUse.input input: content.toolUse.input
}
} }
} }
// 返回符合AWS SDK ContentBlock类型的对象 return { type: 'text', text: 'Unknown content type' }
return { text: 'Unknown content type' }
}) })
})) }))
logger.info('Creating completions with model ID:', { modelId: payload.modelId }) logger.info('Creating completions with model ID:', { modelId: payload.modelId })
const excludeKeys = ['modelId', 'messages', 'system', 'maxTokens', 'temperature', 'topP', 'stream', 'tools']
const additionalParams = Object.keys(payload)
.filter((key) => !excludeKeys.includes(key))
.reduce((acc, key) => ({ ...acc, [key]: payload[key] }), {})
const commonParams = { const commonParams = {
modelId: payload.modelId, modelId: payload.modelId,
messages: awsMessages as any, messages: awsMessages as any,
@ -162,10 +192,18 @@ export class AwsBedrockAPIClient extends BaseApiClient<
try { try {
if (payload.stream) { if (payload.stream) {
const command = new ConverseStreamCommand(commonParams) // 根据模型类型选择正确的 API 格式
const requestBody = this.createRequestBodyForModel(commonParams, additionalParams)
const command = new InvokeModelWithResponseStreamCommand({
modelId: commonParams.modelId,
body: JSON.stringify(requestBody),
contentType: 'application/json',
accept: 'application/json'
})
const response = await sdk.client.send(command) const response = await sdk.client.send(command)
// 直接返回AWS Bedrock流式响应的异步迭代器 return this.createInvokeModelStreamIterator(response)
return this.createStreamIterator(response)
} else { } else {
const command = new ConverseCommand(commonParams) const command = new ConverseCommand(commonParams)
const response = await sdk.client.send(command) const response = await sdk.client.send(command)
@ -177,32 +215,236 @@ export class AwsBedrockAPIClient extends BaseApiClient<
} }
} }
private async *createStreamIterator(response: any): AsyncIterable<AwsBedrockSdkRawChunk> { /**
try { *
if (response.stream) { */
for await (const chunk of response.stream) { private createRequestBodyForModel(commonParams: any, additionalParams: any): any {
logger.debug('AWS Bedrock chunk received:', chunk) const modelId = commonParams.modelId.toLowerCase()
// AWS Bedrock的流式响应格式转换为标准格式 // Claude 系列模型使用 Anthropic API 格式
if (chunk.contentBlockDelta?.delta?.text) { if (modelId.includes('claude')) {
yield { return {
contentBlockDelta: { anthropic_version: 'bedrock-2023-05-31',
delta: { text: chunk.contentBlockDelta.delta.text } max_tokens: commonParams.inferenceConfig.maxTokens,
temperature: commonParams.inferenceConfig.temperature,
top_p: commonParams.inferenceConfig.topP,
messages: commonParams.messages,
...(commonParams.system && commonParams.system[0]?.text ? { system: commonParams.system[0].text } : {}),
...(commonParams.toolConfig?.tools ? { tools: commonParams.toolConfig.tools } : {}),
...additionalParams
}
}
// OpenAI 系列模型
if (modelId.includes('gpt') || modelId.includes('openai')) {
const messages: any[] = []
// 添加系统消息
if (commonParams.system && commonParams.system[0]?.text) {
messages.push({
role: 'system',
content: commonParams.system[0].text
})
}
// 转换消息格式
for (const message of commonParams.messages) {
const content: any[] = []
for (const part of message.content) {
if (part.text) {
content.push({ type: 'text', text: part.text })
} else if (part.image) {
content.push({
type: 'image_url',
image_url: {
url: `data:image/${part.image.format};base64,${part.image.source.bytes}`
}
})
}
}
messages.push({
role: message.role,
content: content.length === 1 && content[0].type === 'text' ? content[0].text : content
})
}
const baseBody: any = {
model: commonParams.modelId,
messages: messages,
max_tokens: commonParams.inferenceConfig.maxTokens,
temperature: commonParams.inferenceConfig.temperature,
top_p: commonParams.inferenceConfig.topP,
stream: true,
...(commonParams.toolConfig?.tools ? { tools: commonParams.toolConfig.tools } : {})
}
// OpenAI 模型的 thinking 参数格式
if (additionalParams.reasoning_effort) {
baseBody.reasoning_effort = additionalParams.reasoning_effort
delete additionalParams.reasoning_effort
}
return {
...baseBody,
...additionalParams
}
}
// Llama 系列模型
if (modelId.includes('llama')) {
const baseBody: any = {
prompt: this.convertMessagesToPrompt(commonParams.messages, commonParams.system),
max_gen_len: commonParams.inferenceConfig.maxTokens,
temperature: commonParams.inferenceConfig.temperature,
top_p: commonParams.inferenceConfig.topP
}
// Llama 模型的 thinking 参数格式
if (additionalParams.thinking_mode) {
baseBody.thinking_mode = additionalParams.thinking_mode
delete additionalParams.thinking_mode
}
return {
...baseBody,
...additionalParams
}
}
// Amazon Titan 系列模型
if (modelId.includes('titan')) {
const textGenerationConfig: any = {
maxTokenCount: commonParams.inferenceConfig.maxTokens,
temperature: commonParams.inferenceConfig.temperature,
topP: commonParams.inferenceConfig.topP
}
// 将 thinking 相关参数添加到 textGenerationConfig 中
if (additionalParams.thinking) {
textGenerationConfig.thinking = additionalParams.thinking
delete additionalParams.thinking
}
return {
inputText: this.convertMessagesToPrompt(commonParams.messages, commonParams.system),
textGenerationConfig: {
...textGenerationConfig,
...Object.keys(additionalParams).reduce((acc, key) => {
if (['thinking_tokens', 'reasoning_mode'].includes(key)) {
acc[key] = additionalParams[key]
delete additionalParams[key]
}
return acc
}, {} as any)
},
...additionalParams
}
}
// Cohere Command 系列模型
if (modelId.includes('cohere') || modelId.includes('command')) {
const baseBody: any = {
message: this.convertMessagesToPrompt(commonParams.messages, commonParams.system),
max_tokens: commonParams.inferenceConfig.maxTokens,
temperature: commonParams.inferenceConfig.temperature,
p: commonParams.inferenceConfig.topP
}
// Cohere 模型的 thinking 参数格式
if (additionalParams.thinking) {
baseBody.thinking = additionalParams.thinking
delete additionalParams.thinking
}
if (additionalParams.reasoning_tokens) {
baseBody.reasoning_tokens = additionalParams.reasoning_tokens
delete additionalParams.reasoning_tokens
}
return {
...baseBody,
...additionalParams
}
}
// 默认使用通用格式
const baseBody: any = {
prompt: this.convertMessagesToPrompt(commonParams.messages, commonParams.system),
max_tokens: commonParams.inferenceConfig.maxTokens,
temperature: commonParams.inferenceConfig.temperature,
top_p: commonParams.inferenceConfig.topP
}
return {
...baseBody,
...additionalParams
}
}
/**
* prompt
*/
private convertMessagesToPrompt(messages: any[], system?: any[]): string {
let prompt = ''
// 添加系统消息
if (system && system[0]?.text) {
prompt += `System: ${system[0].text}\n\n`
}
// 添加对话消息
for (const message of messages) {
const role = message.role === 'assistant' ? 'Assistant' : 'Human'
let content = ''
for (const part of message.content) {
if (part.text) {
content += part.text
} else if (part.image) {
content += '[Image]'
}
}
prompt += `${role}: ${content}\n\n`
}
prompt += 'Assistant:'
return prompt
}
private async *createInvokeModelStreamIterator(response: any): AsyncIterable<AwsBedrockSdkRawChunk> {
try {
if (response.body) {
for await (const event of response.body) {
if (event.chunk) {
const chunk: AwsBedrockStreamChunk = JSON.parse(new TextDecoder().decode(event.chunk.bytes))
// 转换为标准格式
if (chunk.type === 'content_block_delta') {
yield {
contentBlockDelta: {
delta: chunk.delta,
contentBlockIndex: chunk.index
}
}
} else if (chunk.type === 'message_start') {
yield { messageStart: chunk }
} else if (chunk.type === 'message_stop') {
yield { messageStop: chunk }
} else if (chunk.type === 'content_block_start') {
yield {
contentBlockStart: {
start: chunk.content_block,
contentBlockIndex: chunk.index
}
}
} else if (chunk.type === 'content_block_stop') {
yield {
contentBlockStop: {
contentBlockIndex: chunk.index
}
} }
} }
} }
if (chunk.messageStart) {
yield { messageStart: chunk.messageStart }
}
if (chunk.messageStop) {
yield { messageStop: chunk.messageStop }
}
if (chunk.metadata) {
yield { metadata: chunk.metadata }
}
} }
} }
} catch (error) { } catch (error) {
@ -485,6 +727,38 @@ export class AwsBedrockAPIClient extends BaseApiClient<
} }
} }
// 获取推理预算token对所有支持推理的模型
const budgetTokens = this.getBudgetToken(assistant, model)
// 构建基础自定义参数
const customParams: Record<string, any> =
coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {}
// 根据模型类型添加 thinking 参数
if (budgetTokens) {
const modelId = model.id.toLowerCase()
if (modelId.includes('claude')) {
// Claude 模型使用 Anthropic 格式
customParams.thinking = { type: 'enabled', budget_tokens: budgetTokens }
} else if (modelId.includes('gpt') || modelId.includes('openai')) {
// OpenAI 模型格式
customParams.reasoning_effort = assistant?.settings?.reasoning_effort
} else if (modelId.includes('llama')) {
// Llama 模型格式
customParams.thinking_mode = true
customParams.thinking_tokens = budgetTokens
} else if (modelId.includes('titan')) {
// Titan 模型格式
customParams.thinking = { enabled: true }
customParams.thinking_tokens = budgetTokens
} else if (modelId.includes('cohere') || modelId.includes('command')) {
// Cohere 模型格式
customParams.thinking = { enabled: true }
customParams.reasoning_tokens = budgetTokens
}
}
const payload: AwsBedrockSdkParams = { const payload: AwsBedrockSdkParams = {
modelId: model.id, modelId: model.id,
messages: messages:
@ -497,9 +771,7 @@ export class AwsBedrockAPIClient extends BaseApiClient<
topP: this.getTopP(assistant, model), topP: this.getTopP(assistant, model),
stream: streamOutput !== false, stream: streamOutput !== false,
tools: tools.length > 0 ? tools : undefined, tools: tools.length > 0 ? tools : undefined,
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑 ...customParams
// 注意:用户自定义参数总是应该覆盖其他参数
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
} }
const timeout = this.getTimeout(model) const timeout = this.getTimeout(model)
@ -511,6 +783,7 @@ export class AwsBedrockAPIClient extends BaseApiClient<
getResponseChunkTransformer(): ResponseChunkTransformer<AwsBedrockSdkRawChunk> { getResponseChunkTransformer(): ResponseChunkTransformer<AwsBedrockSdkRawChunk> {
return () => { return () => {
let hasStartedText = false let hasStartedText = false
let hasStartedThinking = false
let accumulatedJson = '' let accumulatedJson = ''
const toolCalls: Record<number, AwsBedrockSdkToolCall> = {} const toolCalls: Record<number, AwsBedrockSdkToolCall> = {}
@ -570,6 +843,24 @@ export class AwsBedrockAPIClient extends BaseApiClient<
} as TextDeltaChunk) } as TextDeltaChunk)
} }
// 处理thinking增量
if (
rawChunk.contentBlockDelta?.delta?.type === 'thinking_delta' &&
rawChunk.contentBlockDelta?.delta?.thinking
) {
if (!hasStartedThinking) {
controller.enqueue({
type: ChunkType.THINKING_START
} as ThinkingStartChunk)
hasStartedThinking = true
}
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: rawChunk.contentBlockDelta.delta.thinking
} as ThinkingDeltaChunk)
}
// 处理内容块停止事件 - 参考 Anthropic 的 content_block_stop 处理 // 处理内容块停止事件 - 参考 Anthropic 的 content_block_stop 处理
if (rawChunk.contentBlockStop) { if (rawChunk.contentBlockStop) {
const blockIndex = rawChunk.contentBlockStop.contentBlockIndex || 0 const blockIndex = rawChunk.contentBlockStop.contentBlockIndex || 0
@ -708,4 +999,49 @@ export class AwsBedrockAPIClient extends BaseApiClient<
extractMessagesFromSdkPayload(sdkPayload: AwsBedrockSdkParams): AwsBedrockSdkMessageParam[] { extractMessagesFromSdkPayload(sdkPayload: AwsBedrockSdkParams): AwsBedrockSdkMessageParam[] {
return sdkPayload.messages || [] return sdkPayload.messages || []
} }
/**
* AWS Bedrock token
* @param assistant - The assistant
* @param model - The model
* @returns The budget tokens for reasoning effort
*/
private getBudgetToken(assistant: Assistant, model: Model): number | undefined {
try {
if (!isReasoningModel(model)) {
return undefined
}
const { maxTokens } = getAssistantSettings(assistant)
const reasoningEffort = assistant?.settings?.reasoning_effort
if (reasoningEffort === undefined) {
return undefined
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const tokenLimits = findTokenLimit(model.id)
if (tokenLimits) {
// 使用模型特定的 token 限制
const budgetTokens = Math.max(
1024,
Math.floor(
Math.min(
(tokenLimits.max - tokenLimits.min) * effortRatio + tokenLimits.min,
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
)
)
)
return budgetTokens
} else {
// 对于没有特定限制的模型,使用简化计算
const budgetTokens = Math.max(1024, Math.floor((maxTokens || DEFAULT_MAX_TOKENS) * effortRatio))
return budgetTokens
}
} catch (error) {
logger.warn('Failed to calculate budget tokens for reasoning effort:', error as Error)
return undefined
}
}
} }

View File

@ -112,3 +112,129 @@ export function MdiLightbulbOn(props: SVGProps<SVGSVGElement>) {
</svg> </svg>
) )
} }
export function BingLogo(props: SVGProps<SVGSVGElement>) {
return (
<svg
fill="currentColor"
fill-rule="evenodd"
width="1em"
height="1em"
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
{...props}>
<path d="M4.842.005a.966.966 0 01.604.142l2.62 1.813c.369.256.492.352.637.496.471.47.752 1.09.797 1.765l.008.847.003 1.441.004 13.002.144-.094 7.015-4.353.015.003.029.01c-.398-.17-.893-.339-1.655-.566l-.484-.146c-.584-.18-.71-.238-.921-.38a2.009 2.009 0 01-.37-.312 2.172 2.172 0 01-.41-.592L11.32 9.063c-.166-.444-.166-.49-.156-.63a.92.92 0 01.806-.864l.094-.01c.044-.005.22.023.29.044l.052.021c.06.026.16.075.313.154l3.63 1.908a6.626 6.626 0 013.292 4.531c.194.99.159 2.037-.102 3.012-.216.805-.639 1.694-1.054 2.213l-.08.099-.047.05c-.01.01-.013.01-.01.002l.043-.074-.072.114c-.011.031-.233.28-.38.425l-.17.161c-.22.202-.431.36-.832.62L13.544 23c-.941.6-1.86.912-2.913.992-.23.018-.854.008-1.074-.017a6.31 6.31 0 01-1.658-.412c-1.854-.738-3.223-2.288-3.705-4.195a8.077 8.077 0 01-.121-.57l-.046-.325a1.123 1.123 0 01-.014-.168l-.006-.029L4 11.617 4.01.866a.981.981 0 01.007-.111.943.943 0 01.825-.75z"></path>
</svg>
)
}
export function SearXNGLogo(props: SVGProps<SVGSVGElement>) {
return (
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 265 265" style={{ display: 'block' }} {...props}>
<g transform="translate(-40.921 -17.417)">
<circle
cx="142.2"
cy="122.9"
r="85"
fill="none"
stroke="currentColor"
strokeWidth="28.3465"
strokeLinecap="round"
strokeLinejoin="round"
strokeMiterlimit="11.3386"
/>
<path
d="M118.4 77.6c19.8-10.2 44-6.4 59.7 9.4s19.3 40 8.9 59.7"
fill="none"
stroke="currentColor"
strokeWidth="14.1732"
strokeLinecap="round"
strokeLinejoin="round"
strokeMiterlimit="11.3386"
/>
<path d="m184.2 202 37-38.6 81.8 78.3-37 38.6z" fill="currentColor" />
</g>
</svg>
)
}
export function TavilyLogo(props: SVGProps<SVGSVGElement>) {
return (
<svg width="42" height="42" viewBox="0 0 42 42" fill="none" xmlns="http://www.w3.org/2000/svg" {...props}>
<path
d="m16.44.964 4.921 7.79c.79 1.252-.108 2.883-1.588 2.883H17.76V23.3h-2.91V.088c.61 0 1.22.292 1.59.876z"
fill="currentColor"
/>
<path
d="M8.342 8.755 13.263.964a1.864 1.864 0 0 1 1.59-.876V23.3a4.87 4.87 0 0 0-.252-.006c-.99 0-1.907.311-2.658.842V11.637H9.93c-1.48 0-2.38-1.631-1.589-2.882z"
fill="currentColor"
/>
<path
d="M30.278 31H18.031a4.596 4.596 0 0 0 1.219-2.91h22.577c0 .61-.292 1.22-.875 1.59L33.16 34.6c-1.251.791-2.883-.108-2.883-1.588V31z"
fill="currentColor"
/>
<path
d="m33.16 21.581 7.79 4.921c.585.369.876.979.876 1.589H19.25a4.619 4.619 0 0 0-.858-2.91h11.887V23.17c0-1.48 1.631-2.38 2.882-1.589z"
fill="currentColor"
/>
<path
d="m8.24 34.25-7.107 7.108a1.864 1.864 0 0 0 1.742.504l8.989-2.03c1.443-.325 1.961-2.114.915-3.16l-1.423-1.423 5.356-5.356a2.805 2.805 0 0 0 0-3.966l-.074-.075L8.24 34.25z"
fill="currentColor"
/>
<path
d="m7.243 31.135 5.355-5.356a2.805 2.805 0 0 1 3.967 0l.074.074-8.397 8.397-7.108 7.108a1.864 1.864 0 0 1-.504-1.742l2.029-8.989c.325-1.444 2.115-1.961 3.161-.915l1.423 1.423z"
fill="currentColor"
/>
</svg>
)
}
export function ExaLogo(props: SVGProps<SVGSVGElement>) {
return (
<svg
fill="currentColor"
fill-rule="evenodd"
width="1em"
height="1em"
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
{...props}>
<title>Exa</title>
<path
clip-rule="evenodd"
d="M3 0h19v1.791L13.892 12 22 22.209V24H3V0zm9.62 10.348l6.589-8.557H6.03l6.59 8.557zM5.138 3.935v7.17h5.52l-5.52-7.17zm5.52 8.96h-5.52v7.17l5.52-7.17zM6.03 22.21l6.59-8.557 6.589 8.557H6.03z"></path>
</svg>
)
}
export function BochaLogo(props: SVGProps<SVGSVGElement>) {
return (
<svg width="1em" height="1em" viewBox="0 0 135 116" fill="none" xmlns="http://www.w3.org/2000/svg" {...props}>
<path
fill-rule="evenodd"
clip-rule="evenodd"
d="M12.5754 13.8123C24.6109 7.94459 39.1223 12.9435 44.9955 24.9805L57.5355 50.6805C60.4695 56.6936 57.9756 63.9478 51.9652 66.8832C51.9627 66.8844 51.9602 66.8856 51.9577 66.8868C45.94 69.8206 38.6843 67.3212 35.7477 61.3027L12.5754 13.8123Z"
fill="currentColor"
/>
<path
opacity="0.64774"
fill-rule="evenodd"
clip-rule="evenodd"
d="M0 38.3013C9.46916 28.836 24.813 28.836 34.2822 38.3013L55.2526 59.2631C59.9819 63.9904 59.9852 71.6582 55.2601 76.3896C55.2576 76.3921 55.2551 76.3946 55.2526 76.397C50.5181 81.1297 42.8461 81.1297 38.1116 76.397L0 38.3013Z"
fill="currentColor"
/>
<path
fill-rule="evenodd"
clip-rule="evenodd"
d="M86.8777 18.0444C113.939 18.0444 135.876 39.9725 135.876 67.0222C135.876 80.2286 129.086 93.6477 120.585 102.457L117.065 98.2367C111.026 90.9998 108.882 81.2777 111.314 72.1702C111.755 70.5198 111.976 69.0033 111.976 67.6209C111.976 53.6689 100.661 42.3586 86.7029 42.3586C72.7452 42.3586 61.4303 53.6689 61.4303 67.6209C61.4303 81.5728 72.7452 92.8831 86.7029 92.8831C89.3159 92.8831 91.8363 92.4867 94.2071 91.7508C101.312 89.5455 109.054 91.3768 114.419 96.5322L120.585 102.457C111.83 110.626 99.7992 116 86.8777 116C59.8168 116 37.8796 94.0719 37.8796 67.0222C37.8796 39.9725 59.8168 18.0444 86.8777 18.0444Z"
fill="currentColor"
/>
<path
fill-rule="evenodd"
clip-rule="evenodd"
d="M37.8796 0C51.2677 0 62.1208 10.8581 62.1208 24.2522V41.7389C62.1208 55.133 51.2677 65.9911 37.8796 65.9911V0Z"
fill="currentColor"
/>
</svg>
)
}

View File

@ -3,7 +3,14 @@ import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import SelectProviderModelPopup from '@renderer/pages/settings/ProviderSettings/SelectProviderModelPopup' import SelectProviderModelPopup from '@renderer/pages/settings/ProviderSettings/SelectProviderModelPopup'
import { checkApi } from '@renderer/services/ApiService' import { checkApi } from '@renderer/services/ApiService'
import WebSearchService from '@renderer/services/WebSearchService' import WebSearchService from '@renderer/services/WebSearchService'
import { Model, PreprocessProvider, Provider, WebSearchProvider } from '@renderer/types' import {
isPreprocessProviderId,
isWebSearchProviderId,
Model,
PreprocessProvider,
Provider,
WebSearchProvider
} from '@renderer/types'
import { ApiKeyConnectivity, ApiKeyWithStatus, HealthStatus } from '@renderer/types/healthCheck' import { ApiKeyConnectivity, ApiKeyWithStatus, HealthStatus } from '@renderer/types/healthCheck'
import { formatApiKeys, splitApiKeyString } from '@renderer/utils/api' import { formatApiKeys, splitApiKeyString } from '@renderer/utils/api'
import { formatErrorMessage } from '@renderer/utils/error' import { formatErrorMessage } from '@renderer/utils/error'
@ -12,12 +19,11 @@ import { isEmpty } from 'lodash'
import { useCallback, useMemo, useState } from 'react' import { useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { ApiKeyValidity, ApiProviderKind, ApiProviderUnion } from './types' import { ApiKeyValidity, ApiProvider, UpdateApiProviderFunc } from './types'
interface UseApiKeysProps { interface UseApiKeysProps {
provider: ApiProviderUnion provider: ApiProvider
updateProvider: (provider: Partial<ApiProviderUnion>) => void updateProvider: UpdateApiProviderFunc
providerKind: ApiProviderKind
} }
const logger = loggerService.withContext('ApiKeyListPopup') const logger = loggerService.withContext('ApiKeyListPopup')
@ -25,7 +31,7 @@ const logger = loggerService.withContext('ApiKeyListPopup')
/** /**
* API Keys hook * API Keys hook
*/ */
export function useApiKeys({ provider, updateProvider, providerKind }: UseApiKeysProps) { export function useApiKeys({ provider, updateProvider }: UseApiKeysProps) {
const { t } = useTranslation() const { t } = useTranslation()
// 连通性检查的 UI 状态管理 // 连通性检查的 UI 状态管理
@ -199,11 +205,13 @@ export function useApiKeys({ provider, updateProvider, providerKind }: UseApiKey
try { try {
const startTime = Date.now() const startTime = Date.now()
if (isLlmProvider(provider, providerKind) && model) { if (isLlmProvider(provider) && model) {
await checkApi({ ...provider, apiKey: keyToCheck }, model) await checkApi({ ...provider, apiKey: keyToCheck }, model)
} else { } else if (isWebSearchProvider(provider)) {
const result = await WebSearchService.checkSearch({ ...provider, apiKey: keyToCheck }) const result = await WebSearchService.checkSearch({ ...provider, apiKey: keyToCheck })
if (!result.valid) throw new Error(result.error) if (!result.valid) throw new Error(result.error)
} else {
// 不处理预处理供应商
} }
const latency = Date.now() - startTime const latency = Date.now() - startTime
@ -228,7 +236,7 @@ export function useApiKeys({ provider, updateProvider, providerKind }: UseApiKey
logger.error('failed to validate the connectivity of the api key', error) logger.error('failed to validate the connectivity of the api key', error)
} }
}, },
[keys, connectivityStates, updateConnectivityState, provider, providerKind] [keys, connectivityStates, updateConnectivityState, provider]
) )
// 检查单个 key 的连通性 // 检查单个 key 的连通性
@ -240,23 +248,23 @@ export function useApiKeys({ provider, updateProvider, providerKind }: UseApiKey
const currentState = connectivityStates.get(keyToCheck) const currentState = connectivityStates.get(keyToCheck)
if (currentState?.checking) return if (currentState?.checking) return
const model = isLlmProvider(provider, providerKind) ? await getModelForCheck(provider, t) : undefined const model = isLlmProvider(provider) ? await getModelForCheck(provider, t) : undefined
if (model === null) return if (model === null) return
await runConnectivityCheck(index, model) await runConnectivityCheck(index, model)
}, },
[provider, keys, connectivityStates, providerKind, t, runConnectivityCheck] [provider, keys, connectivityStates, t, runConnectivityCheck]
) )
// 检查所有 keys 的连通性 // 检查所有 keys 的连通性
const checkAllKeysConnectivity = useCallback(async () => { const checkAllKeysConnectivity = useCallback(async () => {
if (!provider || keys.length === 0) return if (!provider || keys.length === 0) return
const model = isLlmProvider(provider, providerKind) ? await getModelForCheck(provider, t) : undefined const model = isLlmProvider(provider) ? await getModelForCheck(provider, t) : undefined
if (model === null) return if (model === null) return
await Promise.allSettled(keys.map((_, index) => runConnectivityCheck(index, model))) await Promise.allSettled(keys.map((_, index) => runConnectivityCheck(index, model)))
}, [provider, keys, providerKind, t, runConnectivityCheck]) }, [provider, keys, t, runConnectivityCheck])
// 计算是否有 key 正在检查 // 计算是否有 key 正在检查
const isChecking = useMemo(() => { const isChecking = useMemo(() => {
@ -275,16 +283,18 @@ export function useApiKeys({ provider, updateProvider, providerKind }: UseApiKey
} }
} }
export function isLlmProvider(obj: any, kind: ApiProviderKind): obj is Provider { export function isLlmProvider(provider: ApiProvider): provider is Provider {
return kind === 'llm' && 'type' in obj && 'models' in obj return 'models' in provider
} }
export function isWebSearchProvider(obj: any, kind: ApiProviderKind): obj is WebSearchProvider { export function isWebSearchProvider(provider: ApiProvider): provider is WebSearchProvider {
return kind === 'websearch' && ('url' in obj || 'engines' in obj) return isWebSearchProviderId(provider.id)
} }
export function isPreprocessProvider(obj: any, kind: ApiProviderKind): obj is PreprocessProvider { export function isPreprocessProvider(provider: ApiProvider): provider is PreprocessProvider {
return kind === 'doc-preprocess' && ('quota' in obj || 'options' in obj) // NOTE: mistral 同时提供预处理和llm服务所以其llm provier可能被误判为预处理provider
// 后面需要使用更严格的判断方式
return isPreprocessProviderId(provider.id) && !isLlmProvider(provider)
} }
// 获取模型用于检查 // 获取模型用于检查

View File

@ -6,6 +6,7 @@ import { useProvider } from '@renderer/hooks/useProvider'
import { useWebSearchProvider } from '@renderer/hooks/useWebSearchProviders' import { useWebSearchProvider } from '@renderer/hooks/useWebSearchProviders'
import { SettingHelpText } from '@renderer/pages/settings' import { SettingHelpText } from '@renderer/pages/settings'
import { isProviderSupportAuth } from '@renderer/services/ProviderService' import { isProviderSupportAuth } from '@renderer/services/ProviderService'
import { PreprocessProviderId, WebSearchProviderId } from '@renderer/types'
import { ApiKeyWithStatus, HealthStatus } from '@renderer/types/healthCheck' import { ApiKeyWithStatus, HealthStatus } from '@renderer/types/healthCheck'
import { Button, Card, Flex, List, Popconfirm, Space, Tooltip, Typography } from 'antd' import { Button, Card, Flex, List, Popconfirm, Space, Tooltip, Typography } from 'antd'
import { Plus } from 'lucide-react' import { Plus } from 'lucide-react'
@ -15,19 +16,18 @@ import styled from 'styled-components'
import { isLlmProvider, useApiKeys } from './hook' import { isLlmProvider, useApiKeys } from './hook'
import ApiKeyItem from './item' import ApiKeyItem from './item'
import { ApiProviderKind, ApiProviderUnion } from './types' import { ApiProvider, UpdateApiProviderFunc } from './types'
interface ApiKeyListProps { interface ApiKeyListProps {
provider: ApiProviderUnion provider: ApiProvider
updateProvider: (provider: Partial<ApiProviderUnion>) => void updateProvider: UpdateApiProviderFunc
providerKind: ApiProviderKind
showHealthCheck?: boolean showHealthCheck?: boolean
} }
/** /**
* Api key CRUD * Api key CRUD
*/ */
export const ApiKeyList: FC<ApiKeyListProps> = ({ provider, updateProvider, providerKind, showHealthCheck = true }) => { export const ApiKeyList: FC<ApiKeyListProps> = ({ provider, updateProvider, showHealthCheck = true }) => {
const { t } = useTranslation() const { t } = useTranslation()
// 临时新项状态 // 临时新项状态
@ -42,7 +42,7 @@ export const ApiKeyList: FC<ApiKeyListProps> = ({ provider, updateProvider, prov
checkKeyConnectivity, checkKeyConnectivity,
checkAllKeysConnectivity, checkAllKeysConnectivity,
isChecking isChecking
} = useApiKeys({ provider, updateProvider, providerKind: providerKind }) } = useApiKeys({ provider, updateProvider })
// 创建一个临时新项 // 创建一个临时新项
const handleAddNew = () => { const handleAddNew = () => {
@ -73,7 +73,7 @@ export const ApiKeyList: FC<ApiKeyListProps> = ({ provider, updateProvider, prov
const shouldAutoFocus = () => { const shouldAutoFocus = () => {
if (provider.apiKey) return false if (provider.apiKey) return false
return isLlmProvider(provider, providerKind) && provider.enabled && !isProviderSupportAuth(provider) return isLlmProvider(provider) && provider.enabled && !isProviderSupportAuth(provider)
} }
// 合并真实 keys 和临时新项 // 合并真实 keys 和临时新项
@ -179,55 +179,33 @@ export const ApiKeyList: FC<ApiKeyListProps> = ({ provider, updateProvider, prov
interface SpecificApiKeyListProps { interface SpecificApiKeyListProps {
providerId: string providerId: string
providerKind: ApiProviderKind
showHealthCheck?: boolean showHealthCheck?: boolean
} }
export const LlmApiKeyList: FC<SpecificApiKeyListProps> = ({ providerId, providerKind, showHealthCheck = true }) => { type WebSearchApiKeyList = SpecificApiKeyListProps & {
providerId: WebSearchProviderId
}
type DocPreprocessApiKeyListProps = SpecificApiKeyListProps & {
providerId: PreprocessProviderId
}
export const LlmApiKeyList: FC<SpecificApiKeyListProps> = ({ providerId, showHealthCheck = true }) => {
const { provider, updateProvider } = useProvider(providerId) const { provider, updateProvider } = useProvider(providerId)
return ( return <ApiKeyList provider={provider} updateProvider={updateProvider} showHealthCheck={showHealthCheck} />
<ApiKeyList
provider={provider}
updateProvider={updateProvider}
providerKind={providerKind}
showHealthCheck={showHealthCheck}
/>
)
} }
export const WebSearchApiKeyList: FC<SpecificApiKeyListProps> = ({ export const WebSearchApiKeyList: FC<WebSearchApiKeyList> = ({ providerId, showHealthCheck = true }) => {
providerId,
providerKind,
showHealthCheck = true
}) => {
const { provider, updateProvider } = useWebSearchProvider(providerId) const { provider, updateProvider } = useWebSearchProvider(providerId)
return ( return <ApiKeyList provider={provider} updateProvider={updateProvider} showHealthCheck={showHealthCheck} />
<ApiKeyList
provider={provider}
updateProvider={updateProvider}
providerKind={providerKind}
showHealthCheck={showHealthCheck}
/>
)
} }
export const DocPreprocessApiKeyList: FC<SpecificApiKeyListProps> = ({ export const DocPreprocessApiKeyList: FC<DocPreprocessApiKeyListProps> = ({ providerId, showHealthCheck = true }) => {
providerId,
providerKind,
showHealthCheck = true
}) => {
const { provider, updateProvider } = usePreprocessProvider(providerId) const { provider, updateProvider } = usePreprocessProvider(providerId)
return ( return <ApiKeyList provider={provider} updateProvider={updateProvider} showHealthCheck={showHealthCheck} />
<ApiKeyList
provider={provider}
updateProvider={updateProvider}
providerKind={providerKind}
showHealthCheck={showHealthCheck}
/>
)
} }
const ListContainer = styled.div` const ListContainer = styled.div`

View File

@ -1,14 +1,13 @@
import { TopView } from '@renderer/components/TopView' import { TopView } from '@renderer/components/TopView'
import { isPreprocessProviderId, isWebSearchProviderId } from '@renderer/types'
import { Modal } from 'antd' import { Modal } from 'antd'
import { useMemo, useState } from 'react' import { useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { DocPreprocessApiKeyList, LlmApiKeyList, WebSearchApiKeyList } from './list' import { DocPreprocessApiKeyList, LlmApiKeyList, WebSearchApiKeyList } from './list'
import { ApiProviderKind } from './types'
interface ShowParams { interface ShowParams {
providerId: string providerId: string
providerKind: ApiProviderKind
title?: string title?: string
showHealthCheck?: boolean showHealthCheck?: boolean
} }
@ -20,7 +19,7 @@ interface Props extends ShowParams {
/** /**
* API Key * API Key
*/ */
const PopupContainer: React.FC<Props> = ({ providerId, providerKind, title, resolve, showHealthCheck = true }) => { const PopupContainer: React.FC<Props> = ({ providerId, title, resolve, showHealthCheck = true }) => {
const [open, setOpen] = useState(true) const [open, setOpen] = useState(true)
const { t } = useTranslation() const { t } = useTranslation()
@ -33,17 +32,14 @@ const PopupContainer: React.FC<Props> = ({ providerId, providerKind, title, reso
} }
const ListComponent = useMemo(() => { const ListComponent = useMemo(() => {
switch (providerKind) { if (isWebSearchProviderId(providerId)) {
case 'llm': return <WebSearchApiKeyList providerId={providerId} showHealthCheck={showHealthCheck} />
return LlmApiKeyList
case 'websearch':
return WebSearchApiKeyList
case 'doc-preprocess':
return DocPreprocessApiKeyList
default:
return null
} }
}, [providerKind]) if (isPreprocessProviderId(providerId)) {
return <DocPreprocessApiKeyList providerId={providerId} showHealthCheck={showHealthCheck} />
}
return <LlmApiKeyList providerId={providerId} showHealthCheck={showHealthCheck} />
}, [providerId, showHealthCheck])
return ( return (
<Modal <Modal
@ -55,9 +51,7 @@ const PopupContainer: React.FC<Props> = ({ providerId, providerKind, title, reso
centered centered
width={600} width={600}
footer={null}> footer={null}>
{ListComponent && ( {ListComponent}
<ListComponent providerId={providerId} providerKind={providerKind} showHealthCheck={showHealthCheck} />
)}
</Modal> </Modal>
) )
} }

View File

@ -8,6 +8,12 @@ export type ApiKeyValidity = {
error?: string error?: string
} }
export type ApiProviderUnion = Provider | WebSearchProvider | PreprocessProvider export type ApiProvider = Provider | WebSearchProvider | PreprocessProvider
export type ApiProviderKind = 'llm' | 'websearch' | 'doc-preprocess' export type UpdateProviderFunc = (p: Partial<Provider>) => void
export type UpdateWebSearchProviderFunc = (p: Partial<WebSearchProvider>) => void
export type UpdatePreprocessProviderFunc = (p: Partial<PreprocessProvider>) => void
export type UpdateApiProviderFunc = UpdateProviderFunc | UpdateWebSearchProviderFunc | UpdatePreprocessProviderFunc

View File

@ -2721,7 +2721,7 @@ export function isSupportedThinkingTokenDoubaoModel(model?: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/') const modelId = getLowerBaseModelName(model.id, '/')
return DOUBAO_THINKING_MODEL_REGEX.test(modelId) || DOUBAO_THINKING_MODEL_REGEX.test(modelId) return DOUBAO_THINKING_MODEL_REGEX.test(modelId) || DOUBAO_THINKING_MODEL_REGEX.test(model.name)
} }
export function isClaudeReasoningModel(model?: Model): boolean { export function isClaudeReasoningModel(model?: Model): boolean {

View File

@ -1,8 +1,9 @@
import Doc2xLogo from '@renderer/assets/images/ocr/doc2x.png' import Doc2xLogo from '@renderer/assets/images/ocr/doc2x.png'
import MinerULogo from '@renderer/assets/images/ocr/mineru.jpg' import MinerULogo from '@renderer/assets/images/ocr/mineru.jpg'
import MistralLogo from '@renderer/assets/images/providers/mistral.png' import MistralLogo from '@renderer/assets/images/providers/mistral.png'
import { PreprocessProviderId } from '@renderer/types'
export function getPreprocessProviderLogo(providerId: string) { export function getPreprocessProviderLogo(providerId: PreprocessProviderId) {
switch (providerId) { switch (providerId) {
case 'doc2x': case 'doc2x':
return Doc2xLogo return Doc2xLogo
@ -15,7 +16,9 @@ export function getPreprocessProviderLogo(providerId: string) {
} }
} }
export const PREPROCESS_PROVIDER_CONFIG = { type PreprocessProviderConfig = { websites: { official: string; apiKey: string } }
export const PREPROCESS_PROVIDER_CONFIG: Record<PreprocessProviderId, PreprocessProviderConfig> = {
doc2x: { doc2x: {
websites: { websites: {
official: 'https://doc2x.noedgeai.com', official: 'https://doc2x.noedgeai.com',

View File

@ -1,24 +1,13 @@
import BochaLogo from '@renderer/assets/images/search/bocha.webp' import { WebSearchProvider, WebSearchProviderId } from '@renderer/types'
import ExaLogo from '@renderer/assets/images/search/exa.png'
import SearxngLogo from '@renderer/assets/images/search/searxng.svg'
import TavilyLogo from '@renderer/assets/images/search/tavily.png'
export function getWebSearchProviderLogo(providerId: string) { type WebSearchProviderConfig = {
switch (providerId) { websites: {
case 'tavily': official: string
return TavilyLogo apiKey?: string
case 'searxng':
return SearxngLogo
case 'exa':
return ExaLogo
case 'bocha':
return BochaLogo
default:
return undefined
} }
} }
export const WEB_SEARCH_PROVIDER_CONFIG = { export const WEB_SEARCH_PROVIDER_CONFIG: Record<WebSearchProviderId, WebSearchProviderConfig> = {
tavily: { tavily: {
websites: { websites: {
official: 'https://tavily.com', official: 'https://tavily.com',
@ -58,3 +47,46 @@ export const WEB_SEARCH_PROVIDER_CONFIG = {
} }
} }
} }
export const WEB_SEARCH_PROVIDERS: WebSearchProvider[] = [
{
id: 'tavily',
name: 'Tavily',
apiHost: 'https://api.tavily.com',
apiKey: ''
},
{
id: 'searxng',
name: 'Searxng',
apiHost: '',
basicAuthUsername: '',
basicAuthPassword: ''
},
{
id: 'exa',
name: 'Exa',
apiHost: 'https://api.exa.ai',
apiKey: ''
},
{
id: 'bocha',
name: 'Bocha',
apiHost: 'https://api.bochaai.com',
apiKey: ''
},
{
id: 'local-google',
name: 'Google',
url: 'https://www.google.com/search?q=%s'
},
{
id: 'local-bing',
name: 'Bing',
url: 'https://cn.bing.com/search?q=%s&ensearch=1'
},
{
id: 'local-baidu',
name: 'Baidu',
url: 'https://www.baidu.com/s?wd=%s'
}
] as const

View File

@ -4,10 +4,10 @@ import {
updatePreprocessProvider as _updatePreprocessProvider, updatePreprocessProvider as _updatePreprocessProvider,
updatePreprocessProviders as _updatePreprocessProviders updatePreprocessProviders as _updatePreprocessProviders
} from '@renderer/store/preprocess' } from '@renderer/store/preprocess'
import { PreprocessProvider } from '@renderer/types' import { PreprocessProvider, PreprocessProviderId } from '@renderer/types'
import { useDispatch, useSelector } from 'react-redux' import { useDispatch, useSelector } from 'react-redux'
export const usePreprocessProvider = (id: string) => { export const usePreprocessProvider = (id: PreprocessProviderId) => {
const dispatch = useDispatch() const dispatch = useDispatch()
const preprocessProviders = useSelector((state: RootState) => state.preprocess.providers) const preprocessProviders = useSelector((state: RootState) => state.preprocess.providers)
const provider = preprocessProviders.find((provider) => provider.id === id) const provider = preprocessProviders.find((provider) => provider.id === id)

View File

@ -11,7 +11,7 @@ import {
updateWebSearchProvider, updateWebSearchProvider,
updateWebSearchProviders updateWebSearchProviders
} from '@renderer/store/websearch' } from '@renderer/store/websearch'
import { WebSearchProvider } from '@renderer/types' import { WebSearchProvider, WebSearchProviderId } from '@renderer/types'
export const useDefaultWebSearchProvider = () => { export const useDefaultWebSearchProvider = () => {
const defaultProvider = useAppSelector((state) => state.websearch.defaultProvider) const defaultProvider = useAppSelector((state) => state.websearch.defaultProvider)
@ -49,7 +49,7 @@ export const useWebSearchProviders = () => {
} }
} }
export const useWebSearchProvider = (id: string) => { export const useWebSearchProvider = (id: WebSearchProviderId) => {
const providers = useAppSelector((state) => state.websearch.providers) const providers = useAppSelector((state) => state.websearch.providers)
const provider = providers.find((provider) => provider.id === id) const provider = providers.find((provider) => provider.id === id)
const dispatch = useAppDispatch() const dispatch = useAppDispatch()
@ -60,7 +60,9 @@ export const useWebSearchProvider = (id: string) => {
return { return {
provider, provider,
updateProvider: (updates: Partial<WebSearchProvider>) => dispatch(updateWebSearchProvider({ id, ...updates })) updateProvider: (updates: Partial<WebSearchProvider>) => {
dispatch(updateWebSearchProvider({ id, ...updates }))
}
} }
} }

View File

@ -24,7 +24,7 @@ const GenerateImageButton: FC<Props> = ({ model, ToolbarButton, assistant, onEna
mouseLeaveDelay={0} mouseLeaveDelay={0}
arrow> arrow>
<ToolbarButton type="text" disabled={!isGenerateImageModel(model)} onClick={onEnableGenerateImage}> <ToolbarButton type="text" disabled={!isGenerateImageModel(model)} onClick={onEnableGenerateImage}>
<Image size={18} color={assistant.enableGenerateImage ? 'var(--color-link)' : 'var(--color-icon)'} /> <Image size={18} color={assistant.enableGenerateImage ? 'var(--color-primary)' : 'var(--color-icon)'} />
</ToolbarButton> </ToolbarButton>
</Tooltip> </Tooltip>
) )

View File

@ -87,7 +87,10 @@ const KnowledgeBaseButton: FC<Props> = ({ ref, selectedBases, onSelect, disabled
return ( return (
<Tooltip placement="top" title={t('chat.input.knowledge_base')} mouseLeaveDelay={0} arrow> <Tooltip placement="top" title={t('chat.input.knowledge_base')} mouseLeaveDelay={0} arrow>
<ToolbarButton type="text" onClick={handleOpenQuickPanel} disabled={disabled}> <ToolbarButton type="text" onClick={handleOpenQuickPanel} disabled={disabled}>
<FileSearch size={18} /> <FileSearch
size={18}
color={selectedBases && selectedBases.length > 0 ? 'var(--color-primary)' : 'var(--color-icon)'}
/>
</ToolbarButton> </ToolbarButton>
</Tooltip> </Tooltip>
) )

View File

@ -195,7 +195,7 @@ const MentionModelsButton: FC<Props> = ({
return ( return (
<Tooltip placement="top" title={t('agents.edit.model.select.title')} mouseLeaveDelay={0} arrow> <Tooltip placement="top" title={t('agents.edit.model.select.title')} mouseLeaveDelay={0} arrow>
<ToolbarButton type="text" onClick={handleOpenQuickPanel}> <ToolbarButton type="text" onClick={handleOpenQuickPanel}>
<AtSign size={18} /> <AtSign size={18} color={mentionedModels.length > 0 ? 'var(--color-primary)' : 'var(--color-icon)'} />
</ToolbarButton> </ToolbarButton>
</Tooltip> </Tooltip>
) )

View File

@ -33,7 +33,7 @@ const UrlContextButton: FC<Props> = ({ assistant, ToolbarButton }) => {
<Link <Link
size={18} size={18}
style={{ style={{
color: assistant.enableUrlContext ? 'var(--color-link)' : 'var(--color-icon)' color: assistant.enableUrlContext ? 'var(--color-primary)' : 'var(--color-icon)'
}} }}
/> />
</ToolbarButton> </ToolbarButton>

View File

@ -1,9 +1,11 @@
import { BaiduOutlined, GoogleOutlined } from '@ant-design/icons'
import { BingLogo, BochaLogo, ExaLogo, SearXNGLogo, TavilyLogo } from '@renderer/components/Icons'
import { QuickPanelListItem, useQuickPanel } from '@renderer/components/QuickPanel' import { QuickPanelListItem, useQuickPanel } from '@renderer/components/QuickPanel'
import { isWebSearchModel } from '@renderer/config/models' import { isWebSearchModel } from '@renderer/config/models'
import { useAssistant } from '@renderer/hooks/useAssistant' import { useAssistant } from '@renderer/hooks/useAssistant'
import { useWebSearchProviders } from '@renderer/hooks/useWebSearchProviders' import { useWebSearchProviders } from '@renderer/hooks/useWebSearchProviders'
import WebSearchService from '@renderer/services/WebSearchService' import WebSearchService from '@renderer/services/WebSearchService'
import { Assistant, WebSearchProvider } from '@renderer/types' import { Assistant, WebSearchProvider, WebSearchProviderId } from '@renderer/types'
import { hasObjectKey } from '@renderer/utils' import { hasObjectKey } from '@renderer/utils'
import { Tooltip } from 'antd' import { Tooltip } from 'antd'
import { Globe } from 'lucide-react' import { Globe } from 'lucide-react'
@ -28,6 +30,33 @@ const WebSearchButton: FC<Props> = ({ ref, assistant, ToolbarButton }) => {
const enableWebSearch = assistant?.webSearchProviderId || assistant.enableWebSearch const enableWebSearch = assistant?.webSearchProviderId || assistant.enableWebSearch
const WebSearchIcon = useCallback(
({ pid, size = 18 }: { pid?: WebSearchProviderId; size?: number }) => {
const iconColor = enableWebSearch ? 'var(--color-primary)' : 'var(--color-icon)'
switch (pid) {
case 'bocha':
return <BochaLogo width={size} height={size} color={iconColor} />
case 'exa':
// size微调视觉上和其他图标平衡一些
return <ExaLogo width={size - 2} height={size} color={iconColor} />
case 'tavily':
return <TavilyLogo width={size} height={size} color={iconColor} />
case 'searxng':
return <SearXNGLogo width={size} height={size} color={iconColor} />
case 'local-baidu':
return <BaiduOutlined size={size} style={{ color: iconColor, fontSize: size }} />
case 'local-bing':
return <BingLogo width={size} height={size} color={iconColor} />
case 'local-google':
return <GoogleOutlined size={size} style={{ color: iconColor, fontSize: size }} />
default:
return <Globe size={size} style={{ color: iconColor, fontSize: size }} />
}
},
[enableWebSearch]
)
const updateSelectedWebSearchProvider = useCallback( const updateSelectedWebSearchProvider = useCallback(
async (providerId?: WebSearchProvider['id']) => { async (providerId?: WebSearchProvider['id']) => {
// TODO: updateAssistant有性能问题会导致关闭快捷面板卡顿 // TODO: updateAssistant有性能问题会导致关闭快捷面板卡顿
@ -58,7 +87,7 @@ const WebSearchButton: FC<Props> = ({ ref, assistant, ToolbarButton }) => {
? t('settings.tool.websearch.apikey') ? t('settings.tool.websearch.apikey')
: t('settings.tool.websearch.free') : t('settings.tool.websearch.free')
: t('chat.input.web_search.enable_content'), : t('chat.input.web_search.enable_content'),
icon: <Globe />, icon: <WebSearchIcon size={13} pid={p.id} />,
isSelected: p.id === assistant?.webSearchProviderId, isSelected: p.id === assistant?.webSearchProviderId,
disabled: !WebSearchService.isWebSearchEnabled(p.id), disabled: !WebSearchService.isWebSearchEnabled(p.id),
action: () => updateSelectedWebSearchProvider(p.id) action: () => updateSelectedWebSearchProvider(p.id)
@ -80,6 +109,7 @@ const WebSearchButton: FC<Props> = ({ ref, assistant, ToolbarButton }) => {
return items return items
}, [ }, [
WebSearchIcon,
assistant.enableWebSearch, assistant.enableWebSearch,
assistant.model, assistant.model,
assistant?.webSearchProviderId, assistant?.webSearchProviderId,
@ -135,12 +165,7 @@ const WebSearchButton: FC<Props> = ({ ref, assistant, ToolbarButton }) => {
mouseLeaveDelay={0} mouseLeaveDelay={0}
arrow> arrow>
<ToolbarButton type="text" onClick={handleOpenQuickPanel}> <ToolbarButton type="text" onClick={handleOpenQuickPanel}>
<Globe <WebSearchIcon pid={assistant.webSearchProviderId} />
size={18}
style={{
color: enableWebSearch ? 'var(--color-primary)' : 'var(--color-icon)'
}}
/>
</ToolbarButton> </ToolbarButton>
</Tooltip> </Tooltip>
) )

View File

@ -2,14 +2,14 @@ import { loggerService } from '@logger'
import { usePreprocessProvider } from '@renderer/hooks/usePreprocess' import { usePreprocessProvider } from '@renderer/hooks/usePreprocess'
import { getStoreSetting } from '@renderer/hooks/useSettings' import { getStoreSetting } from '@renderer/hooks/useSettings'
import { getKnowledgeBaseParams } from '@renderer/services/KnowledgeService' import { getKnowledgeBaseParams } from '@renderer/services/KnowledgeService'
import { KnowledgeBase } from '@renderer/types' import { KnowledgeBase, PreprocessProviderId } from '@renderer/types'
import { Tag } from 'antd' import { Tag } from 'antd'
import { FC, useEffect, useState } from 'react' import { FC, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
const logger = loggerService.withContext('QuotaTag') const logger = loggerService.withContext('QuotaTag')
const QuotaTag: FC<{ base: KnowledgeBase; providerId: string; quota?: number }> = ({ const QuotaTag: FC<{ base: KnowledgeBase; providerId: PreprocessProviderId; quota?: number }> = ({
base, base,
providerId, providerId,
quota: _quota quota: _quota

View File

@ -4,23 +4,14 @@ import { getPreprocessProviderLogo, PREPROCESS_PROVIDER_CONFIG } from '@renderer
import { usePreprocessProvider } from '@renderer/hooks/usePreprocess' import { usePreprocessProvider } from '@renderer/hooks/usePreprocess'
import { PreprocessProvider } from '@renderer/types' import { PreprocessProvider } from '@renderer/types'
import { formatApiKeys, hasObjectKey } from '@renderer/utils' import { formatApiKeys, hasObjectKey } from '@renderer/utils'
import { Avatar, Button, Divider, Flex, Input, InputNumber, Segmented, Tooltip } from 'antd' import { Avatar, Button, Divider, Flex, Input, Tooltip } from 'antd'
import Link from 'antd/es/typography/Link' import Link from 'antd/es/typography/Link'
import { List } from 'lucide-react' import { List } from 'lucide-react'
import { FC, useEffect, useState } from 'react' import { FC, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import styled from 'styled-components' import styled from 'styled-components'
import { import { SettingHelpLink, SettingHelpText, SettingHelpTextRow, SettingSubtitle, SettingTitle } from '..'
SettingDivider,
SettingHelpLink,
SettingHelpText,
SettingHelpTextRow,
SettingRow,
SettingRowTitle,
SettingSubtitle,
SettingTitle
} from '..'
interface Props { interface Props {
provider: PreprocessProvider provider: PreprocessProvider
@ -31,7 +22,7 @@ const PreprocessProviderSettings: FC<Props> = ({ provider: _provider }) => {
const { t } = useTranslation() const { t } = useTranslation()
const [apiKey, setApiKey] = useState(preprocessProvider.apiKey || '') const [apiKey, setApiKey] = useState(preprocessProvider.apiKey || '')
const [apiHost, setApiHost] = useState(preprocessProvider.apiHost || '') const [apiHost, setApiHost] = useState(preprocessProvider.apiHost || '')
const [options, setOptions] = useState(preprocessProvider.options || {}) // const [options, setOptions] = useState(preprocessProvider.options || {})
const preprocessProviderConfig = PREPROCESS_PROVIDER_CONFIG[preprocessProvider.id] const preprocessProviderConfig = PREPROCESS_PROVIDER_CONFIG[preprocessProvider.id]
const apiKeyWebsite = preprocessProviderConfig?.websites?.apiKey const apiKeyWebsite = preprocessProviderConfig?.websites?.apiKey
@ -40,7 +31,7 @@ const PreprocessProviderSettings: FC<Props> = ({ provider: _provider }) => {
useEffect(() => { useEffect(() => {
setApiKey(preprocessProvider.apiKey ?? '') setApiKey(preprocessProvider.apiKey ?? '')
setApiHost(preprocessProvider.apiHost ?? '') setApiHost(preprocessProvider.apiHost ?? '')
setOptions(preprocessProvider.options ?? {}) // setOptions(preprocessProvider.options ?? {})
}, [preprocessProvider.apiKey, preprocessProvider.apiHost, preprocessProvider.options]) }, [preprocessProvider.apiKey, preprocessProvider.apiHost, preprocessProvider.options])
const onUpdateApiKey = () => { const onUpdateApiKey = () => {
@ -52,7 +43,6 @@ const PreprocessProviderSettings: FC<Props> = ({ provider: _provider }) => {
const openApiKeyList = async () => { const openApiKeyList = async () => {
await ApiKeyListPopup.show({ await ApiKeyListPopup.show({
providerId: preprocessProvider.id, providerId: preprocessProvider.id,
providerKind: 'doc-preprocess',
title: `${preprocessProvider.name} ${t('settings.provider.api.key.list.title')}`, title: `${preprocessProvider.name} ${t('settings.provider.api.key.list.title')}`,
showHealthCheck: false // FIXME: 目前还没有检查功能 showHealthCheck: false // FIXME: 目前还没有检查功能
}) })
@ -70,11 +60,11 @@ const PreprocessProviderSettings: FC<Props> = ({ provider: _provider }) => {
} }
} }
const onUpdateOptions = (key: string, value: any) => { // const onUpdateOptions = (key: string, value: any) => {
const newOptions = { ...options, [key]: value } // const newOptions = { ...options, [key]: value }
setOptions(newOptions) // setOptions(newOptions)
updateProvider({ options: newOptions }) // updateProvider({ options: newOptions })
} // }
return ( return (
<> <>
@ -145,7 +135,7 @@ const PreprocessProviderSettings: FC<Props> = ({ provider: _provider }) => {
)} )}
{/* 这部分看起来暂时用不上了 */} {/* 这部分看起来暂时用不上了 */}
{hasObjectKey(preprocessProvider, 'options') && preprocessProvider.id === 'system' && ( {/* {hasObjectKey(preprocessProvider, 'options') && preprocessProvider.id === 'system' && (
<> <>
<SettingDivider style={{ marginTop: 15, marginBottom: 12 }} /> <SettingDivider style={{ marginTop: 15, marginBottom: 12 }} />
<SettingRow> <SettingRow>
@ -177,7 +167,7 @@ const PreprocessProviderSettings: FC<Props> = ({ provider: _provider }) => {
/> />
</SettingRow> </SettingRow>
</> </>
)} )} */}
</> </>
) )
} }

View File

@ -1,4 +1,3 @@
import { isMac } from '@renderer/config/constant'
import { useTheme } from '@renderer/context/ThemeProvider' import { useTheme } from '@renderer/context/ThemeProvider'
import { useDefaultPreprocessProvider, usePreprocessProviders } from '@renderer/hooks/usePreprocess' import { useDefaultPreprocessProvider, usePreprocessProviders } from '@renderer/hooks/usePreprocess'
import { PreprocessProvider } from '@renderer/types' import { PreprocessProvider } from '@renderer/types'
@ -40,8 +39,9 @@ const PreprocessSettings: FC = () => {
placeholder={t('settings.tool.preprocess.provider_placeholder')} placeholder={t('settings.tool.preprocess.provider_placeholder')}
options={preprocessProviders.map((p) => ({ options={preprocessProviders.map((p) => ({
value: p.id, value: p.id,
label: p.name, label: p.name
disabled: !isMac && p.id === 'system' // 在非 Mac 系统下禁用 system 选项 // 由于system字段实际未使用先注释掉
// disabled: !isMac && p.id === 'system' // 在非 Mac 系统下禁用 system 选项
}))} }))}
/> />
</div> </div>

View File

@ -128,7 +128,6 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
const openApiKeyList = async () => { const openApiKeyList = async () => {
await ApiKeyListPopup.show({ await ApiKeyListPopup.show({
providerId: provider.id, providerId: provider.id,
providerKind: 'llm',
title: `${fancyProviderName} ${t('settings.provider.api.key.list.title')}` title: `${fancyProviderName} ${t('settings.provider.api.key.list.title')}`
}) })
} }

View File

@ -1,9 +1,14 @@
import { CheckOutlined, ExportOutlined, LoadingOutlined } from '@ant-design/icons' import { CheckOutlined, ExportOutlined, LoadingOutlined } from '@ant-design/icons'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import BochaLogo from '@renderer/assets/images/search/bocha.webp'
import ExaLogo from '@renderer/assets/images/search/exa.png'
import SearxngLogo from '@renderer/assets/images/search/searxng.svg'
import TavilyLogo from '@renderer/assets/images/search/tavily.png'
import ApiKeyListPopup from '@renderer/components/Popups/ApiKeyListPopup/popup' import ApiKeyListPopup from '@renderer/components/Popups/ApiKeyListPopup/popup'
import { getWebSearchProviderLogo, WEB_SEARCH_PROVIDER_CONFIG } from '@renderer/config/webSearchProviders' import { WEB_SEARCH_PROVIDER_CONFIG } from '@renderer/config/webSearchProviders'
import { useWebSearchProvider } from '@renderer/hooks/useWebSearchProviders' import { useWebSearchProvider } from '@renderer/hooks/useWebSearchProviders'
import WebSearchService from '@renderer/services/WebSearchService' import WebSearchService from '@renderer/services/WebSearchService'
import { WebSearchProviderId } from '@renderer/types'
import { formatApiKeys, hasObjectKey } from '@renderer/utils' import { formatApiKeys, hasObjectKey } from '@renderer/utils'
import { Button, Divider, Flex, Form, Input, Space, Tooltip } from 'antd' import { Button, Divider, Flex, Form, Input, Space, Tooltip } from 'antd'
import Link from 'antd/es/typography/Link' import Link from 'antd/es/typography/Link'
@ -16,7 +21,7 @@ import { SettingDivider, SettingHelpLink, SettingHelpText, SettingHelpTextRow, S
const logger = loggerService.withContext('WebSearchProviderSetting') const logger = loggerService.withContext('WebSearchProviderSetting')
interface Props { interface Props {
providerId: string providerId: WebSearchProviderId
} }
const WebSearchProviderSetting: FC<Props> = ({ providerId }) => { const WebSearchProviderSetting: FC<Props> = ({ providerId }) => {
@ -74,7 +79,6 @@ const WebSearchProviderSetting: FC<Props> = ({ providerId }) => {
const openApiKeyList = async () => { const openApiKeyList = async () => {
await ApiKeyListPopup.show({ await ApiKeyListPopup.show({
providerId: provider.id, providerId: provider.id,
providerKind: 'websearch',
title: `${provider.name} ${t('settings.provider.api.key.list.title')}` title: `${provider.name} ${t('settings.provider.api.key.list.title')}`
}) })
} }
@ -132,6 +136,21 @@ const WebSearchProviderSetting: FC<Props> = ({ providerId }) => {
setBasicAuthPassword(provider.basicAuthPassword ?? '') setBasicAuthPassword(provider.basicAuthPassword ?? '')
}, [provider.apiKey, provider.apiHost, provider.basicAuthUsername, provider.basicAuthPassword]) }, [provider.apiKey, provider.apiHost, provider.basicAuthUsername, provider.basicAuthPassword])
const getWebSearchProviderLogo = (providerId: WebSearchProviderId) => {
switch (providerId) {
case 'tavily':
return TavilyLogo
case 'searxng':
return SearxngLogo
case 'exa':
return ExaLogo
case 'bocha':
return BochaLogo
default:
return undefined
}
}
return ( return (
<> <>
<SettingTitle> <SettingTitle>

View File

@ -104,9 +104,9 @@ async function fetchExternalTool(
const showListTools = enabledMCPs && enabledMCPs.length > 0 const showListTools = enabledMCPs && enabledMCPs.length > 0
// 是否使用工具 // 是否使用工具
const hasAnyTool = shouldWebSearch || shouldKnowledgeSearch || shouldSearchMemory || showListTools const hasAnyTool = shouldWebSearch || shouldKnowledgeSearch || showListTools
// 在工具链开始时发送进度通知 // 在工具链开始时发送进度通知(不包括记忆搜索)
if (hasAnyTool) { if (hasAnyTool) {
onChunkReceived({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS }) onChunkReceived({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS })
} }
@ -456,8 +456,6 @@ export async function fetchChatCompletion({
const { mcpTools } = await fetchExternalTool(lastUserMessage, assistant, onChunkReceived, lastAnswer) const { mcpTools } = await fetchExternalTool(lastUserMessage, assistant, onChunkReceived, lastAnswer)
const model = assistant.model || getDefaultModel() const model = assistant.model || getDefaultModel()
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
const { maxTokens, contextCount } = getAssistantSettings(assistant) const { maxTokens, contextCount } = getAssistantSettings(assistant)
const filteredMessages2 = filterUsefulMessages(filteredMessages1) const filteredMessages2 = filterUsefulMessages(filteredMessages1)
@ -488,7 +486,7 @@ export async function fetchChatCompletion({
isGenerateImageModel(model) && (isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage : true) isGenerateImageModel(model) && (isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage : true)
// --- Call AI Completions --- // --- Call AI Completions ---
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
const completionsParams: CompletionsParams = { const completionsParams: CompletionsParams = {
callType: 'chat', callType: 'chat',
messages: _messages, messages: _messages,

View File

@ -40,6 +40,7 @@ export const createCitationCallbacks = (deps: CitationCallbacksDependencies) =>
status: MessageBlockStatus.SUCCESS status: MessageBlockStatus.SUCCESS
} }
blockManager.smartBlockUpdate(citationBlockId, changes, MessageBlockType.CITATION, true) blockManager.smartBlockUpdate(citationBlockId, changes, MessageBlockType.CITATION, true)
citationBlockId = null
} else { } else {
logger.error('[onExternalToolComplete] citationBlockId is null. Cannot update.') logger.error('[onExternalToolComplete] citationBlockId is null. Cannot update.')
} }

View File

@ -1,4 +1,5 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit' import { createSlice, PayloadAction } from '@reduxjs/toolkit'
import { WEB_SEARCH_PROVIDERS } from '@renderer/config/webSearchProviders'
import type { Model, WebSearchProvider } from '@renderer/types' import type { Model, WebSearchProvider } from '@renderer/types'
export interface SubscribeSource { export interface SubscribeSource {
key: number key: number
@ -42,48 +43,7 @@ export interface WebSearchState {
export const initialState: WebSearchState = { export const initialState: WebSearchState = {
defaultProvider: 'local-bing', defaultProvider: 'local-bing',
providers: [ providers: WEB_SEARCH_PROVIDERS,
{
id: 'tavily',
name: 'Tavily',
apiHost: 'https://api.tavily.com',
apiKey: ''
},
{
id: 'searxng',
name: 'Searxng',
apiHost: '',
basicAuthUsername: '',
basicAuthPassword: ''
},
{
id: 'exa',
name: 'Exa',
apiHost: 'https://api.exa.ai',
apiKey: ''
},
{
id: 'bocha',
name: 'Bocha',
apiHost: 'https://api.bochaai.com',
apiKey: ''
},
{
id: 'local-google',
name: 'Google',
url: 'https://www.google.com/search?q=%s'
},
{
id: 'local-bing',
name: 'Bing',
url: 'https://cn.bing.com/search?q=%s&ensearch=1'
},
{
id: 'local-baidu',
name: 'Baidu',
url: 'https://www.baidu.com/s?wd=%s'
}
],
searchWithTime: true, searchWithTime: true,
maxResults: 5, maxResults: 5,
excludeDomains: [], excludeDomains: [],
@ -111,7 +71,7 @@ const websearchSlice = createSlice({
updateWebSearchProviders: (state, action: PayloadAction<WebSearchProvider[]>) => { updateWebSearchProviders: (state, action: PayloadAction<WebSearchProvider[]>) => {
state.providers = action.payload state.providers = action.payload
}, },
updateWebSearchProvider: (state, action: PayloadAction<Partial<WebSearchProvider> & { id: string }>) => { updateWebSearchProvider: (state, action: PayloadAction<Partial<WebSearchProvider>>) => {
const index = state.providers.findIndex((provider) => provider.id === action.payload.id) const index = state.providers.findIndex((provider) => provider.id === action.payload.id)
if (index !== -1) { if (index !== -1) {
Object.assign(state.providers[index], action.payload) Object.assign(state.providers[index], action.payload)

View File

@ -609,8 +609,20 @@ export type KnowledgeBaseParams = {
} }
} }
export const PreprocessProviderIds = {
doc2x: 'doc2x',
mistral: 'mistral',
mineru: 'mineru'
} as const
export type PreprocessProviderId = keyof typeof PreprocessProviderIds
export const isPreprocessProviderId = (id: string): id is PreprocessProviderId => {
return Object.hasOwn(PreprocessProviderIds, id)
}
export interface PreprocessProvider { export interface PreprocessProvider {
id: string id: PreprocessProviderId
name: string name: string
apiKey?: string apiKey?: string
apiHost?: string apiHost?: string
@ -675,8 +687,24 @@ export type ExternalToolResult = {
memories?: MemoryItem[] memories?: MemoryItem[]
} }
export const WebSearchProviderIds = {
tavily: 'tavily',
searxng: 'searxng',
exa: 'exa',
bocha: 'bocha',
'local-google': 'local-google',
'local-bing': 'local-bing',
'local-baidu': 'local-baidu'
} as const
export type WebSearchProviderId = keyof typeof WebSearchProviderIds
export const isWebSearchProviderId = (id: string): id is WebSearchProviderId => {
return Object.hasOwn(WebSearchProviderIds, id)
}
export type WebSearchProvider = { export type WebSearchProvider = {
id: string id: WebSearchProviderId
name: string name: string
apiKey?: string apiKey?: string
apiHost?: string apiHost?: string

View File

@ -162,6 +162,7 @@ export interface AwsBedrockSdkParams {
topP?: number topP?: number
stream?: boolean stream?: boolean
tools?: AwsBedrockSdkTool[] tools?: AwsBedrockSdkTool[]
[key: string]: any // Allow any additional custom parameters
} }
export interface AwsBedrockSdkMessageParam { export interface AwsBedrockSdkMessageParam {
@ -206,6 +207,22 @@ export interface AwsBedrockSdkMessageParam {
}> }>
} }
export interface AwsBedrockStreamChunk {
type: string
delta?: {
text?: string
toolUse?: { input?: string }
type?: string
thinking?: string
}
index?: number
content_block?: any
usage?: {
inputTokens?: number
outputTokens?: number
}
}
export interface AwsBedrockSdkRawChunk { export interface AwsBedrockSdkRawChunk {
contentBlockStart?: { contentBlockStart?: {
start?: { start?: {
@ -222,6 +239,8 @@ export interface AwsBedrockSdkRawChunk {
toolUse?: { toolUse?: {
input?: string input?: string
} }
type?: string // 支持 'thinking_delta' 等类型
thinking?: string // 支持 thinking 内容
} }
contentBlockIndex?: number contentBlockIndex?: number
} }

View File

@ -8629,7 +8629,7 @@ __metadata:
remove-markdown: "npm:^0.6.2" remove-markdown: "npm:^0.6.2"
rollup-plugin-visualizer: "npm:^5.12.0" rollup-plugin-visualizer: "npm:^5.12.0"
sass: "npm:^1.88.0" sass: "npm:^1.88.0"
selection-hook: "npm:^1.0.8" selection-hook: "npm:^1.0.9"
shiki: "npm:^3.9.1" shiki: "npm:^3.9.1"
strict-url-sanitise: "npm:^0.0.1" strict-url-sanitise: "npm:^0.0.1"
string-width: "npm:^7.2.0" string-width: "npm:^7.2.0"
@ -20066,14 +20066,14 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"selection-hook@npm:^1.0.8": "selection-hook@npm:^1.0.9":
version: 1.0.8 version: 1.0.9
resolution: "selection-hook@npm:1.0.8" resolution: "selection-hook@npm:1.0.9"
dependencies: dependencies:
node-addon-api: "npm:^8.4.0" node-addon-api: "npm:^8.4.0"
node-gyp: "npm:latest" node-gyp: "npm:latest"
node-gyp-build: "npm:^4.8.4" node-gyp-build: "npm:^4.8.4"
checksum: 10c0/ed7e230ddf10fcd1974b166c5e73170900260664e40454e4e1fcdf0ba21d2a08cf95824c085fa07069aa99b663e0ee3f2aed74c3fbdba0f4e99abe6956bd51dc checksum: 10c0/5f3114b528d9e1545a5dc4b99927a0ab441570063bb348b52784d757c8f250f0d6a875175d371adf5dc2bfc82bf6bb86f99d3ee66fefe0749040c0b50f3217c3
languageName: node languageName: node
linkType: hard linkType: hard