feat: integrate @cherry-studio/ai-core and enhance AI SDK support

- Added @cherry-studio/ai-core as a workspace dependency in package.json for improved modularity.
- Updated tsconfig to include paths for the new AI core package, enhancing type resolution.
- Refactored aiCore package to use source files directly, improving build efficiency.
- Introduced a new AiSdkToChunkAdapter for converting AI SDK streams to Cherry Studio chunk format.
- Implemented a modernized AI provider interface in index_new.ts, allowing fallback to legacy implementations.
- Enhanced parameter transformation logic for better integration with AI SDK features.
- Updated ApiService to utilize the new AI provider, streamlining chat completion requests.
This commit is contained in:
MyPrototypeWhat 2025-06-19 18:55:59 +08:00
parent 1c5a30cf49
commit 43d55b7e45
11 changed files with 856 additions and 47 deletions

View File

@ -73,6 +73,7 @@
"@agentic/tavily": "^7.3.3",
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@cherry-studio/ai-core": "workspace:*",
"@cherrystudio/embedjs": "^0.1.31",
"@cherrystudio/embedjs-libsql": "^0.1.31",
"@cherrystudio/embedjs-loader-csv": "^0.1.31",

View File

@ -2,8 +2,8 @@
"name": "@cherry-studio/ai-core",
"version": "1.0.0",
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
"main": "dist/index.js",
"types": "dist/index.d.ts",
"main": "src/index.ts",
"types": "src/index.ts",
"scripts": {
"build": "tsc",
"dev": "tsc -w",
@ -114,9 +114,9 @@
],
"exports": {
".": {
"types": "./dist/index.d.ts",
"import": "./dist/index.js",
"require": "./dist/index.js"
"types": "./src/index.ts",
"import": "./src/index.ts",
"require": "./src/index.ts"
}
}
}

View File

