feat(aiCore): update ai-sdk-provider and enhance message conversion logic

- Upgraded `@openrouter/ai-sdk-provider` to version ^1.1.2 in package.json and yarn.lock for improved functionality.
- Enhanced `convertMessageToSdkParam` and related functions to support additional model parameters, improving message conversion for various AI models.
- Integrated logging for error handling in file processing functions to aid in debugging and user feedback.
- Added support for native PDF input handling based on model capabilities, enhancing file processing features.
This commit is contained in:
suyao 2025-08-25 14:40:48 +08:00
parent ca4e7e3d2b
commit 65c15c6d87
No known key found for this signature in database
4 changed files with 93 additions and 21 deletions

View File

@ -126,7 +126,7 @@
"@modelcontextprotocol/sdk": "^1.17.0",
"@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15",
"@openrouter/ai-sdk-provider": "1.0.0-beta.6",
"@openrouter/ai-sdk-provider": "^1.1.2",
"@opentelemetry/api": "^1.9.0",
"@opentelemetry/core": "2.0.0",
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",

View File

@ -1,4 +1,4 @@
import type { ProviderConfig } from '@cherrystudio/ai-core'
import { type ProviderConfig, registerMultipleProviders } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
const logger = loggerService.withContext('ProviderConfigs')
@ -49,9 +49,6 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
*/
export async function initializeNewProviders(): Promise<void> {
try {
// 动态导入以避免循环依赖
const { registerMultipleProviders } = await import('@cherrystudio/ai-core')
const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS)
if (successCount < NEW_PROVIDER_CONFIGS.length) {

View File

@ -13,6 +13,7 @@ import {
TextPart,
UserModelMessage
} from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {
isGenerateImageModel,
@ -26,7 +27,7 @@ import {
isVisionModel,
isWebSearchModel
} from '@renderer/config/models'
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
import { getAssistantSettings, getDefaultModel, getProviderByModel } from '@renderer/services/AssistantService'
import type { Assistant, MCPTool, Message, Model, Provider } from '@renderer/types'
import { FileTypes } from '@renderer/types'
import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage'
@ -39,11 +40,14 @@ import {
} from '@renderer/utils/messageUtils/find'
import { defaultTimeout } from '@shared/config/constant'
import { getAiSdkProviderId } from './provider/factory'
// import { webSearchTool } from './tools/WebSearchTool'
// import { jsonSchemaToZod } from 'json-schema-to-zod'
import { setupToolsConfig } from './utils/mcp'
import { buildProviderOptions } from './utils/options'
const logger = loggerService.withContext('transformParameters')
/**
*
*/
@ -100,15 +104,19 @@ export async function extractFileContent(message: Message): Promise<string> {
* AI SDK
* OpenAI
*/
export async function convertMessageToSdkParam(message: Message, isVisionModel = false): Promise<ModelMessage> {
export async function convertMessageToSdkParam(
message: Message,
isVisionModel = false,
model?: Model
): Promise<ModelMessage> {
const content = getMainTextContent(message)
const fileBlocks = findFileBlocks(message)
const imageBlocks = findImageBlocks(message)
const reasoningBlocks = findThinkingBlocks(message)
if (message.role === 'user' || message.role === 'system') {
return convertMessageToUserModelMessage(content, fileBlocks, imageBlocks, isVisionModel)
return convertMessageToUserModelMessage(content, fileBlocks, imageBlocks, isVisionModel, model)
} else {
return convertMessageToAssistantModelMessage(content, fileBlocks, reasoningBlocks)
return convertMessageToAssistantModelMessage(content, fileBlocks, reasoningBlocks, model)
}
}
@ -116,7 +124,8 @@ async function convertMessageToUserModelMessage(
content: string,
fileBlocks: FileMessageBlock[],
imageBlocks: ImageMessageBlock[],
isVisionModel = false
isVisionModel = false,
model?: Model
): Promise<UserModelMessage> {
const parts: Array<TextPart | FilePart | ImagePart> = []
if (content) {
@ -135,7 +144,7 @@ async function convertMessageToUserModelMessage(
mediaType: image.mime
})
} catch (error) {
console.warn('Failed to load image:', error)
logger.warn('Failed to load image:', error as Error)
}
} else if (imageBlock.url) {
parts.push({
@ -148,6 +157,16 @@ async function convertMessageToUserModelMessage(
// 处理文件
for (const fileBlock of fileBlocks) {
// 优先尝试原生文件支持PDF等
if (model) {
const filePart = await convertFileBlockToFilePart(fileBlock, model)
if (filePart) {
parts.push(filePart)
continue
}
}
// 回退到文本处理
const textPart = await convertFileBlockToTextPart(fileBlock)
if (textPart) {
parts.push(textPart)
@ -163,7 +182,8 @@ async function convertMessageToUserModelMessage(
async function convertMessageToAssistantModelMessage(
content: string,
fileBlocks: FileMessageBlock[],
thinkingBlocks: ThinkingMessageBlock[]
thinkingBlocks: ThinkingMessageBlock[],
model?: Model
): Promise<AssistantModelMessage> {
const parts: Array<TextPart | FilePart> = []
if (content) {
@ -175,6 +195,16 @@ async function convertMessageToAssistantModelMessage(
}
for (const fileBlock of fileBlocks) {
// 优先尝试原生文件支持PDF等
if (model) {
const filePart = await convertFileBlockToFilePart(fileBlock, model)
if (filePart) {
parts.push(filePart)
continue
}
}
// 回退到文本处理
const textPart = await convertFileBlockToTextPart(fileBlock)
if (textPart) {
parts.push(textPart)
@ -190,7 +220,7 @@ async function convertMessageToAssistantModelMessage(
async function convertFileBlockToTextPart(fileBlock: FileMessageBlock): Promise<TextPart | null> {
const file = fileBlock.file
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
if (file.type === FileTypes.TEXT) {
try {
const fileContent = await window.api.file.read(file.id + file.ext)
return {
@ -198,7 +228,52 @@ async function convertFileBlockToTextPart(fileBlock: FileMessageBlock): Promise<
text: `${file.origin_name}\n${fileContent.trim()}`
}
} catch (error) {
console.warn('Failed to read file:', error)
logger.warn('Failed to read file:', error as Error)
}
}
return null
}
/**
* PDF输入
*/
function supportsPdfInput(model: Model): boolean {
// 基于AI SDK文档这些提供商支持PDF输入
const supportedProviders = [
'openai',
'azure-openai',
'anthropic',
'google',
'google-generative-ai',
'google-vertex',
'bedrock',
'amazon-bedrock'
]
const provider = getProviderByModel(model)
const aiSdkId = getAiSdkProviderId(provider)
return supportedProviders.some((provider) => aiSdkId === provider)
}
/**
* FilePart
*/
async function convertFileBlockToFilePart(fileBlock: FileMessageBlock, model: Model): Promise<FilePart | null> {
const file = fileBlock.file
if (file.type === FileTypes.DOCUMENT && file.ext === '.pdf' && supportsPdfInput(model)) {
try {
const base64Data = await window.api.file.base64File(file.id + file.ext)
return {
type: 'file',
data: base64Data,
mediaType: 'application/pdf',
filename: file.origin_name
}
} catch (error) {
logger.warn('Failed to read PDF file:', error as Error)
}
}
@ -216,7 +291,7 @@ export async function convertMessagesToSdkMessages(
const isVision = isVisionModel(model)
for (const message of messages) {
const sdkMessage = await convertMessageToSdkParam(message, isVision)
const sdkMessage = await convertMessageToSdkParam(message, isVision, model)
sdkMessages.push(sdkMessage)
}

View File

@ -4803,13 +4803,13 @@ __metadata:
languageName: node
linkType: hard
"@openrouter/ai-sdk-provider@npm:1.0.0-beta.6":
version: 1.0.0-beta.6
resolution: "@openrouter/ai-sdk-provider@npm:1.0.0-beta.6"
"@openrouter/ai-sdk-provider@npm:^1.1.2":
version: 1.1.2
resolution: "@openrouter/ai-sdk-provider@npm:1.1.2"
peerDependencies:
ai: ^5.0.0-beta.12
ai: ^5.0.0
zod: ^3.24.1 || ^v4
checksum: 10c0/7d3a7b2556b2387e6f15d25037b050f12de47c0339d43dbaac309de113d4ad7446228050fcf26747bf0b400205343c3829a072de09d4093b4cb9a190fb3a159e
checksum: 10c0/1ad50804189910d52c2c10e479bec40dfbd2109820e43135d001f4f8706be6ace532d4769a8c30111f5870afdfa97b815c7334b2e4d8d36ca68b1578ce5d9a41
languageName: node
linkType: hard
@ -8896,7 +8896,7 @@ __metadata:
"@modelcontextprotocol/sdk": "npm:^1.17.0"
"@mozilla/readability": "npm:^0.6.0"
"@notionhq/client": "npm:^2.2.15"
"@openrouter/ai-sdk-provider": "npm:1.0.0-beta.6"
"@openrouter/ai-sdk-provider": "npm:^1.1.2"
"@opentelemetry/api": "npm:^1.9.0"
"@opentelemetry/core": "npm:2.0.0"
"@opentelemetry/exporter-trace-otlp-http": "npm:^0.200.0"