diff --git a/package.json b/package.json index 57f14e728d..0be12b6185 100644 --- a/package.json +++ b/package.json @@ -73,6 +73,7 @@ "@agentic/tavily": "^7.3.3", "@ant-design/v5-patch-for-react-19": "^1.0.3", "@anthropic-ai/sdk": "^0.41.0", + "@cherry-studio/ai-core": "workspace:*", "@cherrystudio/embedjs": "^0.1.31", "@cherrystudio/embedjs-libsql": "^0.1.31", "@cherrystudio/embedjs-loader-csv": "^0.1.31", diff --git a/packages/aiCore/package.json b/packages/aiCore/package.json index d1f810e404..69f73c341b 100644 --- a/packages/aiCore/package.json +++ b/packages/aiCore/package.json @@ -2,8 +2,8 @@ "name": "@cherry-studio/ai-core", "version": "1.0.0", "description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK", - "main": "dist/index.js", - "types": "dist/index.d.ts", + "main": "src/index.ts", + "types": "src/index.ts", "scripts": { "build": "tsc", "dev": "tsc -w", @@ -114,9 +114,9 @@ ], "exports": { ".": { - "types": "./dist/index.d.ts", - "import": "./dist/index.js", - "require": "./dist/index.js" + "types": "./src/index.ts", + "import": "./src/index.ts", + "require": "./src/index.ts" } } } diff --git a/packages/aiCore/src/clients/PluginEnabledAiClient.ts b/packages/aiCore/src/clients/PluginEnabledAiClient.ts index 271c29d58d..62707e5f5e 100644 --- a/packages/aiCore/src/clients/PluginEnabledAiClient.ts +++ b/packages/aiCore/src/clients/PluginEnabledAiClient.ts @@ -19,10 +19,10 @@ * }) * ``` */ - import { generateObject, generateText, streamObject, streamText } from 'ai' import { AiPlugin, createContext, PluginManager } from '../plugins' +import { isProviderSupported } from '../providers/registry' import { ApiClientFactory } from './ApiClientFactory' import { type ProviderId, type ProviderSettingsMap } from './types' import { UniversalAiSdkClient } from './UniversalAiSdkClient' @@ -178,8 +178,7 @@ export class PluginEnabledAiClient { async (finalModelId, transformedParams, streamTransforms) => { // 对于流式调用,需要直接调用 AI SDK 以支持流转换器 const model = await ApiClientFactory.createClient(this.providerId, finalModelId, this.options) - - return streamText({ + return await streamText({ model, ...transformedParams, experimental_transform: streamTransforms.length > 0 ? streamTransforms : undefined @@ -196,7 +195,7 @@ export class PluginEnabledAiClient { params: Omit[0], 'model'> ): Promise> { return this.executeWithPlugins('generateText', modelId, params, async (finalModelId, transformedParams) => { - return this.baseClient.generateText(finalModelId, transformedParams) + return await this.baseClient.generateText(finalModelId, transformedParams) }) } @@ -208,7 +207,7 @@ export class PluginEnabledAiClient { params: Omit[0], 'model'> ): Promise> { return this.executeWithPlugins('generateObject', modelId, params, async (finalModelId, transformedParams) => { - return this.baseClient.generateObject(finalModelId, transformedParams) + return await this.baseClient.generateObject(finalModelId, transformedParams) }) } @@ -221,7 +220,7 @@ export class PluginEnabledAiClient { params: Omit[0], 'model'> ): Promise> { return this.executeWithPlugins('streamObject', modelId, params, async (finalModelId, transformedParams) => { - return this.baseClient.streamObject(finalModelId, transformedParams) + return await this.baseClient.streamObject(finalModelId, transformedParams) }) } @@ -267,7 +266,7 @@ export class PluginEnabledAiClient { ): PluginEnabledAiClient<'openai-compatible'> static create(providerId: string, options: any, plugins: AiPlugin[] = []): PluginEnabledAiClient { - if (providerId in ({} as ProviderSettingsMap)) { + if (isProviderSupported(providerId)) { return new PluginEnabledAiClient(providerId as ProviderId, options, plugins) } else { // 对于未知 provider,使用 openai-compatible diff --git a/packages/aiCore/src/clients/types.ts b/packages/aiCore/src/clients/types.ts index f9c6b37fe2..424d6ca55f 100644 --- a/packages/aiCore/src/clients/types.ts +++ b/packages/aiCore/src/clients/types.ts @@ -1,38 +1,14 @@ -import { FetchFunction } from '@ai-sdk/provider-utils' +import { generateObject, generateText, streamObject, streamText } from 'ai' import type { ProviderSettingsMap } from '../providers/registry' // ProviderSettings 是所有 Provider Settings 的联合类型 export type ProviderSettings = ProviderSettingsMap[keyof ProviderSettingsMap] -// 基础 Provider 配置类型(为了向后兼容和通用场景) -export type BaseProviderSettings = { - /** - * API key for authentication - */ - apiKey?: string - /** - * Base URL for the API calls - */ - baseURL?: string - /** - * Custom headers to include in the requests - */ - headers?: Record - /** - * Optional custom url query parameters to include in request urls - */ - queryParams?: Record - /** - * Custom fetch implementation. You can use it as a middleware to intercept requests, - * or to provide a custom fetch implementation for e.g. testing. - */ - fetch?: FetchFunction - /** - * Allow additional properties for provider-specific settings - */ - [key: string]: any -} +export type StreamTextParams = Omit[0], 'model'> +export type GenerateTextParams = Omit[0], 'model'> +export type StreamObjectParams = Omit[0], 'model'> +export type GenerateObjectParams = Omit[0], 'model'> // 重新导出 ProviderSettingsMap 中的所有类型 export type { diff --git a/packages/aiCore/src/index.ts b/packages/aiCore/src/index.ts index 12da6b844d..7968645d68 100644 --- a/packages/aiCore/src/index.ts +++ b/packages/aiCore/src/index.ts @@ -39,9 +39,44 @@ export { aiProviderRegistry } from './providers/registry' // ==================== 类型定义 ==================== export type { ClientFactoryError } from './clients/ApiClientFactory' -export type { BaseProviderSettings, ProviderSettings } from './clients/types' +export type { + GenerateObjectParams, + GenerateTextParams, + ProviderSettings, + StreamObjectParams, + StreamTextParams +} from './clients/types' export type { ProviderConfig } from './providers/registry' export type { ProviderError } from './providers/types' +export * as aiSdk from 'ai' + +// ==================== AI SDK 常用类型导出 ==================== +// 直接导出 AI SDK 的常用类型,方便使用 +export type { + CoreAssistantMessage, + // 消息相关类型 + CoreMessage, + CoreSystemMessage, + CoreToolMessage, + CoreUserMessage, + // 通用类型 + FinishReason, + GenerateObjectResult, + // 生成相关类型 + GenerateTextResult, + InvalidToolArgumentsError, + LanguageModelUsage, // AI SDK 4.0 中 TokenUsage 改名为 LanguageModelUsage + // 错误类型 + NoSuchToolError, + StreamTextResult, + // 流相关类型 + TextStreamPart, + // 工具相关类型 + Tool, + ToolCall, + ToolExecutionError, + ToolResult +} from 'ai' // 重新导出所有 Provider Settings 类型 export type { diff --git a/src/renderer/src/aiCore/AiSdkToChunkAdapter.ts b/src/renderer/src/aiCore/AiSdkToChunkAdapter.ts new file mode 100644 index 0000000000..c15ebdebeb --- /dev/null +++ b/src/renderer/src/aiCore/AiSdkToChunkAdapter.ts @@ -0,0 +1,296 @@ +/** + * AI SDK 到 Cherry Studio Chunk 适配器 + * 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式 + */ + +import { TextStreamPart } from '@cherry-studio/ai-core' +import { Chunk, ChunkType } from '@renderer/types/chunk' + +export interface CherryStudioChunk { + type: 'text-delta' | 'text-complete' | 'tool-call' | 'tool-result' | 'finish' | 'error' + text?: string + toolCall?: any + toolResult?: any + finishReason?: string + usage?: any + error?: any +} + +/** + * AI SDK 到 Cherry Studio Chunk 适配器类 + * 处理 fullStream 到 Cherry Studio chunk 的转换 + */ +export class AiSdkToChunkAdapter { + constructor(private onChunk: (chunk: Chunk) => void) {} + + /** + * 处理 AI SDK 流结果 + * @param aiSdkResult AI SDK 的流结果对象 + * @returns 最终的文本内容 + */ + async processStream(aiSdkResult: any): Promise { + // 如果是流式且有 fullStream + if (aiSdkResult.fullStream) { + await this.readFullStream(aiSdkResult.fullStream) + } + + // 使用 streamResult.text 获取最终结果 + return await aiSdkResult.text + } + + /** + * 读取 fullStream 并转换为 Cherry Studio chunks + * @param fullStream AI SDK 的 fullStream (ReadableStream) + */ + private async readFullStream(fullStream: ReadableStream>) { + const reader = fullStream.getReader() + const final = { + text: '', + reasoning_content: '' + } + try { + while (true) { + const { done, value } = await reader.read() + + if (done) { + break + } + + // 转换并发送 chunk + this.convertAndEmitChunk(value, final) + } + } finally { + reader.releaseLock() + } + } + + /** + * 转换 AI SDK chunk 为 Cherry Studio chunk 并调用回调 + * @param chunk AI SDK 的 chunk 数据 + */ + private convertAndEmitChunk(chunk: TextStreamPart, final: { text: string; reasoning_content: string }) { + console.log('AI SDK chunk type:', chunk.type, chunk) + switch (chunk.type) { + // === 文本相关事件 === + case 'text-delta': + final.text += chunk.textDelta || '' + this.onChunk({ + type: ChunkType.TEXT_DELTA, + text: chunk.textDelta || '' + }) + if (final.reasoning_content) { + this.onChunk({ + type: ChunkType.THINKING_COMPLETE, + text: final.reasoning_content || '' + }) + final.reasoning_content = '' + } + break + + // === 推理相关事件 === + case 'reasoning': + final.reasoning_content += chunk.textDelta || '' + this.onChunk({ + type: ChunkType.THINKING_DELTA, + text: chunk.textDelta || '' + }) + break + + case 'reasoning-signature': + // 推理签名,可以映射到思考完成 + this.onChunk({ + type: ChunkType.THINKING_COMPLETE, + text: chunk.signature || '' + }) + break + + case 'redacted-reasoning': + // 被编辑的推理内容,也映射到思考 + this.onChunk({ + type: ChunkType.THINKING_DELTA, + text: chunk.data || '' + }) + break + + // === 工具调用相关事件 === + case 'tool-call-streaming-start': + // 开始流式工具调用 + this.onChunk({ + type: ChunkType.MCP_TOOL_CREATED, + tool_calls: [ + { + id: chunk.toolCallId, + name: chunk.toolName, + args: {} + } + ] + }) + break + + case 'tool-call-delta': + // 工具调用参数的增量更新 + this.onChunk({ + type: ChunkType.MCP_TOOL_IN_PROGRESS, + responses: [ + { + id: chunk.toolCallId, + tool: { + id: chunk.toolName, + // TODO: serverId,serverName + serverId: 'ai-sdk', + serverName: 'AI SDK', + name: chunk.toolName, + description: '', + inputSchema: { + type: 'object', + title: chunk.toolName, + properties: {} + } + }, + arguments: {}, + status: 'invoking', + response: chunk.argsTextDelta, + toolCallId: chunk.toolCallId + } + ] + }) + break + + case 'tool-call': + // 完整的工具调用 + this.onChunk({ + type: ChunkType.MCP_TOOL_CREATED, + tool_calls: [ + { + id: chunk.toolCallId, + name: chunk.toolName, + args: chunk.args + } + ] + }) + break + + case 'tool-result': + // 工具调用结果 + this.onChunk({ + type: ChunkType.MCP_TOOL_COMPLETE, + responses: [ + { + id: chunk.toolCallId, + tool: { + id: chunk.toolName, + // TODO: serverId,serverName + serverId: 'ai-sdk', + serverName: 'AI SDK', + name: chunk.toolName, + description: '', + inputSchema: { + type: 'object', + title: chunk.toolName, + properties: {} + } + }, + arguments: chunk.args || {}, + status: 'done', + response: chunk.result, + toolCallId: chunk.toolCallId + } + ] + }) + break + + // === 步骤相关事件 === + // case 'step-start': + // this.onChunk({ + // type: ChunkType.LLM_RESPONSE_CREATED + // }) + // break + case 'step-finish': + this.onChunk({ + type: ChunkType.BLOCK_COMPLETE, + response: { + text: final.text || '', + reasoning_content: final.reasoning_content || '', + usage: { + completion_tokens: chunk.usage.completionTokens || 0, + prompt_tokens: chunk.usage.promptTokens || 0, + total_tokens: chunk.usage.totalTokens || 0 + }, + metrics: chunk.usage + ? { + completion_tokens: chunk.usage.completionTokens || 0, + time_completion_millsec: 0 + } + : undefined + } + }) + break + + case 'finish': + this.onChunk({ + type: ChunkType.TEXT_COMPLETE, + text: final.text || '' // TEXT_COMPLETE 需要 text 字段 + }) + this.onChunk({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + text: final.text || '', + reasoning_content: final.reasoning_content || '', + usage: { + completion_tokens: chunk.usage.completionTokens || 0, + prompt_tokens: chunk.usage.promptTokens || 0, + total_tokens: chunk.usage.totalTokens || 0 + }, + metrics: chunk.usage + ? { + completion_tokens: chunk.usage.completionTokens || 0, + time_completion_millsec: 0 + } + : undefined + } + }) + break + + // === 源和文件相关事件 === + case 'source': + // 源信息,可以映射到知识搜索完成 + this.onChunk({ + type: ChunkType.KNOWLEDGE_SEARCH_COMPLETE, + knowledge: [ + { + id: Number(chunk.source.id) || Date.now(), + content: chunk.source.title || '', + sourceUrl: chunk.source.url || '', + type: 'url' + } + ] + }) + break + + case 'file': + // 文件相关事件,可能是图片生成 + this.onChunk({ + type: ChunkType.IMAGE_COMPLETE, + image: { + type: 'base64', + images: [chunk.base64] + } + }) + break + case 'error': + this.onChunk({ + type: ChunkType.ERROR, + error: { + message: chunk.error || 'Unknown error' + } + }) + break + + default: + // 其他类型的 chunk 可以忽略或记录日志 + console.log('Unhandled AI SDK chunk type:', chunk.type, chunk) + } + } +} + +export default AiSdkToChunkAdapter diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts new file mode 100644 index 0000000000..12ec392cec --- /dev/null +++ b/src/renderer/src/aiCore/index_new.ts @@ -0,0 +1,230 @@ +/** + * Cherry Studio AI Core - 新版本入口 + * 集成 @cherry-studio/ai-core 库的渐进式重构方案 + * + * 融合方案:简化实现,专注于核心功能 + * 1. 优先使用新AI SDK + * 2. 失败时fallback到原有实现 + * 3. 暂时保持接口兼容性 + */ + +import { + AiClient, + AiCore, + createClient, + type OpenAICompatibleProviderSettings, + type ProviderId +} from '@cherry-studio/ai-core' +import { isDedicatedImageGenerationModel } from '@renderer/config/models' +import type { GenerateImageParams, Model, Provider } from '@renderer/types' +import { Chunk, ChunkType } from '@renderer/types/chunk' +import { RequestOptions } from '@renderer/types/sdk' + +// 引入适配器 +import AiSdkToChunkAdapter from './AiSdkToChunkAdapter' +// 引入原有的AiProvider作为fallback +import LegacyAiProvider from './index' +import { CompletionsParams, CompletionsResult } from './middleware/schemas' +// 引入参数转换模块 +import { buildStreamTextParams } from './transformParameters' + +/** + * 将现有 Provider 类型映射到 AI SDK 的 Provider ID + * 根据 registry.ts 中的支持列表进行映射 + */ +function mapProviderTypeToAiSdkId(providerType: string): string { + // Cherry Studio Provider Type -> AI SDK Provider ID 映射表 + const typeMapping: Record = { + // 需要转换的映射 + grok: 'xai', // grok -> xai + 'azure-openai': 'azure', // azure-openai -> azure + gemini: 'google' // gemini -> google + } + + return typeMapping[providerType] +} + +/** + * 将 Provider 配置转换为新 AI SDK 格式 + */ +function providerToAiSdkConfig(provider: Provider): { + providerId: ProviderId | 'openai-compatible' + options: any +} { + console.log('provider', provider) + // 1. 先映射 provider 类型到 AI SDK ID + const mappedProviderId = mapProviderTypeToAiSdkId(provider.id) + + // 2. 检查映射后的 provider ID 是否在 AI SDK 注册表中 + const isSupported = AiCore.isSupported(mappedProviderId) + + console.log(`Provider mapping: ${provider.type} -> ${mappedProviderId}, supported: ${isSupported}`) + + // 3. 如果映射的 provider 不支持,则使用 openai-compatible + if (isSupported) { + return { + providerId: mappedProviderId as ProviderId, + options: { + apiKey: provider.apiKey + } + } + } else { + console.log(`Using openai-compatible fallback for provider: ${provider.type}`) + const compatibleConfig: OpenAICompatibleProviderSettings = { + name: provider.name || provider.type, + apiKey: provider.apiKey, + baseURL: provider.apiHost + } + + return { + providerId: 'openai-compatible', + options: compatibleConfig + } + } +} + +/** + * 检查是否支持使用新的AI SDK + */ +function isModernSdkSupported(provider: Provider, model?: Model): boolean { + // 目前支持主要的providers + const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai'] + + // 检查provider类型 + if (!supportedProviders.includes(provider.type)) { + return false + } + + // 检查是否为图像生成模型(暂时不支持) + if (model && isDedicatedImageGenerationModel(model)) { + return false + } + + return true +} + +export default class ModernAiProvider { + private modernClient?: AiClient + private legacyProvider: LegacyAiProvider + private provider: Provider + + constructor(provider: Provider) { + this.provider = provider + this.legacyProvider = new LegacyAiProvider(provider) + + const config = providerToAiSdkConfig(provider) + this.modernClient = createClient(config.providerId, config.options) + } + + public async completions(params: CompletionsParams, options?: RequestOptions): Promise { + // const model = params.assistant.model + + // 检查是否应该使用现代化客户端 + // if (this.modernClient && model && isModernSdkSupported(this.provider, model)) { + // try { + return await this.modernCompletions(params, options) + // } catch (error) { + // console.warn('Modern client failed, falling back to legacy:', error) + // fallback到原有实现 + // } + // } + + // 使用原有实现 + // return this.legacyProvider.completions(params, options) + } + + /** + * 使用现代化AI SDK的completions实现 + * 使用 AiSdkUtils 工具模块进行参数构建 + */ + private async modernCompletions(params: CompletionsParams, options?: RequestOptions): Promise { + if (!this.modernClient || !params.assistant.model) { + throw new Error('Modern client not available') + } + + console.log('Modern completions with params:', params, 'options:', options) + + const model = params.assistant.model + const assistant = params.assistant + + // 检查 messages 类型并转换 + const messages = Array.isArray(params.messages) ? params.messages : [] + if (typeof params.messages === 'string') { + console.warn('Messages is string, using empty array') + } + + // 使用 transformParameters 模块构建参数 + const aiSdkParams = await buildStreamTextParams(messages, assistant, model, { + maxTokens: params.maxTokens, + mcpTools: params.mcpTools + }) + + console.log('Built AI SDK params:', aiSdkParams) + const chunks: Chunk[] = [] + + try { + if (params.streamOutput && params.onChunk) { + // 流式处理 - 使用适配器 + const adapter = new AiSdkToChunkAdapter(params.onChunk) + const streamResult = await this.modernClient.streamText(model.id, aiSdkParams) + const finalText = await adapter.processStream(streamResult) + + return { + getText: () => finalText + } + } else if (params.streamOutput) { + // 流式处理但没有 onChunk 回调 + const streamResult = await this.modernClient.streamText(model.id, aiSdkParams) + const finalText = await streamResult.text + + return { + getText: () => finalText + } + } else { + // 非流式处理 + const result = await this.modernClient.generateText(model.id, aiSdkParams) + + const cherryChunk: Chunk = { + type: ChunkType.TEXT_COMPLETE, + text: result.text || '' + } + chunks.push(cherryChunk) + + if (params.onChunk) { + params.onChunk(cherryChunk) + } + + return { + getText: () => result.text || '' + } + } + } catch (error) { + console.error('Modern AI SDK error:', error) + throw error + } + } + + // 代理其他方法到原有实现 + public async models() { + return this.legacyProvider.models() + } + + public async getEmbeddingDimensions(model: Model): Promise { + return this.legacyProvider.getEmbeddingDimensions(model) + } + + public async generateImage(params: GenerateImageParams): Promise { + return this.legacyProvider.generateImage(params) + } + + public getBaseURL(): string { + return this.legacyProvider.getBaseURL() + } + + public getApiKey(): string { + return this.legacyProvider.getApiKey() + } +} + +// 为了方便调试,导出一些工具函数 +export { isModernSdkSupported, providerToAiSdkConfig } diff --git a/src/renderer/src/aiCore/transformParameters.ts b/src/renderer/src/aiCore/transformParameters.ts new file mode 100644 index 0000000000..25c5a400c6 --- /dev/null +++ b/src/renderer/src/aiCore/transformParameters.ts @@ -0,0 +1,269 @@ +/** + * AI SDK 参数转换模块 + * 统一管理从各个 apiClient 提取的参数处理和转换功能 + */ + +import type { StreamTextParams } from '@cherry-studio/ai-core' +import { isNotSupportTemperatureAndTopP, isSupportedFlexServiceTier } from '@renderer/config/models' +import type { Assistant, MCPTool, Message, Model } from '@renderer/types' +import { FileTypes } from '@renderer/types' +import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' +import { buildSystemPrompt } from '@renderer/utils/prompt' +import { defaultTimeout } from '@shared/config/constant' + +/** + * 获取温度参数 + */ +export function getTemperature(assistant: Assistant, model: Model): number | undefined { + return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.temperature +} + +/** + * 获取 TopP 参数 + */ +export function getTopP(assistant: Assistant, model: Model): number | undefined { + return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.topP +} + +/** + * 获取超时设置 + */ +export function getTimeout(model: Model): number { + if (isSupportedFlexServiceTier(model)) { + return 15 * 1000 * 60 + } + return defaultTimeout +} + +/** + * 构建系统提示词 + */ +export async function buildSystemPromptWithTools( + prompt: string, + mcpTools?: MCPTool[], + assistant?: Assistant +): Promise { + return await buildSystemPrompt(prompt, mcpTools, assistant) +} + +// /** +// * 转换 MCP 工具为 AI SDK 工具格式 +// * 注意:这里返回通用格式,实际使用时需要根据具体 provider 转换 +// TODO: 需要使用ai-sdk的mcp +// */ +// export function convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Pick { +// return mcpTools.map((tool) => ({ +// type: 'function', +// function: { +// name: tool.id, +// description: tool.description, +// parameters: tool.inputSchema || {} +// } +// })) +// } + +/** + * 提取文件内容 + */ +export async function extractFileContent(message: Message): Promise { + const fileBlocks = findFileBlocks(message) + if (fileBlocks.length > 0) { + const textFileBlocks = fileBlocks.filter( + (fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type) + ) + + if (textFileBlocks.length > 0) { + let text = '' + const divider = '\n\n---\n\n' + + for (const fileBlock of textFileBlocks) { + const file = fileBlock.file + const fileContent = (await window.api.file.read(file.id + file.ext)).trim() + const fileNameRow = 'file: ' + file.origin_name + '\n\n' + text = text + fileNameRow + fileContent + divider + } + + return text + } + } + + return '' +} + +/** + * 转换消息为 AI SDK 参数格式 + * 基于 OpenAI 格式的通用转换,支持文本、图片和文件 + */ +export async function convertMessageToSdkParam(message: Message, isVisionModel = false): Promise { + const content = getMainTextContent(message) + const fileBlocks = findFileBlocks(message) + const imageBlocks = findImageBlocks(message) + + // 简单消息(无文件无图片) + if (fileBlocks.length === 0 && imageBlocks.length === 0) { + return { + role: message.role === 'system' ? 'user' : message.role, + content + } + } + + // 复杂消息(包含文件或图片) + const parts: any[] = [] + + if (content) { + parts.push({ type: 'text', text: content }) + } + + // 处理图片(仅在支持视觉的模型中) + if (isVisionModel) { + for (const imageBlock of imageBlocks) { + if (imageBlock.file) { + try { + const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext) + parts.push({ + type: 'image_url', + image_url: { url: image.data } + }) + } catch (error) { + console.warn('Failed to load image:', error) + } + } else if (imageBlock.url && imageBlock.url.startsWith('data:')) { + parts.push({ + type: 'image_url', + image_url: { url: imageBlock.url } + }) + } + } + } + + // 处理文件 + for (const fileBlock of fileBlocks) { + const file = fileBlock.file + if (!file) continue + + if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { + try { + const fileContent = await window.api.file.read(file.id + file.ext) + parts.push({ + type: 'text', + text: `${file.origin_name}\n${fileContent.trim()}` + }) + } catch (error) { + console.warn('Failed to read file:', error) + } + } + } + + return { + role: message.role === 'system' ? 'user' : message.role, + content: parts.length === 1 && parts[0].type === 'text' ? parts[0].text : parts + } +} + +/** + * 转换 Cherry Studio 消息数组为 AI SDK 消息数组 + */ +export async function convertMessagesToSdkMessages( + messages: Message[], + model: Model +): Promise { + const sdkMessages: StreamTextParams['messages'] = [] + const isVision = model.id.includes('vision') || model.id.includes('gpt-4') // 简单的视觉模型检测 + + for (const message of messages) { + const sdkMessage = await convertMessageToSdkParam(message, isVision) + sdkMessages.push(sdkMessage) + } + + return sdkMessages +} + +/** + * 构建 AI SDK 流式参数 + * 这是主要的参数构建函数,整合所有转换逻辑 + */ +export async function buildStreamTextParams( + messages: Message[], + assistant: Assistant, + model: Model, + options: { + maxTokens?: number + mcpTools?: MCPTool[] + enableTools?: boolean + } = {} +): Promise { + const { maxTokens, mcpTools, enableTools = false } = options + + // 转换消息 + const sdkMessages = await convertMessagesToSdkMessages(messages, model) + + // 构建系统提示 + let systemPrompt = assistant.prompt || '' + if (mcpTools && mcpTools.length > 0) { + systemPrompt = await buildSystemPromptWithTools(systemPrompt, mcpTools, assistant) + } + + // 构建基础参数 + const params: StreamTextParams = { + messages: sdkMessages, + maxTokens: maxTokens || 1000, + temperature: getTemperature(assistant, model), + topP: getTopP(assistant, model), + system: systemPrompt || undefined, + ...getCustomParameters(assistant) + } + + // 添加工具(如果启用且有工具) + if (enableTools && mcpTools && mcpTools.length > 0) { + // TODO: 暂时注释掉工具支持,等类型问题解决后再启用 + // params.tools = convertMcpToolsToSdkTools(mcpTools) + } + + return params +} + +/** + * 构建非流式的 generateText 参数 + */ +export async function buildGenerateTextParams( + messages: Message[], + assistant: Assistant, + model: Model, + options: { + maxTokens?: number + mcpTools?: MCPTool[] + enableTools?: boolean + } = {} +): Promise { + // 复用流式参数的构建逻辑 + return await buildStreamTextParams(messages, assistant, model, options) +} + +/** + * 获取自定义参数 + * 从 assistant 设置中提取自定义参数 + */ +export function getCustomParameters(assistant: Assistant): Record { + return ( + assistant?.settings?.customParameters?.reduce((acc, param) => { + if (!param.name?.trim()) { + return acc + } + if (param.type === 'json') { + const value = param.value as string + if (value === 'undefined') { + return { ...acc, [param.name]: undefined } + } + try { + return { ...acc, [param.name]: JSON.parse(value) } + } catch { + return { ...acc, [param.name]: value } + } + } + return { + ...acc, + [param.name]: param.value + } + }, {}) || {} + ) +} diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 46c5fd849c..587434a4f2 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -37,7 +37,7 @@ import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils import { findLast, isEmpty, takeRight } from 'lodash' import AiProvider from '../aiCore' -import store from '../store' +import AiProviderNew from '../aiCore/index_new' import { getAssistantProvider, getAssistantSettings, @@ -313,7 +313,7 @@ export async function fetchChatCompletion({ console.log('fetchChatCompletion', messages, assistant) const provider = getAssistantProvider(assistant) - const AI = new AiProvider(provider) + const AI = new AiProviderNew(provider) // Make sure that 'Clear Context' works for all scenarios including external tool and normal chat. messages = filterContextMessages(messages) diff --git a/tsconfig.web.json b/tsconfig.web.json index ddcece6352..11b6799d61 100644 --- a/tsconfig.web.json +++ b/tsconfig.web.json @@ -4,7 +4,8 @@ "src/renderer/src/**/*", "src/preload/*.d.ts", "local/src/renderer/**/*", - "packages/shared/**/*" + "packages/shared/**/*", + "packages/aiCore/src/**/*" ], "compilerOptions": { "composite": true, @@ -14,7 +15,8 @@ "paths": { "@renderer/*": ["src/renderer/src/*"], "@shared/*": ["packages/shared/*"], - "@types": ["src/renderer/src/types/index.ts"] + "@types": ["src/renderer/src/types/index.ts"], + "@cherry-studio/ai-core": ["packages/aiCore/src/"] } } } diff --git a/yarn.lock b/yarn.lock index c4628e7160..335fcdb547 100644 --- a/yarn.lock +++ b/yarn.lock @@ -960,7 +960,7 @@ __metadata: languageName: node linkType: hard -"@cherry-studio/ai-core@workspace:packages/aiCore": +"@cherry-studio/ai-core@workspace:*, @cherry-studio/ai-core@workspace:packages/aiCore": version: 0.0.0-use.local resolution: "@cherry-studio/ai-core@workspace:packages/aiCore" dependencies: @@ -6392,6 +6392,7 @@ __metadata: "@agentic/tavily": "npm:^7.3.3" "@ant-design/v5-patch-for-react-19": "npm:^1.0.3" "@anthropic-ai/sdk": "npm:^0.41.0" + "@cherry-studio/ai-core": "workspace:*" "@cherrystudio/embedjs": "npm:^0.1.31" "@cherrystudio/embedjs-libsql": "npm:^0.1.31" "@cherrystudio/embedjs-loader-csv": "npm:^0.1.31"