From 65c15c6d87f0fc3d8225d638b6cc87df55c0d763 Mon Sep 17 00:00:00 2001 From: suyao Date: Mon, 25 Aug 2025 14:40:48 +0800 Subject: [PATCH] feat(aiCore): update ai-sdk-provider and enhance message conversion logic - Upgraded `@openrouter/ai-sdk-provider` to version ^1.1.2 in package.json and yarn.lock for improved functionality. - Enhanced `convertMessageToSdkParam` and related functions to support additional model parameters, improving message conversion for various AI models. - Integrated logging for error handling in file processing functions to aid in debugging and user feedback. - Added support for native PDF input handling based on model capabilities, enhancing file processing features. --- package.json | 2 +- .../src/aiCore/provider/providerConfigs.ts | 5 +- .../src/aiCore/transformParameters.ts | 95 +++++++++++++++++-- yarn.lock | 12 +-- 4 files changed, 93 insertions(+), 21 deletions(-) diff --git a/package.json b/package.json index d5c0338ddd..7e9fd4c67a 100644 --- a/package.json +++ b/package.json @@ -126,7 +126,7 @@ "@modelcontextprotocol/sdk": "^1.17.0", "@mozilla/readability": "^0.6.0", "@notionhq/client": "^2.2.15", - "@openrouter/ai-sdk-provider": "1.0.0-beta.6", + "@openrouter/ai-sdk-provider": "^1.1.2", "@opentelemetry/api": "^1.9.0", "@opentelemetry/core": "2.0.0", "@opentelemetry/exporter-trace-otlp-http": "^0.200.0", diff --git a/src/renderer/src/aiCore/provider/providerConfigs.ts b/src/renderer/src/aiCore/provider/providerConfigs.ts index 39bdce9551..306f86ea43 100644 --- a/src/renderer/src/aiCore/provider/providerConfigs.ts +++ b/src/renderer/src/aiCore/provider/providerConfigs.ts @@ -1,4 +1,4 @@ -import type { ProviderConfig } from '@cherrystudio/ai-core' +import { type ProviderConfig, registerMultipleProviders } from '@cherrystudio/ai-core' import { loggerService } from '@logger' const logger = loggerService.withContext('ProviderConfigs') @@ -49,9 +49,6 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & { */ export async function initializeNewProviders(): Promise { try { - // 动态导入以避免循环依赖 - const { registerMultipleProviders } = await import('@cherrystudio/ai-core') - const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS) if (successCount < NEW_PROVIDER_CONFIGS.length) { diff --git a/src/renderer/src/aiCore/transformParameters.ts b/src/renderer/src/aiCore/transformParameters.ts index a16e4a887b..105f3d0342 100644 --- a/src/renderer/src/aiCore/transformParameters.ts +++ b/src/renderer/src/aiCore/transformParameters.ts @@ -13,6 +13,7 @@ import { TextPart, UserModelMessage } from '@cherrystudio/ai-core' +import { loggerService } from '@logger' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { isGenerateImageModel, @@ -26,7 +27,7 @@ import { isVisionModel, isWebSearchModel } from '@renderer/config/models' -import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService' +import { getAssistantSettings, getDefaultModel, getProviderByModel } from '@renderer/services/AssistantService' import type { Assistant, MCPTool, Message, Model, Provider } from '@renderer/types' import { FileTypes } from '@renderer/types' import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage' @@ -39,11 +40,14 @@ import { } from '@renderer/utils/messageUtils/find' import { defaultTimeout } from '@shared/config/constant' +import { getAiSdkProviderId } from './provider/factory' // import { webSearchTool } from './tools/WebSearchTool' // import { jsonSchemaToZod } from 'json-schema-to-zod' import { setupToolsConfig } from './utils/mcp' import { buildProviderOptions } from './utils/options' +const logger = loggerService.withContext('transformParameters') + /** * 获取温度参数 */ @@ -100,15 +104,19 @@ export async function extractFileContent(message: Message): Promise { * 转换消息为 AI SDK 参数格式 * 基于 OpenAI 格式的通用转换,支持文本、图片和文件 */ -export async function convertMessageToSdkParam(message: Message, isVisionModel = false): Promise { +export async function convertMessageToSdkParam( + message: Message, + isVisionModel = false, + model?: Model +): Promise { const content = getMainTextContent(message) const fileBlocks = findFileBlocks(message) const imageBlocks = findImageBlocks(message) const reasoningBlocks = findThinkingBlocks(message) if (message.role === 'user' || message.role === 'system') { - return convertMessageToUserModelMessage(content, fileBlocks, imageBlocks, isVisionModel) + return convertMessageToUserModelMessage(content, fileBlocks, imageBlocks, isVisionModel, model) } else { - return convertMessageToAssistantModelMessage(content, fileBlocks, reasoningBlocks) + return convertMessageToAssistantModelMessage(content, fileBlocks, reasoningBlocks, model) } } @@ -116,7 +124,8 @@ async function convertMessageToUserModelMessage( content: string, fileBlocks: FileMessageBlock[], imageBlocks: ImageMessageBlock[], - isVisionModel = false + isVisionModel = false, + model?: Model ): Promise { const parts: Array = [] if (content) { @@ -135,7 +144,7 @@ async function convertMessageToUserModelMessage( mediaType: image.mime }) } catch (error) { - console.warn('Failed to load image:', error) + logger.warn('Failed to load image:', error as Error) } } else if (imageBlock.url) { parts.push({ @@ -148,6 +157,16 @@ async function convertMessageToUserModelMessage( // 处理文件 for (const fileBlock of fileBlocks) { + // 优先尝试原生文件支持(PDF等) + if (model) { + const filePart = await convertFileBlockToFilePart(fileBlock, model) + if (filePart) { + parts.push(filePart) + continue + } + } + + // 回退到文本处理 const textPart = await convertFileBlockToTextPart(fileBlock) if (textPart) { parts.push(textPart) @@ -163,7 +182,8 @@ async function convertMessageToUserModelMessage( async function convertMessageToAssistantModelMessage( content: string, fileBlocks: FileMessageBlock[], - thinkingBlocks: ThinkingMessageBlock[] + thinkingBlocks: ThinkingMessageBlock[], + model?: Model ): Promise { const parts: Array = [] if (content) { @@ -175,6 +195,16 @@ async function convertMessageToAssistantModelMessage( } for (const fileBlock of fileBlocks) { + // 优先尝试原生文件支持(PDF等) + if (model) { + const filePart = await convertFileBlockToFilePart(fileBlock, model) + if (filePart) { + parts.push(filePart) + continue + } + } + + // 回退到文本处理 const textPart = await convertFileBlockToTextPart(fileBlock) if (textPart) { parts.push(textPart) @@ -190,7 +220,7 @@ async function convertMessageToAssistantModelMessage( async function convertFileBlockToTextPart(fileBlock: FileMessageBlock): Promise { const file = fileBlock.file - if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { + if (file.type === FileTypes.TEXT) { try { const fileContent = await window.api.file.read(file.id + file.ext) return { @@ -198,7 +228,52 @@ async function convertFileBlockToTextPart(fileBlock: FileMessageBlock): Promise< text: `${file.origin_name}\n${fileContent.trim()}` } } catch (error) { - console.warn('Failed to read file:', error) + logger.warn('Failed to read file:', error as Error) + } + } + + return null +} + +/** + * 检查模型是否支持原生PDF输入 + */ +function supportsPdfInput(model: Model): boolean { + // 基于AI SDK文档,这些提供商支持PDF输入 + const supportedProviders = [ + 'openai', + 'azure-openai', + 'anthropic', + 'google', + 'google-generative-ai', + 'google-vertex', + 'bedrock', + 'amazon-bedrock' + ] + + const provider = getProviderByModel(model) + const aiSdkId = getAiSdkProviderId(provider) + + return supportedProviders.some((provider) => aiSdkId === provider) +} + +/** + * 将文件块转换为FilePart(用于原生文件支持) + */ +async function convertFileBlockToFilePart(fileBlock: FileMessageBlock, model: Model): Promise { + const file = fileBlock.file + + if (file.type === FileTypes.DOCUMENT && file.ext === '.pdf' && supportsPdfInput(model)) { + try { + const base64Data = await window.api.file.base64File(file.id + file.ext) + return { + type: 'file', + data: base64Data, + mediaType: 'application/pdf', + filename: file.origin_name + } + } catch (error) { + logger.warn('Failed to read PDF file:', error as Error) } } @@ -216,7 +291,7 @@ export async function convertMessagesToSdkMessages( const isVision = isVisionModel(model) for (const message of messages) { - const sdkMessage = await convertMessageToSdkParam(message, isVision) + const sdkMessage = await convertMessageToSdkParam(message, isVision, model) sdkMessages.push(sdkMessage) } diff --git a/yarn.lock b/yarn.lock index 55c0e12a12..03573ad797 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4803,13 +4803,13 @@ __metadata: languageName: node linkType: hard -"@openrouter/ai-sdk-provider@npm:1.0.0-beta.6": - version: 1.0.0-beta.6 - resolution: "@openrouter/ai-sdk-provider@npm:1.0.0-beta.6" +"@openrouter/ai-sdk-provider@npm:^1.1.2": + version: 1.1.2 + resolution: "@openrouter/ai-sdk-provider@npm:1.1.2" peerDependencies: - ai: ^5.0.0-beta.12 + ai: ^5.0.0 zod: ^3.24.1 || ^v4 - checksum: 10c0/7d3a7b2556b2387e6f15d25037b050f12de47c0339d43dbaac309de113d4ad7446228050fcf26747bf0b400205343c3829a072de09d4093b4cb9a190fb3a159e + checksum: 10c0/1ad50804189910d52c2c10e479bec40dfbd2109820e43135d001f4f8706be6ace532d4769a8c30111f5870afdfa97b815c7334b2e4d8d36ca68b1578ce5d9a41 languageName: node linkType: hard @@ -8896,7 +8896,7 @@ __metadata: "@modelcontextprotocol/sdk": "npm:^1.17.0" "@mozilla/readability": "npm:^0.6.0" "@notionhq/client": "npm:^2.2.15" - "@openrouter/ai-sdk-provider": "npm:1.0.0-beta.6" + "@openrouter/ai-sdk-provider": "npm:^1.1.2" "@opentelemetry/api": "npm:^1.9.0" "@opentelemetry/core": "npm:2.0.0" "@opentelemetry/exporter-trace-otlp-http": "npm:^0.200.0"