diff --git a/.yarn/patches/@ai-sdk-google-npm-2.0.20-b9102f9d54.patch b/.yarn/patches/@ai-sdk-google-npm-2.0.20-b9102f9d54.patch deleted file mode 100644 index 34babfe803..0000000000 --- a/.yarn/patches/@ai-sdk-google-npm-2.0.20-b9102f9d54.patch +++ /dev/null @@ -1,13 +0,0 @@ -diff --git a/dist/index.mjs b/dist/index.mjs -index 69ab1599c76801dc1167551b6fa283dded123466..f0af43bba7ad1196fe05338817e65b4ebda40955 100644 ---- a/dist/index.mjs -+++ b/dist/index.mjs -@@ -477,7 +477,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) { - - // src/get-model-path.ts - function getModelPath(modelId) { -- return modelId.includes("/") ? modelId : `models/${modelId}`; -+ return modelId?.includes("models/") ? modelId : `models/${modelId}`; - } - - // src/google-generative-ai-options.ts diff --git a/.yarn/patches/@ai-sdk-google-npm-2.0.23-81682e07b0.patch b/.yarn/patches/@ai-sdk-google-npm-2.0.23-81682e07b0.patch new file mode 100644 index 0000000000..ba4cd59d4c --- /dev/null +++ b/.yarn/patches/@ai-sdk-google-npm-2.0.23-81682e07b0.patch @@ -0,0 +1,26 @@ +diff --git a/dist/index.js b/dist/index.js +index 4cc66d83af1cef39f6447dc62e680251e05ddf9f..eb9819cb674c1808845ceb29936196c4bb355172 100644 +--- a/dist/index.js ++++ b/dist/index.js +@@ -471,7 +471,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) { + + // src/get-model-path.ts + function getModelPath(modelId) { +- return modelId.includes("/") ? modelId : `models/${modelId}`; ++ return modelId.includes("models/") ? modelId : `models/${modelId}`; + } + + // src/google-generative-ai-options.ts +diff --git a/dist/index.mjs b/dist/index.mjs +index a032505ec54e132dc386dde001dc51f710f84c83..5efada51b9a8b56e3f01b35e734908ebe3c37043 100644 +--- a/dist/index.mjs ++++ b/dist/index.mjs +@@ -477,7 +477,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) { + + // src/get-model-path.ts + function getModelPath(modelId) { +- return modelId.includes("/") ? modelId : `models/${modelId}`; ++ return modelId.includes("models/") ? modelId : `models/${modelId}`; + } + + // src/google-generative-ai-options.ts diff --git a/.yarn/patches/@ai-sdk-openai-npm-2.0.52-b36d949c76.patch b/.yarn/patches/@ai-sdk-openai-npm-2.0.52-b36d949c76.patch new file mode 100644 index 0000000000..a7985ddfcd --- /dev/null +++ b/.yarn/patches/@ai-sdk-openai-npm-2.0.52-b36d949c76.patch @@ -0,0 +1,76 @@ +diff --git a/dist/index.js b/dist/index.js +index cc6652c4e7f32878a64a2614115bf7eeb3b7c890..76e989017549c89b45d633525efb1f318026d9b2 100644 +--- a/dist/index.js ++++ b/dist/index.js +@@ -274,6 +274,7 @@ var openaiChatResponseSchema = (0, import_provider_utils3.lazyValidator)( + message: import_v42.z.object({ + role: import_v42.z.literal("assistant").nullish(), + content: import_v42.z.string().nullish(), ++ reasoning_content: import_v42.z.string().nullish(), + tool_calls: import_v42.z.array( + import_v42.z.object({ + id: import_v42.z.string().nullish(), +@@ -340,6 +341,7 @@ var openaiChatChunkSchema = (0, import_provider_utils3.lazyValidator)( + delta: import_v42.z.object({ + role: import_v42.z.enum(["assistant"]).nullish(), + content: import_v42.z.string().nullish(), ++ reasoning_content: import_v42.z.string().nullish(), + tool_calls: import_v42.z.array( + import_v42.z.object({ + index: import_v42.z.number(), +@@ -785,6 +787,14 @@ var OpenAIChatLanguageModel = class { + if (text != null && text.length > 0) { + content.push({ type: "text", text }); + } ++ const reasoning = ++ choice.message.reasoning_content; ++ if (reasoning != null && reasoning.length > 0) { ++ content.push({ ++ type: 'reasoning', ++ text: reasoning, ++ }); ++ } + for (const toolCall of (_a = choice.message.tool_calls) != null ? _a : []) { + content.push({ + type: "tool-call", +@@ -866,6 +876,7 @@ var OpenAIChatLanguageModel = class { + }; + let isFirstChunk = true; + let isActiveText = false; ++ let isActiveReasoning = false; + const providerMetadata = { openai: {} }; + return { + stream: response.pipeThrough( +@@ -920,6 +931,22 @@ var OpenAIChatLanguageModel = class { + return; + } + const delta = choice.delta; ++ const reasoningContent = delta.reasoning_content; ++ if (reasoningContent) { ++ if (!isActiveReasoning) { ++ controller.enqueue({ ++ type: 'reasoning-start', ++ id: 'reasoning-0', ++ }); ++ isActiveReasoning = true; ++ } ++ ++ controller.enqueue({ ++ type: 'reasoning-delta', ++ id: 'reasoning-0', ++ delta: reasoningContent, ++ }); ++ } + if (delta.content != null) { + if (!isActiveText) { + controller.enqueue({ type: "text-start", id: "0" }); +@@ -1032,6 +1059,9 @@ var OpenAIChatLanguageModel = class { + } + }, + flush(controller) { ++ if (isActiveReasoning) { ++ controller.enqueue({ type: 'reasoning-end', id: 'reasoning-0' }); ++ } + if (isActiveText) { + controller.enqueue({ type: "text-end", id: "0" }); + } diff --git a/package.json b/package.json index 0e0176e95a..1d1213a078 100644 --- a/package.json +++ b/package.json @@ -103,8 +103,8 @@ "@agentic/exa": "^7.3.3", "@agentic/searxng": "^7.3.3", "@agentic/tavily": "^7.3.3", - "@ai-sdk/amazon-bedrock": "^3.0.35", - "@ai-sdk/google-vertex": "^3.0.40", + "@ai-sdk/amazon-bedrock": "^3.0.42", + "@ai-sdk/google-vertex": "^3.0.48", "@ai-sdk/huggingface": "patch:@ai-sdk/huggingface@npm%3A0.0.4#~/.yarn/patches/@ai-sdk-huggingface-npm-0.0.4-8080836bc1.patch", "@ai-sdk/mistral": "^2.0.19", "@ai-sdk/perplexity": "^2.0.13", @@ -227,7 +227,7 @@ "@viz-js/lang-dot": "^1.0.5", "@viz-js/viz": "^3.14.0", "@xyflow/react": "^12.4.4", - "ai": "^5.0.68", + "ai": "^5.0.76", "antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch", "archiver": "^7.0.1", "async-mutex": "^0.5.0", @@ -390,7 +390,8 @@ "undici": "6.21.2", "vite": "npm:rolldown-vite@7.1.5", "tesseract.js@npm:*": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch", - "@ai-sdk/google@npm:2.0.20": "patch:@ai-sdk/google@npm%3A2.0.20#~/.yarn/patches/@ai-sdk-google-npm-2.0.20-b9102f9d54.patch", + "@ai-sdk/google@npm:2.0.23": "patch:@ai-sdk/google@npm%3A2.0.23#~/.yarn/patches/@ai-sdk-google-npm-2.0.23-81682e07b0.patch", + "@ai-sdk/openai@npm:^2.0.52": "patch:@ai-sdk/openai@npm%3A2.0.52#~/.yarn/patches/@ai-sdk-openai-npm-2.0.52-b36d949c76.patch", "@img/sharp-darwin-arm64": "0.34.3", "@img/sharp-darwin-x64": "0.34.3", "@img/sharp-linux-arm": "0.34.3", diff --git a/packages/aiCore/package.json b/packages/aiCore/package.json index eb9d000929..8310b4164c 100644 --- a/packages/aiCore/package.json +++ b/packages/aiCore/package.json @@ -36,10 +36,10 @@ "ai": "^5.0.26" }, "dependencies": { - "@ai-sdk/anthropic": "^2.0.27", - "@ai-sdk/azure": "^2.0.49", + "@ai-sdk/anthropic": "^2.0.32", + "@ai-sdk/azure": "^2.0.53", "@ai-sdk/deepseek": "^1.0.23", - "@ai-sdk/openai": "^2.0.48", + "@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.52#~/.yarn/patches/@ai-sdk-openai-npm-2.0.52-b36d949c76.patch", "@ai-sdk/openai-compatible": "^1.0.22", "@ai-sdk/provider": "^2.0.0", "@ai-sdk/provider-utils": "^3.0.12", diff --git a/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts b/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts index 7435ad2bb0..f1a2559af7 100644 --- a/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts +++ b/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts @@ -1,7 +1,6 @@ import type { BaseEmbeddings } from '@cherrystudio/embedjs-interfaces' import { OllamaEmbeddings } from '@cherrystudio/embedjs-ollama' import { OpenAiEmbeddings } from '@cherrystudio/embedjs-openai' -import { AzureOpenAiEmbeddings } from '@cherrystudio/embedjs-openai/src/azure-openai-embeddings' import { ApiClient } from '@types' import { VoyageEmbeddings } from './VoyageEmbeddings' @@ -9,7 +8,7 @@ import { VoyageEmbeddings } from './VoyageEmbeddings' export default class EmbeddingsFactory { static create({ embedApiClient, dimensions }: { embedApiClient: ApiClient; dimensions?: number }): BaseEmbeddings { const batchSize = 10 - const { model, provider, apiKey, apiVersion, baseURL } = embedApiClient + const { model, provider, apiKey, baseURL } = embedApiClient if (provider === 'voyageai') { return new VoyageEmbeddings({ modelName: model, @@ -38,16 +37,7 @@ export default class EmbeddingsFactory { } }) } - if (apiVersion !== undefined) { - return new AzureOpenAiEmbeddings({ - azureOpenAIApiKey: apiKey, - azureOpenAIApiVersion: apiVersion, - azureOpenAIApiDeploymentName: model, - azureOpenAIEndpoint: baseURL, - dimensions, - batchSize - }) - } + // NOTE: Azure OpenAI 也走 OpenAIEmbeddings, baseURL是https://xxxx.openai.azure.com/openai/v1 return new OpenAiEmbeddings({ model, apiKey, diff --git a/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts b/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts index 39aefc6a5f..4d18024f89 100644 --- a/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts +++ b/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts @@ -6,7 +6,14 @@ import { loggerService } from '@logger' import { processKnowledgeReferences } from '@renderer/services/KnowledgeService' -import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types' +import { + BaseTool, + MCPCallToolResponse, + MCPTool, + MCPToolResponse, + MCPToolResultContent, + NormalToolResponse +} from '@renderer/types' import { Chunk, ChunkType } from '@renderer/types/chunk' import type { ToolSet, TypedToolCall, TypedToolError, TypedToolResult } from 'ai' @@ -254,6 +261,7 @@ export class ToolCallChunkHandler { type: 'tool-result' } & TypedToolResult ): void { + // TODO: 基于AI SDK为供应商内置工具做更好的展示和类型安全处理 const { toolCallId, output, input } = chunk if (!toolCallId) { @@ -299,12 +307,7 @@ export class ToolCallChunkHandler { responses: [toolResponse] }) - const images: string[] = [] - for (const content of toolResponse.response?.content || []) { - if (content.type === 'image' && content.data) { - images.push(`data:${content.mimeType};base64,${content.data}`) - } - } + const images = extractImagesFromToolOutput(toolResponse.response) if (images.length) { this.onChunk({ @@ -351,3 +354,41 @@ export class ToolCallChunkHandler { } export const addActiveToolCall = ToolCallChunkHandler.addActiveToolCall.bind(ToolCallChunkHandler) + +function extractImagesFromToolOutput(output: unknown): string[] { + if (!output) { + return [] + } + + const contents: unknown[] = [] + + if (isMcpCallToolResponse(output)) { + contents.push(...output.content) + } else if (Array.isArray(output)) { + contents.push(...output) + } else if (hasContentArray(output)) { + contents.push(...output.content) + } + + return contents + .filter(isMcpImageContent) + .map((content) => `data:${content.mimeType ?? 'image/png'};base64,${content.data}`) +} + +function isMcpCallToolResponse(value: unknown): value is MCPCallToolResponse { + return typeof value === 'object' && value !== null && Array.isArray((value as MCPCallToolResponse).content) +} + +function hasContentArray(value: unknown): value is { content: unknown[] } { + return typeof value === 'object' && value !== null && Array.isArray((value as { content?: unknown }).content) +} + +function isMcpImageContent(content: unknown): content is MCPToolResultContent & { data: string } { + if (typeof content !== 'object' || content === null) { + return false + } + + const resultContent = content as MCPToolResultContent + + return resultContent.type === 'image' && typeof resultContent.data === 'string' +} diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 6211b70314..f6d6673cb9 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -14,6 +14,7 @@ import { addSpan, endSpan } from '@renderer/services/SpanManagerService' import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity' import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types' import type { AiSdkModel, StreamTextParams } from '@renderer/types/aiCoreTypes' +import { SUPPORTED_IMAGE_ENDPOINT_LIST } from '@renderer/utils' import { buildClaudeCodeSystemModelMessage } from '@shared/anthropic' import { type ImageModel, type LanguageModel, type Provider as AiSdkProvider, wrapLanguageModel } from 'ai' @@ -77,7 +78,7 @@ export default class ModernAiProvider { return this.actualProvider } - public async completions(modelId: string, params: StreamTextParams, config: ModernAiProviderConfig) { + public async completions(modelId: string, params: StreamTextParams, providerConfig: ModernAiProviderConfig) { // 检查model是否存在 if (!this.model) { throw new Error('Model is required for completions. Please use constructor with model parameter.') @@ -85,7 +86,10 @@ export default class ModernAiProvider { // 每次请求时重新生成配置以确保API key轮换生效 this.config = providerToAiSdkConfig(this.actualProvider, this.model) - + logger.debug('Generated provider config for completions', this.config) + if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) { + providerConfig.isImageGenerationEndpoint = true + } // 准备特殊配置 await prepareSpecialProviderConfig(this.actualProvider, this.config) @@ -96,13 +100,13 @@ export default class ModernAiProvider { // 提前构建中间件 const middlewares = buildAiSdkMiddlewares({ - ...config, + ...providerConfig, provider: this.actualProvider, - assistant: config.assistant + assistant: providerConfig.assistant }) logger.debug('Built middlewares in completions', { middlewareCount: middlewares.length, - isImageGeneration: config.isImageGenerationEndpoint + isImageGeneration: providerConfig.isImageGenerationEndpoint }) if (!this.localProvider) { throw new Error('Local provider not created') @@ -110,7 +114,7 @@ export default class ModernAiProvider { // 根据endpoint类型创建对应的模型 let model: AiSdkModel | undefined - if (config.isImageGenerationEndpoint) { + if (providerConfig.isImageGenerationEndpoint) { model = this.localProvider.imageModel(modelId) } else { model = this.localProvider.languageModel(modelId) @@ -126,15 +130,15 @@ export default class ModernAiProvider { params.messages = [...claudeCodeSystemMessage, ...(params.messages || [])] } - if (config.topicId && getEnableDeveloperMode()) { + if (providerConfig.topicId && getEnableDeveloperMode()) { // TypeScript类型窄化:确保topicId是string类型 const traceConfig = { - ...config, - topicId: config.topicId + ...providerConfig, + topicId: providerConfig.topicId } return await this._completionsForTrace(model, params, traceConfig) } else { - return await this._completionsOrImageGeneration(model, params, config) + return await this._completionsOrImageGeneration(model, params, providerConfig) } } diff --git a/src/renderer/src/aiCore/legacy/clients/__tests__/ApiClientFactory.test.ts b/src/renderer/src/aiCore/legacy/clients/__tests__/ApiClientFactory.test.ts index 550486afb2..b8870fdb7c 100644 --- a/src/renderer/src/aiCore/legacy/clients/__tests__/ApiClientFactory.test.ts +++ b/src/renderer/src/aiCore/legacy/clients/__tests__/ApiClientFactory.test.ts @@ -1,5 +1,4 @@ import { Provider } from '@renderer/types' -import { isOpenAIProvider } from '@renderer/utils' import { beforeEach, describe, expect, it, vi } from 'vitest' import { AihubmixAPIClient } from '../aihubmix/AihubmixAPIClient' @@ -202,36 +201,4 @@ describe('ApiClientFactory', () => { expect(client).toBeDefined() }) }) - - describe('isOpenAIProvider', () => { - it('should return true for openai type', () => { - const provider = createTestProvider('openai', 'openai') - expect(isOpenAIProvider(provider)).toBe(true) - }) - - it('should return true for azure-openai type', () => { - const provider = createTestProvider('azure-openai', 'azure-openai') - expect(isOpenAIProvider(provider)).toBe(true) - }) - - it('should return true for unknown type (fallback to OpenAI)', () => { - const provider = createTestProvider('unknown', 'unknown') - expect(isOpenAIProvider(provider)).toBe(true) - }) - - it('should return false for vertexai type', () => { - const provider = createTestProvider('vertex', 'vertexai') - expect(isOpenAIProvider(provider)).toBe(false) - }) - - it('should return false for anthropic type', () => { - const provider = createTestProvider('anthropic', 'anthropic') - expect(isOpenAIProvider(provider)).toBe(false) - }) - - it('should return false for gemini type', () => { - const provider = createTestProvider('gemini', 'gemini') - expect(isOpenAIProvider(provider)).toBe(false) - }) - }) }) diff --git a/src/renderer/src/aiCore/plugins/PluginBuilder.ts b/src/renderer/src/aiCore/plugins/PluginBuilder.ts index 7767564bd9..c249142330 100644 --- a/src/renderer/src/aiCore/plugins/PluginBuilder.ts +++ b/src/renderer/src/aiCore/plugins/PluginBuilder.ts @@ -1,5 +1,5 @@ import { AiPlugin } from '@cherrystudio/ai-core' -import { createPromptToolUsePlugin, googleToolsPlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins' +import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins' import { loggerService } from '@logger' import { getEnableDeveloperMode } from '@renderer/hooks/useSettings' import { Assistant } from '@renderer/types' @@ -68,9 +68,9 @@ export function buildPlugins( ) } - if (middlewareConfig.enableUrlContext) { - plugins.push(googleToolsPlugin({ urlContext: true })) - } + // if (middlewareConfig.enableUrlContext && middlewareConfig.) { + // plugins.push(googleToolsPlugin({ urlContext: true })) + // } logger.debug( 'Final plugin list:', diff --git a/src/renderer/src/aiCore/prepareParams/fileProcessor.ts b/src/renderer/src/aiCore/prepareParams/fileProcessor.ts index 9e46f0c627..49048122c2 100644 --- a/src/renderer/src/aiCore/prepareParams/fileProcessor.ts +++ b/src/renderer/src/aiCore/prepareParams/fileProcessor.ts @@ -114,7 +114,7 @@ export async function handleGeminiFileUpload(file: FileMetadata, model: Model): } /** - * 处理OpenAI大文件上传 + * 处理OpenAI兼容大文件上传 */ export async function handleOpenAILargeFileUpload( file: FileMetadata, diff --git a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts index b53293ea88..c693ed235d 100644 --- a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts +++ b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts @@ -3,6 +3,8 @@ * 构建AI SDK的流式和非流式参数 */ +import { anthropic } from '@ai-sdk/anthropic' +import { google } from '@ai-sdk/google' import { vertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge' import { vertex } from '@ai-sdk/google-vertex/edge' import { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins' @@ -97,10 +99,6 @@ export async function buildStreamTextParams( let tools = setupToolsConfig(mcpTools) - // if (webSearchProviderId) { - // tools['builtin_web_search'] = webSearchTool(webSearchProviderId) - // } - // 构建真正的 providerOptions const webSearchConfig: CherryWebSearchConfig = { maxResults: store.getState().websearch.maxResults, @@ -143,12 +141,34 @@ export async function buildStreamTextParams( } } - // google-vertex - if (enableUrlContext && aiSdkProviderId === 'google-vertex') { + if (enableUrlContext) { if (!tools) { tools = {} } - tools.url_context = vertex.tools.urlContext({}) as ProviderDefinedTool + const blockedDomains = mapRegexToPatterns(webSearchConfig.excludeDomains) + + switch (aiSdkProviderId) { + case 'google-vertex': + tools.url_context = vertex.tools.urlContext({}) as ProviderDefinedTool + break + case 'google': + tools.url_context = google.tools.urlContext({}) as ProviderDefinedTool + break + case 'anthropic': + case 'google-vertex-anthropic': + tools.web_fetch = ( + aiSdkProviderId === 'anthropic' + ? anthropic.tools.webFetch_20250910({ + maxUses: webSearchConfig.maxResults, + blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined + }) + : vertexAnthropic.tools.webFetch_20250910({ + maxUses: webSearchConfig.maxResults, + blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined + }) + ) as ProviderDefinedTool + break + } } // 构建基础参数 diff --git a/src/renderer/src/aiCore/provider/config/aihubmix.ts b/src/renderer/src/aiCore/provider/config/aihubmix.ts index d0ac83dc66..819e9cd28b 100644 --- a/src/renderer/src/aiCore/provider/config/aihubmix.ts +++ b/src/renderer/src/aiCore/provider/config/aihubmix.ts @@ -32,7 +32,8 @@ const AIHUBMIX_RULES: RuleSet = { match: (model) => (startsWith('gemini')(model) || startsWith('imagen')(model)) && !model.id.endsWith('-nothink') && - !model.id.endsWith('-search'), + !model.id.endsWith('-search') && + !model.id.includes('embedding'), provider: (provider: Provider) => { return extraProviderConfig({ ...provider, diff --git a/src/renderer/src/aiCore/provider/providerConfig.ts b/src/renderer/src/aiCore/provider/providerConfig.ts index 0a3bbc7b58..15323751c2 100644 --- a/src/renderer/src/aiCore/provider/providerConfig.ts +++ b/src/renderer/src/aiCore/provider/providerConfig.ts @@ -6,26 +6,28 @@ import { type ProviderSettingsMap } from '@cherrystudio/ai-core/provider' import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models' -import { isNewApiProvider } from '@renderer/config/providers' +import { + isAnthropicProvider, + isAzureOpenAIProvider, + isGeminiProvider, + isNewApiProvider +} from '@renderer/config/providers' import { getAwsBedrockAccessKeyId, getAwsBedrockRegion, getAwsBedrockSecretAccessKey } from '@renderer/hooks/useAwsBedrock' -import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI' +import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' import { getProviderByModel } from '@renderer/services/AssistantService' -import { loggerService } from '@renderer/services/LoggerService' import store from '@renderer/store' -import { isSystemProvider, type Model, type Provider } from '@renderer/types' -import { formatApiHost } from '@renderer/utils/api' -import { cloneDeep, trim } from 'lodash' +import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types' +import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api' +import { cloneDeep } from 'lodash' import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config' import { COPILOT_DEFAULT_HEADERS } from './constants' import { getAiSdkProviderId } from './factory' -const logger = loggerService.withContext('ProviderConfigProcessor') - /** * 获取轮询的API key * 复用legacy架构的多key轮询逻辑 @@ -56,13 +58,6 @@ function getRotatedApiKey(provider: Provider): string { * 处理特殊provider的转换逻辑 */ function handleSpecialProviders(model: Model, provider: Provider): Provider { - // if (provider.type === 'vertexai' && !isVertexProvider(provider)) { - // if (!isVertexAIConfigured()) { - // throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.') - // } - // return createVertexProvider(provider) - // } - if (isNewApiProvider(provider)) { return newApiResolverCreator(model, provider) } @@ -79,43 +74,30 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider { } /** - * 格式化provider的API Host + * 主要用来对齐AISdk的BaseURL格式 + * @param provider + * @returns */ -function formatAnthropicApiHost(host: string): string { - const trimmedHost = host?.trim() - - if (!trimmedHost) { - return '' - } - - if (trimmedHost.endsWith('/')) { - return trimmedHost - } - - if (trimmedHost.endsWith('/v1')) { - return `${trimmedHost}/` - } - - return formatApiHost(trimmedHost) -} - function formatProviderApiHost(provider: Provider): Provider { const formatted = { ...provider } if (formatted.anthropicApiHost) { - formatted.anthropicApiHost = formatAnthropicApiHost(formatted.anthropicApiHost) + formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost) } - if (formatted.type === 'anthropic') { + if (isAnthropicProvider(provider)) { const baseHost = formatted.anthropicApiHost || formatted.apiHost - formatted.apiHost = formatAnthropicApiHost(baseHost) + formatted.apiHost = formatApiHost(baseHost) if (!formatted.anthropicApiHost) { formatted.anthropicApiHost = formatted.apiHost } - } else if (formatted.id === 'copilot') { - const trimmed = trim(formatted.apiHost) - formatted.apiHost = trimmed.endsWith('/') ? trimmed.slice(0, -1) : trimmed - } else if (formatted.type === 'gemini') { - formatted.apiHost = formatApiHost(formatted.apiHost, 'v1beta') + } else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) { + formatted.apiHost = formatApiHost(formatted.apiHost, false) + } else if (isGeminiProvider(formatted)) { + formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta') + } else if (isAzureOpenAIProvider(formatted)) { + formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost) + } else if (isVertexProvider(formatted)) { + formatted.apiHost = formatVertexApiHost(formatted) } else { formatted.apiHost = formatApiHost(formatted.apiHost) } @@ -149,15 +131,15 @@ export function providerToAiSdkConfig( options: ProviderSettingsMap[keyof ProviderSettingsMap] } { const aiSdkProviderId = getAiSdkProviderId(actualProvider) - logger.debug('providerToAiSdkConfig', { aiSdkProviderId }) // 构建基础配置 + const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost) const baseConfig = { - baseURL: trim(actualProvider.apiHost), + baseURL: baseURL, apiKey: getRotatedApiKey(actualProvider) } - const isCopilotProvider = actualProvider.id === 'copilot' + const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot if (isCopilotProvider) { const storedHeaders = store.getState().copilot.defaultHeaders ?? {} const options = ProviderConfigFactory.fromProvider('github-copilot-openai-compatible', baseConfig, { @@ -178,6 +160,7 @@ export function providerToAiSdkConfig( // 处理OpenAI模式 const extraOptions: any = {} + extraOptions.endpoint = endpoint if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) { extraOptions.mode = 'responses' } else if (aiSdkProviderId === 'openai') { @@ -199,13 +182,11 @@ export function providerToAiSdkConfig( } // azure if (aiSdkProviderId === 'azure' || actualProvider.type === 'azure-openai') { - extraOptions.apiVersion = actualProvider.apiVersion - baseConfig.baseURL += '/openai' + // extraOptions.apiVersion = actualProvider.apiVersion 默认使用v1,不使用azure endpoint if (actualProvider.apiVersion === 'preview') { extraOptions.mode = 'responses' } else { extraOptions.mode = 'chat' - extraOptions.useDeploymentBasedUrls = true } } @@ -227,22 +208,7 @@ export function providerToAiSdkConfig( ...googleCredentials, privateKey: formatPrivateKey(googleCredentials.privateKey) } - // extraOptions.headers = window.api.vertexAI.getAuthHeaders({ - // projectId: project, - // serviceAccount: { - // privateKey: googleCredentials.privateKey, - // clientEmail: googleCredentials.clientEmail - // } - // }) - if (baseConfig.baseURL.endsWith('/v1/')) { - baseConfig.baseURL = baseConfig.baseURL.slice(0, -4) - } else if (baseConfig.baseURL.endsWith('/v1')) { - baseConfig.baseURL = baseConfig.baseURL.slice(0, -3) - } - - if (baseConfig.baseURL && !baseConfig.baseURL.includes('publishers/google')) { - baseConfig.baseURL = `${baseConfig.baseURL}/v1/projects/${project}/locations/${location}/publishers/google` - } + baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models' } if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { diff --git a/src/renderer/src/components/DraggableList/__tests__/DraggableList.test.tsx b/src/renderer/src/components/DraggableList/__tests__/DraggableList.test.tsx index 87765a504b..d4319d4df8 100644 --- a/src/renderer/src/components/DraggableList/__tests__/DraggableList.test.tsx +++ b/src/renderer/src/components/DraggableList/__tests__/DraggableList.test.tsx @@ -5,6 +5,16 @@ import { describe, expect, it, vi } from 'vitest' import { DraggableList } from '../' +vi.mock('@renderer/store', () => ({ + default: { + getState: () => ({ + llm: { + settings: {} + } + }) + } +})) + // mock @hello-pangea/dnd 组件 vi.mock('@hello-pangea/dnd', () => { return { diff --git a/src/renderer/src/components/DraggableList/__tests__/DraggableVirtualList.test.tsx b/src/renderer/src/components/DraggableList/__tests__/DraggableVirtualList.test.tsx index d931d961b8..610f6bb780 100644 --- a/src/renderer/src/components/DraggableList/__tests__/DraggableVirtualList.test.tsx +++ b/src/renderer/src/components/DraggableList/__tests__/DraggableVirtualList.test.tsx @@ -3,6 +3,16 @@ import { describe, expect, it, vi } from 'vitest' import { DraggableVirtualList } from '../' +vi.mock('@renderer/store', () => ({ + default: { + getState: () => ({ + llm: { + settings: {} + } + }) + } +})) + // Mock 依赖项 vi.mock('@hello-pangea/dnd', () => ({ __esModule: true, diff --git a/src/renderer/src/components/Preview/utils.ts b/src/renderer/src/components/Preview/utils.ts index a209a6b4c8..7fa67b9f0b 100644 --- a/src/renderer/src/components/Preview/utils.ts +++ b/src/renderer/src/components/Preview/utils.ts @@ -1,4 +1,4 @@ -import { makeSvgSizeAdaptive } from '@renderer/utils' +import { makeSvgSizeAdaptive } from '@renderer/utils/image' import DOMPurify from 'dompurify' /** diff --git a/src/renderer/src/components/Selector.tsx b/src/renderer/src/components/Selector.tsx index bffe2b2eaf..58b587f6e3 100644 --- a/src/renderer/src/components/Selector.tsx +++ b/src/renderer/src/components/Selector.tsx @@ -16,6 +16,7 @@ interface BaseSelectorProps { options: SelectorOption[] placeholder?: string placement?: 'topLeft' | 'topCenter' | 'topRight' | 'bottomLeft' | 'bottomCenter' | 'bottomRight' | 'top' | 'bottom' + style?: React.CSSProperties /** 字体大小 */ size?: number /** 是否禁用 */ @@ -43,6 +44,7 @@ const Selector = ({ placement = 'bottomRight', size = 13, placeholder, + style, disabled = false, multiple = false }: SelectorProps) => { @@ -135,7 +137,7 @@ const Selector = ({ placement={placement} open={open && !disabled} onOpenChange={handleOpenChange}> -