@ -19,10 +19,10 @@
* })
* ```
*/
import { generateObject, generateText, streamObject, streamText } from 'ai'
import { AiPlugin, createContext, PluginManager } from '../plugins'
import { isProviderSupported } from '../providers/registry'
import { ApiClientFactory } from './ApiClientFactory'
import { type ProviderId, type ProviderSettingsMap } from './types'
import { UniversalAiSdkClient } from './UniversalAiSdkClient'
@ -178,8 +178,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
async (finalModelId, transformedParams, streamTransforms) => {
// 对于流式调用,需要直接调用 AI SDK 以支持流转换器
const model = await ApiClientFactory.createClient(this.providerId, finalModelId, this.options)
return streamText({
return await streamText({
model,
...transformedParams,
experimental_transform: streamTransforms.length > 0 ? streamTransforms : undefined
@ -196,7 +195,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
params: Omit<Parameters<typeof generateText>[0], 'model'>
): Promise<ReturnType<typeof generateText>> {
return this.executeWithPlugins('generateText', modelId, params, async (finalModelId, transformedParams) => {
return this.baseClient.generateText(finalModelId, transformedParams)
return await this.baseClient.generateText(finalModelId, transformedParams)
})
}
@ -208,7 +207,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
params: Omit<Parameters<typeof generateObject>[0], 'model'>
): Promise<ReturnType<typeof generateObject>> {
return this.executeWithPlugins('generateObject', modelId, params, async (finalModelId, transformedParams) => {
return this.baseClient.generateObject(finalModelId, transformedParams)
return await this.baseClient.generateObject(finalModelId, transformedParams)
})
}
@ -221,7 +220,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
params: Omit<Parameters<typeof streamObject>[0], 'model'>
): Promise<ReturnType<typeof streamObject>> {
return this.executeWithPlugins('streamObject', modelId, params, async (finalModelId, transformedParams) => {
return this.baseClient.streamObject(finalModelId, transformedParams)
return await this.baseClient.streamObject(finalModelId, transformedParams)
})
}
@ -267,7 +266,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
): PluginEnabledAiClient<'openai-compatible'>
static create(providerId: string, options: any, plugins: AiPlugin[] = []): PluginEnabledAiClient {
if (providerId in ({} as ProviderSettingsMap)) {
if (isProviderSupported(providerId)) {
return new PluginEnabledAiClient(providerId as ProviderId, options, plugins)
} else {
// 对于未知 provider使用 openai-compatible

View File

@ -1,38 +1,14 @@
import { FetchFunction } from '@ai-sdk/provider-utils'
import { generateObject, generateText, streamObject, streamText } from 'ai'
import type { ProviderSettingsMap } from '../providers/registry'
// ProviderSettings 是所有 Provider Settings 的联合类型
export type ProviderSettings = ProviderSettingsMap[keyof ProviderSettingsMap]
// 基础 Provider 配置类型(为了向后兼容和通用场景)
export type BaseProviderSettings = {
/**
* API key for authentication
*/
apiKey?: string
/**
* Base URL for the API calls
*/
baseURL?: string
/**
* Custom headers to include in the requests
*/
headers?: Record<string, string>
/**
* Optional custom url query parameters to include in request urls
*/
queryParams?: Record<string, string>
/**
* Custom fetch implementation. You can use it as a middleware to intercept requests,
* or to provide a custom fetch implementation for e.g. testing.
*/
fetch?: FetchFunction
/**
* Allow additional properties for provider-specific settings
*/
[key: string]: any
}
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model'>
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model'>
export type StreamObjectParams = Omit<Parameters<typeof streamObject>[0], 'model'>
export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'model'>
// 重新导出 ProviderSettingsMap 中的所有类型
export type {

View File

@ -39,9 +39,44 @@ export { aiProviderRegistry } from './providers/registry'
// ==================== 类型定义 ====================
export type { ClientFactoryError } from './clients/ApiClientFactory'
export type { BaseProviderSettings, ProviderSettings } from './clients/types'
export type {
GenerateObjectParams,
GenerateTextParams,
ProviderSettings,
StreamObjectParams,
StreamTextParams
} from './clients/types'
export type { ProviderConfig } from './providers/registry'
export type { ProviderError } from './providers/types'
export * as aiSdk from 'ai'
// ==================== AI SDK 常用类型导出 ====================
// 直接导出 AI SDK 的常用类型,方便使用
export type {
CoreAssistantMessage,
// 消息相关类型
CoreMessage,
CoreSystemMessage,
CoreToolMessage,
CoreUserMessage,
// 通用类型
FinishReason,
GenerateObjectResult,
// 生成相关类型
GenerateTextResult,
InvalidToolArgumentsError,
LanguageModelUsage, // AI SDK 4.0 中 TokenUsage 改名为 LanguageModelUsage
// 错误类型
NoSuchToolError,
StreamTextResult,
// 流相关类型
TextStreamPart,
// 工具相关类型
Tool,
ToolCall,
ToolExecutionError,
ToolResult
} from 'ai'
// 重新导出所有 Provider Settings 类型
export type {

View File

@ -0,0 +1,296 @@
/**
* AI SDK Cherry Studio Chunk
* AI SDK fullStream Cherry Studio chunk
*/
import { TextStreamPart } from '@cherry-studio/ai-core'
import { Chunk, ChunkType } from '@renderer/types/chunk'
export interface CherryStudioChunk {
type: 'text-delta' | 'text-complete' | 'tool-call' | 'tool-result' | 'finish' | 'error'
text?: string
toolCall?: any
toolResult?: any
finishReason?: string
usage?: any
error?: any
}
/**
* AI SDK Cherry Studio Chunk
* fullStream Cherry Studio chunk
*/
export class AiSdkToChunkAdapter {
constructor(private onChunk: (chunk: Chunk) => void) {}
/**
* AI SDK
* @param aiSdkResult AI SDK
* @returns
*/
async processStream(aiSdkResult: any): Promise<string> {
// 如果是流式且有 fullStream
if (aiSdkResult.fullStream) {
await this.readFullStream(aiSdkResult.fullStream)
}
// 使用 streamResult.text 获取最终结果
return await aiSdkResult.text
}
/**
* fullStream Cherry Studio chunks
* @param fullStream AI SDK fullStream (ReadableStream)
*/
private async readFullStream(fullStream: ReadableStream<TextStreamPart<any>>) {
const reader = fullStream.getReader()
const final = {
text: '',
reasoning_content: ''
}
try {
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
// 转换并发送 chunk
this.convertAndEmitChunk(value, final)
}
} finally {
reader.releaseLock()
}
}
/**
* AI SDK chunk Cherry Studio chunk
* @param chunk AI SDK chunk
*/
private convertAndEmitChunk(chunk: TextStreamPart<any>, final: { text: string; reasoning_content: string }) {
console.log('AI SDK chunk type:', chunk.type, chunk)
switch (chunk.type) {
// === 文本相关事件 ===
case 'text-delta':
final.text += chunk.textDelta || ''
this.onChunk({
type: ChunkType.TEXT_DELTA,
text: chunk.textDelta || ''
})
if (final.reasoning_content) {
this.onChunk({
type: ChunkType.THINKING_COMPLETE,
text: final.reasoning_content || ''
})
final.reasoning_content = ''
}
break
// === 推理相关事件 ===
case 'reasoning':
final.reasoning_content += chunk.textDelta || ''
this.onChunk({
type: ChunkType.THINKING_DELTA,
text: chunk.textDelta || ''
})
break
case 'reasoning-signature':
// 推理签名,可以映射到思考完成
this.onChunk({
type: ChunkType.THINKING_COMPLETE,
text: chunk.signature || ''
})
break
case 'redacted-reasoning':
// 被编辑的推理内容,也映射到思考
this.onChunk({
type: ChunkType.THINKING_DELTA,
text: chunk.data || ''
})
break
// === 工具调用相关事件 ===
case 'tool-call-streaming-start':
// 开始流式工具调用
this.onChunk({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: [
{
id: chunk.toolCallId,
name: chunk.toolName,
args: {}
}
]
})
break
case 'tool-call-delta':
// 工具调用参数的增量更新
this.onChunk({
type: ChunkType.MCP_TOOL_IN_PROGRESS,
responses: [
{
id: chunk.toolCallId,
tool: {
id: chunk.toolName,
// TODO: serverId,serverName
serverId: 'ai-sdk',
serverName: 'AI SDK',
name: chunk.toolName,
description: '',
inputSchema: {
type: 'object',
title: chunk.toolName,
properties: {}
}
},
arguments: {},
status: 'invoking',
response: chunk.argsTextDelta,
toolCallId: chunk.toolCallId
}
]
})
break
case 'tool-call':
// 完整的工具调用
this.onChunk({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: [
{
id: chunk.toolCallId,
name: chunk.toolName,
args: chunk.args
}
]
})
break
case 'tool-result':
// 工具调用结果
this.onChunk({
type: ChunkType.MCP_TOOL_COMPLETE,
responses: [
{
id: chunk.toolCallId,
tool: {
id: chunk.toolName,
// TODO: serverId,serverName
serverId: 'ai-sdk',
serverName: 'AI SDK',
name: chunk.toolName,
description: '',
inputSchema: {
type: 'object',
title: chunk.toolName,
properties: {}
}
},
arguments: chunk.args || {},
status: 'done',
response: chunk.result,
toolCallId: chunk.toolCallId
}
]
})
break
// === 步骤相关事件 ===
// case 'step-start':
// this.onChunk({
// type: ChunkType.LLM_RESPONSE_CREATED
// })
// break
case 'step-finish':
this.onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
text: final.text || '',
reasoning_content: final.reasoning_content || '',
usage: {
completion_tokens: chunk.usage.completionTokens || 0,
prompt_tokens: chunk.usage.promptTokens || 0,
total_tokens: chunk.usage.totalTokens || 0
},
metrics: chunk.usage
? {
completion_tokens: chunk.usage.completionTokens || 0,
time_completion_millsec: 0
}
: undefined
}
})
break
case 'finish':
this.onChunk({
type: ChunkType.TEXT_COMPLETE,
text: final.text || '' // TEXT_COMPLETE 需要 text 字段
})
this.onChunk({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
text: final.text || '',
reasoning_content: final.reasoning_content || '',
usage: {
completion_tokens: chunk.usage.completionTokens || 0,
prompt_tokens: chunk.usage.promptTokens || 0,
total_tokens: chunk.usage.totalTokens || 0
},
metrics: chunk.usage
? {
completion_tokens: chunk.usage.completionTokens || 0,
time_completion_millsec: 0
}
: undefined
}
})
break
// === 源和文件相关事件 ===
case 'source':
// 源信息,可以映射到知识搜索完成
this.onChunk({
type: ChunkType.KNOWLEDGE_SEARCH_COMPLETE,
knowledge: [
{
id: Number(chunk.source.id) || Date.now(),
content: chunk.source.title || '',
sourceUrl: chunk.source.url || '',
type: 'url'
}
]
})
break
case 'file':
// 文件相关事件,可能是图片生成
this.onChunk({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: [chunk.base64]
}
})
break
case 'error':
this.onChunk({
type: ChunkType.ERROR,
error: {
message: chunk.error || 'Unknown error'
}
})
break
default:
// 其他类型的 chunk 可以忽略或记录日志
console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
}
}
}
export default AiSdkToChunkAdapter

View File

@ -0,0 +1,230 @@
/**
* Cherry Studio AI Core -
* @cherry-studio/ai-core
*
*
* 1. 使AI SDK
* 2. fallback到原有实现
* 3.
*/
import {
AiClient,
AiCore,
createClient,
type OpenAICompatibleProviderSettings,
type ProviderId
} from '@cherry-studio/ai-core'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import { RequestOptions } from '@renderer/types/sdk'
// 引入适配器
import AiSdkToChunkAdapter from './AiSdkToChunkAdapter'
// 引入原有的AiProvider作为fallback
import LegacyAiProvider from './index'
import { CompletionsParams, CompletionsResult } from './middleware/schemas'
// 引入参数转换模块
import { buildStreamTextParams } from './transformParameters'
/**
* Provider AI SDK Provider ID
* registry.ts
*/
function mapProviderTypeToAiSdkId(providerType: string): string {
// Cherry Studio Provider Type -> AI SDK Provider ID 映射表
const typeMapping: Record<string, string> = {
// 需要转换的映射
grok: 'xai', // grok -> xai
'azure-openai': 'azure', // azure-openai -> azure
gemini: 'google' // gemini -> google
}
return typeMapping[providerType]
}
/**
* Provider AI SDK
*/
function providerToAiSdkConfig(provider: Provider): {
providerId: ProviderId | 'openai-compatible'
options: any
} {
console.log('provider', provider)
// 1. 先映射 provider 类型到 AI SDK ID
const mappedProviderId = mapProviderTypeToAiSdkId(provider.id)
// 2. 检查映射后的 provider ID 是否在 AI SDK 注册表中
const isSupported = AiCore.isSupported(mappedProviderId)
console.log(`Provider mapping: ${provider.type} -> ${mappedProviderId}, supported: ${isSupported}`)
// 3. 如果映射的 provider 不支持,则使用 openai-compatible
if (isSupported) {
return {
providerId: mappedProviderId as ProviderId,
options: {
apiKey: provider.apiKey
}
}
} else {
console.log(`Using openai-compatible fallback for provider: ${provider.type}`)
const compatibleConfig: OpenAICompatibleProviderSettings = {
name: provider.name || provider.type,
apiKey: provider.apiKey,
baseURL: provider.apiHost
}
return {
providerId: 'openai-compatible',
options: compatibleConfig
}
}
}
/**
* 使AI SDK
*/
function isModernSdkSupported(provider: Provider, model?: Model): boolean {
// 目前支持主要的providers
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai']
// 检查provider类型
if (!supportedProviders.includes(provider.type)) {
return false
}
// 检查是否为图像生成模型(暂时不支持)
if (model && isDedicatedImageGenerationModel(model)) {
return false
}
return true
}
export default class ModernAiProvider {
private modernClient?: AiClient
private legacyProvider: LegacyAiProvider
private provider: Provider
constructor(provider: Provider) {
this.provider = provider
this.legacyProvider = new LegacyAiProvider(provider)
const config = providerToAiSdkConfig(provider)
this.modernClient = createClient(config.providerId, config.options)
}
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
// const model = params.assistant.model
// 检查是否应该使用现代化客户端
// if (this.modernClient && model && isModernSdkSupported(this.provider, model)) {
// try {
return await this.modernCompletions(params, options)
// } catch (error) {
// console.warn('Modern client failed, falling back to legacy:', error)
// fallback到原有实现
// }
// }
// 使用原有实现
// return this.legacyProvider.completions(params, options)
}
/**
* 使AI SDK的completions实现
* 使 AiSdkUtils
*/
private async modernCompletions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
if (!this.modernClient || !params.assistant.model) {
throw new Error('Modern client not available')
}
console.log('Modern completions with params:', params, 'options:', options)
const model = params.assistant.model
const assistant = params.assistant
// 检查 messages 类型并转换
const messages = Array.isArray(params.messages) ? params.messages : []
if (typeof params.messages === 'string') {
console.warn('Messages is string, using empty array')
}
// 使用 transformParameters 模块构建参数
const aiSdkParams = await buildStreamTextParams(messages, assistant, model, {
maxTokens: params.maxTokens,
mcpTools: params.mcpTools
})
console.log('Built AI SDK params:', aiSdkParams)
const chunks: Chunk[] = []
try {
if (params.streamOutput && params.onChunk) {
// 流式处理 - 使用适配器
const adapter = new AiSdkToChunkAdapter(params.onChunk)
const streamResult = await this.modernClient.streamText(model.id, aiSdkParams)
const finalText = await adapter.processStream(streamResult)
return {
getText: () => finalText
}
} else if (params.streamOutput) {
// 流式处理但没有 onChunk 回调
const streamResult = await this.modernClient.streamText(model.id, aiSdkParams)
const finalText = await streamResult.text
return {
getText: () => finalText
}
} else {
// 非流式处理
const result = await this.modernClient.generateText(model.id, aiSdkParams)
const cherryChunk: Chunk = {
type: ChunkType.TEXT_COMPLETE,
text: result.text || ''
}
chunks.push(cherryChunk)
if (params.onChunk) {
params.onChunk(cherryChunk)
}
return {
getText: () => result.text || ''
}
}
} catch (error) {
console.error('Modern AI SDK error:', error)
throw error
}
}
// 代理其他方法到原有实现
public async models() {
return this.legacyProvider.models()
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
return this.legacyProvider.getEmbeddingDimensions(model)
}
public async generateImage(params: GenerateImageParams): Promise<string[]> {
return this.legacyProvider.generateImage(params)
}
public getBaseURL(): string {
return this.legacyProvider.getBaseURL()
}
public getApiKey(): string {
return this.legacyProvider.getApiKey()
}
}
// 为了方便调试,导出一些工具函数
export { isModernSdkSupported, providerToAiSdkConfig }

View File

@ -0,0 +1,269 @@
/**
* AI SDK
* apiClient
*/
import type { StreamTextParams } from '@cherry-studio/ai-core'
import { isNotSupportTemperatureAndTopP, isSupportedFlexServiceTier } from '@renderer/config/models'
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
import { FileTypes } from '@renderer/types'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { defaultTimeout } from '@shared/config/constant'
/**
*
*/
export function getTemperature(assistant: Assistant, model: Model): number | undefined {
return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.temperature
}
/**
* TopP
*/
export function getTopP(assistant: Assistant, model: Model): number | undefined {
return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.topP
}
/**
*
*/
export function getTimeout(model: Model): number {
if (isSupportedFlexServiceTier(model)) {
return 15 * 1000 * 60
}
return defaultTimeout
}
/**
*
*/
export async function buildSystemPromptWithTools(
prompt: string,
mcpTools?: MCPTool[],
assistant?: Assistant
): Promise<string> {
return await buildSystemPrompt(prompt, mcpTools, assistant)
}
// /**
// * 转换 MCP 工具为 AI SDK 工具格式
// * 注意:这里返回通用格式,实际使用时需要根据具体 provider 转换
// TODO: 需要使用ai-sdk的mcp
// */
// export function convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Pick<StreamTextParams, 'tools'> {
// return mcpTools.map((tool) => ({
// type: 'function',
// function: {
// name: tool.id,
// description: tool.description,
// parameters: tool.inputSchema || {}
// }
// }))
// }
/**
*
*/
export async function extractFileContent(message: Message): Promise<string> {
const fileBlocks = findFileBlocks(message)
if (fileBlocks.length > 0) {
const textFileBlocks = fileBlocks.filter(
(fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type)
)
if (textFileBlocks.length > 0) {
let text = ''
const divider = '\n\n---\n\n'
for (const fileBlock of textFileBlocks) {
const file = fileBlock.file
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
text = text + fileNameRow + fileContent + divider
}
return text
}
}
return ''
}
/**
* AI SDK
* OpenAI
*/
export async function convertMessageToSdkParam(message: Message, isVisionModel = false): Promise<any> {
const content = getMainTextContent(message)
const fileBlocks = findFileBlocks(message)
const imageBlocks = findImageBlocks(message)
// 简单消息(无文件无图片)
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
return {
role: message.role === 'system' ? 'user' : message.role,
content
}
}
// 复杂消息(包含文件或图片)
const parts: any[] = []
if (content) {
parts.push({ type: 'text', text: content })
}
// 处理图片(仅在支持视觉的模型中)
if (isVisionModel) {
for (const imageBlock of imageBlocks) {
if (imageBlock.file) {
try {
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
parts.push({
type: 'image_url',
image_url: { url: image.data }
})
} catch (error) {
console.warn('Failed to load image:', error)
}
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
parts.push({
type: 'image_url',
image_url: { url: imageBlock.url }
})
}
}
}
// 处理文件
for (const fileBlock of fileBlocks) {
const file = fileBlock.file
if (!file) continue
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
try {
const fileContent = await window.api.file.read(file.id + file.ext)
parts.push({
type: 'text',
text: `${file.origin_name}\n${fileContent.trim()}`
})
} catch (error) {
console.warn('Failed to read file:', error)
}
}
}
return {
role: message.role === 'system' ? 'user' : message.role,
content: parts.length === 1 && parts[0].type === 'text' ? parts[0].text : parts
}
}
/**
* Cherry Studio AI SDK
*/
export async function convertMessagesToSdkMessages(
messages: Message[],
model: Model
): Promise<StreamTextParams['messages']> {
const sdkMessages: StreamTextParams['messages'] = []
const isVision = model.id.includes('vision') || model.id.includes('gpt-4') // 简单的视觉模型检测
for (const message of messages) {
const sdkMessage = await convertMessageToSdkParam(message, isVision)
sdkMessages.push(sdkMessage)
}
return sdkMessages
}
/**
* AI SDK
*
*/
export async function buildStreamTextParams(
messages: Message[],
assistant: Assistant,
model: Model,
options: {
maxTokens?: number
mcpTools?: MCPTool[]
enableTools?: boolean
} = {}
): Promise<StreamTextParams> {
const { maxTokens, mcpTools, enableTools = false } = options
// 转换消息
const sdkMessages = await convertMessagesToSdkMessages(messages, model)
// 构建系统提示
let systemPrompt = assistant.prompt || ''
if (mcpTools && mcpTools.length > 0) {
systemPrompt = await buildSystemPromptWithTools(systemPrompt, mcpTools, assistant)
}
// 构建基础参数
const params: StreamTextParams = {
messages: sdkMessages,
maxTokens: maxTokens || 1000,
temperature: getTemperature(assistant, model),
topP: getTopP(assistant, model),
system: systemPrompt || undefined,
...getCustomParameters(assistant)
}
// 添加工具(如果启用且有工具)
if (enableTools && mcpTools && mcpTools.length > 0) {
// TODO: 暂时注释掉工具支持,等类型问题解决后再启用
// params.tools = convertMcpToolsToSdkTools(mcpTools)
}
return params
}
/**
* generateText
*/
export async function buildGenerateTextParams(
messages: Message[],
assistant: Assistant,
model: Model,
options: {
maxTokens?: number
mcpTools?: MCPTool[]
enableTools?: boolean
} = {}
): Promise<any> {
// 复用流式参数的构建逻辑
return await buildStreamTextParams(messages, assistant, model, options)
}
/**
*
* assistant
*/
export function getCustomParameters(assistant: Assistant): Record<string, any> {
return (
assistant?.settings?.customParameters?.reduce((acc, param) => {
if (!param.name?.trim()) {
return acc
}
if (param.type === 'json') {
const value = param.value as string
if (value === 'undefined') {
return { ...acc, [param.name]: undefined }
}
try {
return { ...acc, [param.name]: JSON.parse(value) }
} catch {
return { ...acc, [param.name]: value }
}
}
return {
...acc,
[param.name]: param.value
}
}, {}) || {}
)
}

View File

@ -37,7 +37,7 @@ import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils
import { findLast, isEmpty, takeRight } from 'lodash'
import AiProvider from '../aiCore'
import store from '../store'
import AiProviderNew from '../aiCore/index_new'
import {
getAssistantProvider,
getAssistantSettings,
@ -313,7 +313,7 @@ export async function fetchChatCompletion({
console.log('fetchChatCompletion', messages, assistant)
const provider = getAssistantProvider(assistant)
const AI = new AiProvider(provider)
const AI = new AiProviderNew(provider)
// Make sure that 'Clear Context' works for all scenarios including external tool and normal chat.
messages = filterContextMessages(messages)

View File

@ -4,7 +4,8 @@
"src/renderer/src/**/*",
"src/preload/*.d.ts",
"local/src/renderer/**/*",
"packages/shared/**/*"
"packages/shared/**/*",
"packages/aiCore/src/**/*"
],
"compilerOptions": {
"composite": true,
@ -14,7 +15,8 @@
"paths": {
"@renderer/*": ["src/renderer/src/*"],
"@shared/*": ["packages/shared/*"],
"@types": ["src/renderer/src/types/index.ts"]
"@types": ["src/renderer/src/types/index.ts"],
"@cherry-studio/ai-core": ["packages/aiCore/src/"]
}
}
}

View File

@ -960,7 +960,7 @@ __metadata:
languageName: node
linkType: hard
"@cherry-studio/ai-core@workspace:packages/aiCore":
"@cherry-studio/ai-core@workspace:*, @cherry-studio/ai-core@workspace:packages/aiCore":
version: 0.0.0-use.local
resolution: "@cherry-studio/ai-core@workspace:packages/aiCore"
dependencies:
@ -6392,6 +6392,7 @@ __metadata:
"@agentic/tavily": "npm:^7.3.3"
"@ant-design/v5-patch-for-react-19": "npm:^1.0.3"
"@anthropic-ai/sdk": "npm:^0.41.0"
"@cherry-studio/ai-core": "workspace:*"
"@cherrystudio/embedjs": "npm:^0.1.31"
"@cherrystudio/embedjs-libsql": "npm:^0.1.31"
"@cherrystudio/embedjs-loader-csv": "npm:^0.1.31"