mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
feat(aiCore): update ai-sdk-provider and enhance message conversion logic
- Upgraded `@openrouter/ai-sdk-provider` to version ^1.1.2 in package.json and yarn.lock for improved functionality. - Enhanced `convertMessageToSdkParam` and related functions to support additional model parameters, improving message conversion for various AI models. - Integrated logging for error handling in file processing functions to aid in debugging and user feedback. - Added support for native PDF input handling based on model capabilities, enhancing file processing features.
This commit is contained in:
parent
ca4e7e3d2b
commit
65c15c6d87
@ -126,7 +126,7 @@
|
||||
"@modelcontextprotocol/sdk": "^1.17.0",
|
||||
"@mozilla/readability": "^0.6.0",
|
||||
"@notionhq/client": "^2.2.15",
|
||||
"@openrouter/ai-sdk-provider": "1.0.0-beta.6",
|
||||
"@openrouter/ai-sdk-provider": "^1.1.2",
|
||||
"@opentelemetry/api": "^1.9.0",
|
||||
"@opentelemetry/core": "2.0.0",
|
||||
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { ProviderConfig } from '@cherrystudio/ai-core'
|
||||
import { type ProviderConfig, registerMultipleProviders } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
|
||||
const logger = loggerService.withContext('ProviderConfigs')
|
||||
@ -49,9 +49,6 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
|
||||
*/
|
||||
export async function initializeNewProviders(): Promise<void> {
|
||||
try {
|
||||
// 动态导入以避免循环依赖
|
||||
const { registerMultipleProviders } = await import('@cherrystudio/ai-core')
|
||||
|
||||
const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS)
|
||||
|
||||
if (successCount < NEW_PROVIDER_CONFIGS.length) {
|
||||
|
||||
@ -13,6 +13,7 @@ import {
|
||||
TextPart,
|
||||
UserModelMessage
|
||||
} from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import {
|
||||
isGenerateImageModel,
|
||||
@ -26,7 +27,7 @@ import {
|
||||
isVisionModel,
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import { getAssistantSettings, getDefaultModel, getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { Assistant, MCPTool, Message, Model, Provider } from '@renderer/types'
|
||||
import { FileTypes } from '@renderer/types'
|
||||
import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage'
|
||||
@ -39,11 +40,14 @@ import {
|
||||
} from '@renderer/utils/messageUtils/find'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
|
||||
import { getAiSdkProviderId } from './provider/factory'
|
||||
// import { webSearchTool } from './tools/WebSearchTool'
|
||||
// import { jsonSchemaToZod } from 'json-schema-to-zod'
|
||||
import { setupToolsConfig } from './utils/mcp'
|
||||
import { buildProviderOptions } from './utils/options'
|
||||
|
||||
const logger = loggerService.withContext('transformParameters')
|
||||
|
||||
/**
|
||||
* 获取温度参数
|
||||
*/
|
||||
@ -100,15 +104,19 @@ export async function extractFileContent(message: Message): Promise<string> {
|
||||
* 转换消息为 AI SDK 参数格式
|
||||
* 基于 OpenAI 格式的通用转换,支持文本、图片和文件
|
||||
*/
|
||||
export async function convertMessageToSdkParam(message: Message, isVisionModel = false): Promise<ModelMessage> {
|
||||
export async function convertMessageToSdkParam(
|
||||
message: Message,
|
||||
isVisionModel = false,
|
||||
model?: Model
|
||||
): Promise<ModelMessage> {
|
||||
const content = getMainTextContent(message)
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
const reasoningBlocks = findThinkingBlocks(message)
|
||||
if (message.role === 'user' || message.role === 'system') {
|
||||
return convertMessageToUserModelMessage(content, fileBlocks, imageBlocks, isVisionModel)
|
||||
return convertMessageToUserModelMessage(content, fileBlocks, imageBlocks, isVisionModel, model)
|
||||
} else {
|
||||
return convertMessageToAssistantModelMessage(content, fileBlocks, reasoningBlocks)
|
||||
return convertMessageToAssistantModelMessage(content, fileBlocks, reasoningBlocks, model)
|
||||
}
|
||||
}
|
||||
|
||||
@ -116,7 +124,8 @@ async function convertMessageToUserModelMessage(
|
||||
content: string,
|
||||
fileBlocks: FileMessageBlock[],
|
||||
imageBlocks: ImageMessageBlock[],
|
||||
isVisionModel = false
|
||||
isVisionModel = false,
|
||||
model?: Model
|
||||
): Promise<UserModelMessage> {
|
||||
const parts: Array<TextPart | FilePart | ImagePart> = []
|
||||
if (content) {
|
||||
@ -135,7 +144,7 @@ async function convertMessageToUserModelMessage(
|
||||
mediaType: image.mime
|
||||
})
|
||||
} catch (error) {
|
||||
console.warn('Failed to load image:', error)
|
||||
logger.warn('Failed to load image:', error as Error)
|
||||
}
|
||||
} else if (imageBlock.url) {
|
||||
parts.push({
|
||||
@ -148,6 +157,16 @@ async function convertMessageToUserModelMessage(
|
||||
|
||||
// 处理文件
|
||||
for (const fileBlock of fileBlocks) {
|
||||
// 优先尝试原生文件支持(PDF等)
|
||||
if (model) {
|
||||
const filePart = await convertFileBlockToFilePart(fileBlock, model)
|
||||
if (filePart) {
|
||||
parts.push(filePart)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 回退到文本处理
|
||||
const textPart = await convertFileBlockToTextPart(fileBlock)
|
||||
if (textPart) {
|
||||
parts.push(textPart)
|
||||
@ -163,7 +182,8 @@ async function convertMessageToUserModelMessage(
|
||||
async function convertMessageToAssistantModelMessage(
|
||||
content: string,
|
||||
fileBlocks: FileMessageBlock[],
|
||||
thinkingBlocks: ThinkingMessageBlock[]
|
||||
thinkingBlocks: ThinkingMessageBlock[],
|
||||
model?: Model
|
||||
): Promise<AssistantModelMessage> {
|
||||
const parts: Array<TextPart | FilePart> = []
|
||||
if (content) {
|
||||
@ -175,6 +195,16 @@ async function convertMessageToAssistantModelMessage(
|
||||
}
|
||||
|
||||
for (const fileBlock of fileBlocks) {
|
||||
// 优先尝试原生文件支持(PDF等)
|
||||
if (model) {
|
||||
const filePart = await convertFileBlockToFilePart(fileBlock, model)
|
||||
if (filePart) {
|
||||
parts.push(filePart)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 回退到文本处理
|
||||
const textPart = await convertFileBlockToTextPart(fileBlock)
|
||||
if (textPart) {
|
||||
parts.push(textPart)
|
||||
@ -190,7 +220,7 @@ async function convertMessageToAssistantModelMessage(
|
||||
async function convertFileBlockToTextPart(fileBlock: FileMessageBlock): Promise<TextPart | null> {
|
||||
const file = fileBlock.file
|
||||
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
if (file.type === FileTypes.TEXT) {
|
||||
try {
|
||||
const fileContent = await window.api.file.read(file.id + file.ext)
|
||||
return {
|
||||
@ -198,7 +228,52 @@ async function convertFileBlockToTextPart(fileBlock: FileMessageBlock): Promise<
|
||||
text: `${file.origin_name}\n${fileContent.trim()}`
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Failed to read file:', error)
|
||||
logger.warn('Failed to read file:', error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查模型是否支持原生PDF输入
|
||||
*/
|
||||
function supportsPdfInput(model: Model): boolean {
|
||||
// 基于AI SDK文档,这些提供商支持PDF输入
|
||||
const supportedProviders = [
|
||||
'openai',
|
||||
'azure-openai',
|
||||
'anthropic',
|
||||
'google',
|
||||
'google-generative-ai',
|
||||
'google-vertex',
|
||||
'bedrock',
|
||||
'amazon-bedrock'
|
||||
]
|
||||
|
||||
const provider = getProviderByModel(model)
|
||||
const aiSdkId = getAiSdkProviderId(provider)
|
||||
|
||||
return supportedProviders.some((provider) => aiSdkId === provider)
|
||||
}
|
||||
|
||||
/**
|
||||
* 将文件块转换为FilePart(用于原生文件支持)
|
||||
*/
|
||||
async function convertFileBlockToFilePart(fileBlock: FileMessageBlock, model: Model): Promise<FilePart | null> {
|
||||
const file = fileBlock.file
|
||||
|
||||
if (file.type === FileTypes.DOCUMENT && file.ext === '.pdf' && supportsPdfInput(model)) {
|
||||
try {
|
||||
const base64Data = await window.api.file.base64File(file.id + file.ext)
|
||||
return {
|
||||
type: 'file',
|
||||
data: base64Data,
|
||||
mediaType: 'application/pdf',
|
||||
filename: file.origin_name
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('Failed to read PDF file:', error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
@ -216,7 +291,7 @@ export async function convertMessagesToSdkMessages(
|
||||
const isVision = isVisionModel(model)
|
||||
|
||||
for (const message of messages) {
|
||||
const sdkMessage = await convertMessageToSdkParam(message, isVision)
|
||||
const sdkMessage = await convertMessageToSdkParam(message, isVision, model)
|
||||
sdkMessages.push(sdkMessage)
|
||||
}
|
||||
|
||||
|
||||
12
yarn.lock
12
yarn.lock
@ -4803,13 +4803,13 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@openrouter/ai-sdk-provider@npm:1.0.0-beta.6":
|
||||
version: 1.0.0-beta.6
|
||||
resolution: "@openrouter/ai-sdk-provider@npm:1.0.0-beta.6"
|
||||
"@openrouter/ai-sdk-provider@npm:^1.1.2":
|
||||
version: 1.1.2
|
||||
resolution: "@openrouter/ai-sdk-provider@npm:1.1.2"
|
||||
peerDependencies:
|
||||
ai: ^5.0.0-beta.12
|
||||
ai: ^5.0.0
|
||||
zod: ^3.24.1 || ^v4
|
||||
checksum: 10c0/7d3a7b2556b2387e6f15d25037b050f12de47c0339d43dbaac309de113d4ad7446228050fcf26747bf0b400205343c3829a072de09d4093b4cb9a190fb3a159e
|
||||
checksum: 10c0/1ad50804189910d52c2c10e479bec40dfbd2109820e43135d001f4f8706be6ace532d4769a8c30111f5870afdfa97b815c7334b2e4d8d36ca68b1578ce5d9a41
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
@ -8896,7 +8896,7 @@ __metadata:
|
||||
"@modelcontextprotocol/sdk": "npm:^1.17.0"
|
||||
"@mozilla/readability": "npm:^0.6.0"
|
||||
"@notionhq/client": "npm:^2.2.15"
|
||||
"@openrouter/ai-sdk-provider": "npm:1.0.0-beta.6"
|
||||
"@openrouter/ai-sdk-provider": "npm:^1.1.2"
|
||||
"@opentelemetry/api": "npm:^1.9.0"
|
||||
"@opentelemetry/core": "npm:2.0.0"
|
||||
"@opentelemetry/exporter-trace-otlp-http": "npm:^0.200.0"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user