mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-27 21:01:32 +08:00
feat: integrate @cherry-studio/ai-core and enhance AI SDK support
- Added @cherry-studio/ai-core as a workspace dependency in package.json for improved modularity. - Updated tsconfig to include paths for the new AI core package, enhancing type resolution. - Refactored aiCore package to use source files directly, improving build efficiency. - Introduced a new AiSdkToChunkAdapter for converting AI SDK streams to Cherry Studio chunk format. - Implemented a modernized AI provider interface in index_new.ts, allowing fallback to legacy implementations. - Enhanced parameter transformation logic for better integration with AI SDK features. - Updated ApiService to utilize the new AI provider, streamlining chat completion requests.
This commit is contained in:
parent
1c5a30cf49
commit
43d55b7e45
@ -73,6 +73,7 @@
|
||||
"@agentic/tavily": "^7.3.3",
|
||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||
"@anthropic-ai/sdk": "^0.41.0",
|
||||
"@cherry-studio/ai-core": "workspace:*",
|
||||
"@cherrystudio/embedjs": "^0.1.31",
|
||||
"@cherrystudio/embedjs-libsql": "^0.1.31",
|
||||
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
"name": "@cherry-studio/ai-core",
|
||||
"version": "1.0.0",
|
||||
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
"main": "src/index.ts",
|
||||
"types": "src/index.ts",
|
||||
"scripts": {
|
||||
"build": "tsc",
|
||||
"dev": "tsc -w",
|
||||
@ -114,9 +114,9 @@
|
||||
],
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"import": "./dist/index.js",
|
||||
"require": "./dist/index.js"
|
||||
"types": "./src/index.ts",
|
||||
"import": "./src/index.ts",
|
||||
"require": "./src/index.ts"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,10 +19,10 @@
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
|
||||
import { generateObject, generateText, streamObject, streamText } from 'ai'
|
||||
|
||||
import { AiPlugin, createContext, PluginManager } from '../plugins'
|
||||
import { isProviderSupported } from '../providers/registry'
|
||||
import { ApiClientFactory } from './ApiClientFactory'
|
||||
import { type ProviderId, type ProviderSettingsMap } from './types'
|
||||
import { UniversalAiSdkClient } from './UniversalAiSdkClient'
|
||||
@ -178,8 +178,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
|
||||
async (finalModelId, transformedParams, streamTransforms) => {
|
||||
// 对于流式调用,需要直接调用 AI SDK 以支持流转换器
|
||||
const model = await ApiClientFactory.createClient(this.providerId, finalModelId, this.options)
|
||||
|
||||
return streamText({
|
||||
return await streamText({
|
||||
model,
|
||||
...transformedParams,
|
||||
experimental_transform: streamTransforms.length > 0 ? streamTransforms : undefined
|
||||
@ -196,7 +195,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
|
||||
params: Omit<Parameters<typeof generateText>[0], 'model'>
|
||||
): Promise<ReturnType<typeof generateText>> {
|
||||
return this.executeWithPlugins('generateText', modelId, params, async (finalModelId, transformedParams) => {
|
||||
return this.baseClient.generateText(finalModelId, transformedParams)
|
||||
return await this.baseClient.generateText(finalModelId, transformedParams)
|
||||
})
|
||||
}
|
||||
|
||||
@ -208,7 +207,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
|
||||
params: Omit<Parameters<typeof generateObject>[0], 'model'>
|
||||
): Promise<ReturnType<typeof generateObject>> {
|
||||
return this.executeWithPlugins('generateObject', modelId, params, async (finalModelId, transformedParams) => {
|
||||
return this.baseClient.generateObject(finalModelId, transformedParams)
|
||||
return await this.baseClient.generateObject(finalModelId, transformedParams)
|
||||
})
|
||||
}
|
||||
|
||||
@ -221,7 +220,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
|
||||
params: Omit<Parameters<typeof streamObject>[0], 'model'>
|
||||
): Promise<ReturnType<typeof streamObject>> {
|
||||
return this.executeWithPlugins('streamObject', modelId, params, async (finalModelId, transformedParams) => {
|
||||
return this.baseClient.streamObject(finalModelId, transformedParams)
|
||||
return await this.baseClient.streamObject(finalModelId, transformedParams)
|
||||
})
|
||||
}
|
||||
|
||||
@ -267,7 +266,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
|
||||
): PluginEnabledAiClient<'openai-compatible'>
|
||||
|
||||
static create(providerId: string, options: any, plugins: AiPlugin[] = []): PluginEnabledAiClient {
|
||||
if (providerId in ({} as ProviderSettingsMap)) {
|
||||
if (isProviderSupported(providerId)) {
|
||||
return new PluginEnabledAiClient(providerId as ProviderId, options, plugins)
|
||||
} else {
|
||||
// 对于未知 provider,使用 openai-compatible
|
||||
|
||||
@ -1,38 +1,14 @@
|
||||
import { FetchFunction } from '@ai-sdk/provider-utils'
|
||||
import { generateObject, generateText, streamObject, streamText } from 'ai'
|
||||
|
||||
import type { ProviderSettingsMap } from '../providers/registry'
|
||||
|
||||
// ProviderSettings 是所有 Provider Settings 的联合类型
|
||||
export type ProviderSettings = ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||
|
||||
// 基础 Provider 配置类型(为了向后兼容和通用场景)
|
||||
export type BaseProviderSettings = {
|
||||
/**
|
||||
* API key for authentication
|
||||
*/
|
||||
apiKey?: string
|
||||
/**
|
||||
* Base URL for the API calls
|
||||
*/
|
||||
baseURL?: string
|
||||
/**
|
||||
* Custom headers to include in the requests
|
||||
*/
|
||||
headers?: Record<string, string>
|
||||
/**
|
||||
* Optional custom url query parameters to include in request urls
|
||||
*/
|
||||
queryParams?: Record<string, string>
|
||||
/**
|
||||
* Custom fetch implementation. You can use it as a middleware to intercept requests,
|
||||
* or to provide a custom fetch implementation for e.g. testing.
|
||||
*/
|
||||
fetch?: FetchFunction
|
||||
/**
|
||||
* Allow additional properties for provider-specific settings
|
||||
*/
|
||||
[key: string]: any
|
||||
}
|
||||
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model'>
|
||||
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model'>
|
||||
export type StreamObjectParams = Omit<Parameters<typeof streamObject>[0], 'model'>
|
||||
export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'model'>
|
||||
|
||||
// 重新导出 ProviderSettingsMap 中的所有类型
|
||||
export type {
|
||||
|
||||
@ -39,9 +39,44 @@ export { aiProviderRegistry } from './providers/registry'
|
||||
|
||||
// ==================== 类型定义 ====================
|
||||
export type { ClientFactoryError } from './clients/ApiClientFactory'
|
||||
export type { BaseProviderSettings, ProviderSettings } from './clients/types'
|
||||
export type {
|
||||
GenerateObjectParams,
|
||||
GenerateTextParams,
|
||||
ProviderSettings,
|
||||
StreamObjectParams,
|
||||
StreamTextParams
|
||||
} from './clients/types'
|
||||
export type { ProviderConfig } from './providers/registry'
|
||||
export type { ProviderError } from './providers/types'
|
||||
export * as aiSdk from 'ai'
|
||||
|
||||
// ==================== AI SDK 常用类型导出 ====================
|
||||
// 直接导出 AI SDK 的常用类型,方便使用
|
||||
export type {
|
||||
CoreAssistantMessage,
|
||||
// 消息相关类型
|
||||
CoreMessage,
|
||||
CoreSystemMessage,
|
||||
CoreToolMessage,
|
||||
CoreUserMessage,
|
||||
// 通用类型
|
||||
FinishReason,
|
||||
GenerateObjectResult,
|
||||
// 生成相关类型
|
||||
GenerateTextResult,
|
||||
InvalidToolArgumentsError,
|
||||
LanguageModelUsage, // AI SDK 4.0 中 TokenUsage 改名为 LanguageModelUsage
|
||||
// 错误类型
|
||||
NoSuchToolError,
|
||||
StreamTextResult,
|
||||
// 流相关类型
|
||||
TextStreamPart,
|
||||
// 工具相关类型
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolExecutionError,
|
||||
ToolResult
|
||||
} from 'ai'
|
||||
|
||||
// 重新导出所有 Provider Settings 类型
|
||||
export type {
|
||||
|
||||
296
src/renderer/src/aiCore/AiSdkToChunkAdapter.ts
Normal file
296
src/renderer/src/aiCore/AiSdkToChunkAdapter.ts
Normal file
@ -0,0 +1,296 @@
|
||||
/**
|
||||
* AI SDK 到 Cherry Studio Chunk 适配器
|
||||
* 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式
|
||||
*/
|
||||
|
||||
import { TextStreamPart } from '@cherry-studio/ai-core'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
export interface CherryStudioChunk {
|
||||
type: 'text-delta' | 'text-complete' | 'tool-call' | 'tool-result' | 'finish' | 'error'
|
||||
text?: string
|
||||
toolCall?: any
|
||||
toolResult?: any
|
||||
finishReason?: string
|
||||
usage?: any
|
||||
error?: any
|
||||
}
|
||||
|
||||
/**
|
||||
* AI SDK 到 Cherry Studio Chunk 适配器类
|
||||
* 处理 fullStream 到 Cherry Studio chunk 的转换
|
||||
*/
|
||||
export class AiSdkToChunkAdapter {
|
||||
constructor(private onChunk: (chunk: Chunk) => void) {}
|
||||
|
||||
/**
|
||||
* 处理 AI SDK 流结果
|
||||
* @param aiSdkResult AI SDK 的流结果对象
|
||||
* @returns 最终的文本内容
|
||||
*/
|
||||
async processStream(aiSdkResult: any): Promise<string> {
|
||||
// 如果是流式且有 fullStream
|
||||
if (aiSdkResult.fullStream) {
|
||||
await this.readFullStream(aiSdkResult.fullStream)
|
||||
}
|
||||
|
||||
// 使用 streamResult.text 获取最终结果
|
||||
return await aiSdkResult.text
|
||||
}
|
||||
|
||||
/**
|
||||
* 读取 fullStream 并转换为 Cherry Studio chunks
|
||||
* @param fullStream AI SDK 的 fullStream (ReadableStream)
|
||||
*/
|
||||
private async readFullStream(fullStream: ReadableStream<TextStreamPart<any>>) {
|
||||
const reader = fullStream.getReader()
|
||||
const final = {
|
||||
text: '',
|
||||
reasoning_content: ''
|
||||
}
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
|
||||
if (done) {
|
||||
break
|
||||
}
|
||||
|
||||
// 转换并发送 chunk
|
||||
this.convertAndEmitChunk(value, final)
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换 AI SDK chunk 为 Cherry Studio chunk 并调用回调
|
||||
* @param chunk AI SDK 的 chunk 数据
|
||||
*/
|
||||
private convertAndEmitChunk(chunk: TextStreamPart<any>, final: { text: string; reasoning_content: string }) {
|
||||
console.log('AI SDK chunk type:', chunk.type, chunk)
|
||||
switch (chunk.type) {
|
||||
// === 文本相关事件 ===
|
||||
case 'text-delta':
|
||||
final.text += chunk.textDelta || ''
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: chunk.textDelta || ''
|
||||
})
|
||||
if (final.reasoning_content) {
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: final.reasoning_content || ''
|
||||
})
|
||||
final.reasoning_content = ''
|
||||
}
|
||||
break
|
||||
|
||||
// === 推理相关事件 ===
|
||||
case 'reasoning':
|
||||
final.reasoning_content += chunk.textDelta || ''
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: chunk.textDelta || ''
|
||||
})
|
||||
break
|
||||
|
||||
case 'reasoning-signature':
|
||||
// 推理签名,可以映射到思考完成
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: chunk.signature || ''
|
||||
})
|
||||
break
|
||||
|
||||
case 'redacted-reasoning':
|
||||
// 被编辑的推理内容,也映射到思考
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: chunk.data || ''
|
||||
})
|
||||
break
|
||||
|
||||
// === 工具调用相关事件 ===
|
||||
case 'tool-call-streaming-start':
|
||||
// 开始流式工具调用
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: [
|
||||
{
|
||||
id: chunk.toolCallId,
|
||||
name: chunk.toolName,
|
||||
args: {}
|
||||
}
|
||||
]
|
||||
})
|
||||
break
|
||||
|
||||
case 'tool-call-delta':
|
||||
// 工具调用参数的增量更新
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_IN_PROGRESS,
|
||||
responses: [
|
||||
{
|
||||
id: chunk.toolCallId,
|
||||
tool: {
|
||||
id: chunk.toolName,
|
||||
// TODO: serverId,serverName
|
||||
serverId: 'ai-sdk',
|
||||
serverName: 'AI SDK',
|
||||
name: chunk.toolName,
|
||||
description: '',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
title: chunk.toolName,
|
||||
properties: {}
|
||||
}
|
||||
},
|
||||
arguments: {},
|
||||
status: 'invoking',
|
||||
response: chunk.argsTextDelta,
|
||||
toolCallId: chunk.toolCallId
|
||||
}
|
||||
]
|
||||
})
|
||||
break
|
||||
|
||||
case 'tool-call':
|
||||
// 完整的工具调用
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: [
|
||||
{
|
||||
id: chunk.toolCallId,
|
||||
name: chunk.toolName,
|
||||
args: chunk.args
|
||||
}
|
||||
]
|
||||
})
|
||||
break
|
||||
|
||||
case 'tool-result':
|
||||
// 工具调用结果
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||
responses: [
|
||||
{
|
||||
id: chunk.toolCallId,
|
||||
tool: {
|
||||
id: chunk.toolName,
|
||||
// TODO: serverId,serverName
|
||||
serverId: 'ai-sdk',
|
||||
serverName: 'AI SDK',
|
||||
name: chunk.toolName,
|
||||
description: '',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
title: chunk.toolName,
|
||||
properties: {}
|
||||
}
|
||||
},
|
||||
arguments: chunk.args || {},
|
||||
status: 'done',
|
||||
response: chunk.result,
|
||||
toolCallId: chunk.toolCallId
|
||||
}
|
||||
]
|
||||
})
|
||||
break
|
||||
|
||||
// === 步骤相关事件 ===
|
||||
// case 'step-start':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.LLM_RESPONSE_CREATED
|
||||
// })
|
||||
// break
|
||||
case 'step-finish':
|
||||
this.onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
text: final.text || '',
|
||||
reasoning_content: final.reasoning_content || '',
|
||||
usage: {
|
||||
completion_tokens: chunk.usage.completionTokens || 0,
|
||||
prompt_tokens: chunk.usage.promptTokens || 0,
|
||||
total_tokens: chunk.usage.totalTokens || 0
|
||||
},
|
||||
metrics: chunk.usage
|
||||
? {
|
||||
completion_tokens: chunk.usage.completionTokens || 0,
|
||||
time_completion_millsec: 0
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
})
|
||||
break
|
||||
|
||||
case 'finish':
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: final.text || '' // TEXT_COMPLETE 需要 text 字段
|
||||
})
|
||||
this.onChunk({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
text: final.text || '',
|
||||
reasoning_content: final.reasoning_content || '',
|
||||
usage: {
|
||||
completion_tokens: chunk.usage.completionTokens || 0,
|
||||
prompt_tokens: chunk.usage.promptTokens || 0,
|
||||
total_tokens: chunk.usage.totalTokens || 0
|
||||
},
|
||||
metrics: chunk.usage
|
||||
? {
|
||||
completion_tokens: chunk.usage.completionTokens || 0,
|
||||
time_completion_millsec: 0
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
})
|
||||
break
|
||||
|
||||
// === 源和文件相关事件 ===
|
||||
case 'source':
|
||||
// 源信息,可以映射到知识搜索完成
|
||||
this.onChunk({
|
||||
type: ChunkType.KNOWLEDGE_SEARCH_COMPLETE,
|
||||
knowledge: [
|
||||
{
|
||||
id: Number(chunk.source.id) || Date.now(),
|
||||
content: chunk.source.title || '',
|
||||
sourceUrl: chunk.source.url || '',
|
||||
type: 'url'
|
||||
}
|
||||
]
|
||||
})
|
||||
break
|
||||
|
||||
case 'file':
|
||||
// 文件相关事件,可能是图片生成
|
||||
this.onChunk({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [chunk.base64]
|
||||
}
|
||||
})
|
||||
break
|
||||
case 'error':
|
||||
this.onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
error: {
|
||||
message: chunk.error || 'Unknown error'
|
||||
}
|
||||
})
|
||||
break
|
||||
|
||||
default:
|
||||
// 其他类型的 chunk 可以忽略或记录日志
|
||||
console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default AiSdkToChunkAdapter
|
||||
230
src/renderer/src/aiCore/index_new.ts
Normal file
230
src/renderer/src/aiCore/index_new.ts
Normal file
@ -0,0 +1,230 @@
|
||||
/**
|
||||
* Cherry Studio AI Core - 新版本入口
|
||||
* 集成 @cherry-studio/ai-core 库的渐进式重构方案
|
||||
*
|
||||
* 融合方案:简化实现,专注于核心功能
|
||||
* 1. 优先使用新AI SDK
|
||||
* 2. 失败时fallback到原有实现
|
||||
* 3. 暂时保持接口兼容性
|
||||
*/
|
||||
|
||||
import {
|
||||
AiClient,
|
||||
AiCore,
|
||||
createClient,
|
||||
type OpenAICompatibleProviderSettings,
|
||||
type ProviderId
|
||||
} from '@cherry-studio/ai-core'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import { RequestOptions } from '@renderer/types/sdk'
|
||||
|
||||
// 引入适配器
|
||||
import AiSdkToChunkAdapter from './AiSdkToChunkAdapter'
|
||||
// 引入原有的AiProvider作为fallback
|
||||
import LegacyAiProvider from './index'
|
||||
import { CompletionsParams, CompletionsResult } from './middleware/schemas'
|
||||
// 引入参数转换模块
|
||||
import { buildStreamTextParams } from './transformParameters'
|
||||
|
||||
/**
|
||||
* 将现有 Provider 类型映射到 AI SDK 的 Provider ID
|
||||
* 根据 registry.ts 中的支持列表进行映射
|
||||
*/
|
||||
function mapProviderTypeToAiSdkId(providerType: string): string {
|
||||
// Cherry Studio Provider Type -> AI SDK Provider ID 映射表
|
||||
const typeMapping: Record<string, string> = {
|
||||
// 需要转换的映射
|
||||
grok: 'xai', // grok -> xai
|
||||
'azure-openai': 'azure', // azure-openai -> azure
|
||||
gemini: 'google' // gemini -> google
|
||||
}
|
||||
|
||||
return typeMapping[providerType]
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 Provider 配置转换为新 AI SDK 格式
|
||||
*/
|
||||
function providerToAiSdkConfig(provider: Provider): {
|
||||
providerId: ProviderId | 'openai-compatible'
|
||||
options: any
|
||||
} {
|
||||
console.log('provider', provider)
|
||||
// 1. 先映射 provider 类型到 AI SDK ID
|
||||
const mappedProviderId = mapProviderTypeToAiSdkId(provider.id)
|
||||
|
||||
// 2. 检查映射后的 provider ID 是否在 AI SDK 注册表中
|
||||
const isSupported = AiCore.isSupported(mappedProviderId)
|
||||
|
||||
console.log(`Provider mapping: ${provider.type} -> ${mappedProviderId}, supported: ${isSupported}`)
|
||||
|
||||
// 3. 如果映射的 provider 不支持,则使用 openai-compatible
|
||||
if (isSupported) {
|
||||
return {
|
||||
providerId: mappedProviderId as ProviderId,
|
||||
options: {
|
||||
apiKey: provider.apiKey
|
||||
}
|
||||
}
|
||||
} else {
|
||||
console.log(`Using openai-compatible fallback for provider: ${provider.type}`)
|
||||
const compatibleConfig: OpenAICompatibleProviderSettings = {
|
||||
name: provider.name || provider.type,
|
||||
apiKey: provider.apiKey,
|
||||
baseURL: provider.apiHost
|
||||
}
|
||||
|
||||
return {
|
||||
providerId: 'openai-compatible',
|
||||
options: compatibleConfig
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否支持使用新的AI SDK
|
||||
*/
|
||||
function isModernSdkSupported(provider: Provider, model?: Model): boolean {
|
||||
// 目前支持主要的providers
|
||||
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai']
|
||||
|
||||
// 检查provider类型
|
||||
if (!supportedProviders.includes(provider.type)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否为图像生成模型(暂时不支持)
|
||||
if (model && isDedicatedImageGenerationModel(model)) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
export default class ModernAiProvider {
|
||||
private modernClient?: AiClient
|
||||
private legacyProvider: LegacyAiProvider
|
||||
private provider: Provider
|
||||
|
||||
constructor(provider: Provider) {
|
||||
this.provider = provider
|
||||
this.legacyProvider = new LegacyAiProvider(provider)
|
||||
|
||||
const config = providerToAiSdkConfig(provider)
|
||||
this.modernClient = createClient(config.providerId, config.options)
|
||||
}
|
||||
|
||||
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
// const model = params.assistant.model
|
||||
|
||||
// 检查是否应该使用现代化客户端
|
||||
// if (this.modernClient && model && isModernSdkSupported(this.provider, model)) {
|
||||
// try {
|
||||
return await this.modernCompletions(params, options)
|
||||
// } catch (error) {
|
||||
// console.warn('Modern client failed, falling back to legacy:', error)
|
||||
// fallback到原有实现
|
||||
// }
|
||||
// }
|
||||
|
||||
// 使用原有实现
|
||||
// return this.legacyProvider.completions(params, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化AI SDK的completions实现
|
||||
* 使用 AiSdkUtils 工具模块进行参数构建
|
||||
*/
|
||||
private async modernCompletions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
if (!this.modernClient || !params.assistant.model) {
|
||||
throw new Error('Modern client not available')
|
||||
}
|
||||
|
||||
console.log('Modern completions with params:', params, 'options:', options)
|
||||
|
||||
const model = params.assistant.model
|
||||
const assistant = params.assistant
|
||||
|
||||
// 检查 messages 类型并转换
|
||||
const messages = Array.isArray(params.messages) ? params.messages : []
|
||||
if (typeof params.messages === 'string') {
|
||||
console.warn('Messages is string, using empty array')
|
||||
}
|
||||
|
||||
// 使用 transformParameters 模块构建参数
|
||||
const aiSdkParams = await buildStreamTextParams(messages, assistant, model, {
|
||||
maxTokens: params.maxTokens,
|
||||
mcpTools: params.mcpTools
|
||||
})
|
||||
|
||||
console.log('Built AI SDK params:', aiSdkParams)
|
||||
const chunks: Chunk[] = []
|
||||
|
||||
try {
|
||||
if (params.streamOutput && params.onChunk) {
|
||||
// 流式处理 - 使用适配器
|
||||
const adapter = new AiSdkToChunkAdapter(params.onChunk)
|
||||
const streamResult = await this.modernClient.streamText(model.id, aiSdkParams)
|
||||
const finalText = await adapter.processStream(streamResult)
|
||||
|
||||
return {
|
||||
getText: () => finalText
|
||||
}
|
||||
} else if (params.streamOutput) {
|
||||
// 流式处理但没有 onChunk 回调
|
||||
const streamResult = await this.modernClient.streamText(model.id, aiSdkParams)
|
||||
const finalText = await streamResult.text
|
||||
|
||||
return {
|
||||
getText: () => finalText
|
||||
}
|
||||
} else {
|
||||
// 非流式处理
|
||||
const result = await this.modernClient.generateText(model.id, aiSdkParams)
|
||||
|
||||
const cherryChunk: Chunk = {
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: result.text || ''
|
||||
}
|
||||
chunks.push(cherryChunk)
|
||||
|
||||
if (params.onChunk) {
|
||||
params.onChunk(cherryChunk)
|
||||
}
|
||||
|
||||
return {
|
||||
getText: () => result.text || ''
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Modern AI SDK error:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 代理其他方法到原有实现
|
||||
public async models() {
|
||||
return this.legacyProvider.models()
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
return this.legacyProvider.getEmbeddingDimensions(model)
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
return this.legacyProvider.generateImage(params)
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.legacyProvider.getBaseURL()
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.legacyProvider.getApiKey()
|
||||
}
|
||||
}
|
||||
|
||||
// 为了方便调试,导出一些工具函数
|
||||
export { isModernSdkSupported, providerToAiSdkConfig }
|
||||
269
src/renderer/src/aiCore/transformParameters.ts
Normal file
269
src/renderer/src/aiCore/transformParameters.ts
Normal file
@ -0,0 +1,269 @@
|
||||
/**
|
||||
* AI SDK 参数转换模块
|
||||
* 统一管理从各个 apiClient 提取的参数处理和转换功能
|
||||
*/
|
||||
|
||||
import type { StreamTextParams } from '@cherry-studio/ai-core'
|
||||
import { isNotSupportTemperatureAndTopP, isSupportedFlexServiceTier } from '@renderer/config/models'
|
||||
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
|
||||
import { FileTypes } from '@renderer/types'
|
||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
|
||||
/**
|
||||
* 获取温度参数
|
||||
*/
|
||||
export function getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||
return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.temperature
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 TopP 参数
|
||||
*/
|
||||
export function getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||
return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.topP
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取超时设置
|
||||
*/
|
||||
export function getTimeout(model: Model): number {
|
||||
if (isSupportedFlexServiceTier(model)) {
|
||||
return 15 * 1000 * 60
|
||||
}
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建系统提示词
|
||||
*/
|
||||
export async function buildSystemPromptWithTools(
|
||||
prompt: string,
|
||||
mcpTools?: MCPTool[],
|
||||
assistant?: Assistant
|
||||
): Promise<string> {
|
||||
return await buildSystemPrompt(prompt, mcpTools, assistant)
|
||||
}
|
||||
|
||||
// /**
|
||||
// * 转换 MCP 工具为 AI SDK 工具格式
|
||||
// * 注意:这里返回通用格式,实际使用时需要根据具体 provider 转换
|
||||
// TODO: 需要使用ai-sdk的mcp
|
||||
// */
|
||||
// export function convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Pick<StreamTextParams, 'tools'> {
|
||||
// return mcpTools.map((tool) => ({
|
||||
// type: 'function',
|
||||
// function: {
|
||||
// name: tool.id,
|
||||
// description: tool.description,
|
||||
// parameters: tool.inputSchema || {}
|
||||
// }
|
||||
// }))
|
||||
// }
|
||||
|
||||
/**
|
||||
* 提取文件内容
|
||||
*/
|
||||
export async function extractFileContent(message: Message): Promise<string> {
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
if (fileBlocks.length > 0) {
|
||||
const textFileBlocks = fileBlocks.filter(
|
||||
(fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type)
|
||||
)
|
||||
|
||||
if (textFileBlocks.length > 0) {
|
||||
let text = ''
|
||||
const divider = '\n\n---\n\n'
|
||||
|
||||
for (const fileBlock of textFileBlocks) {
|
||||
const file = fileBlock.file
|
||||
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
|
||||
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
|
||||
text = text + fileNameRow + fileContent + divider
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
}
|
||||
|
||||
return ''
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换消息为 AI SDK 参数格式
|
||||
* 基于 OpenAI 格式的通用转换,支持文本、图片和文件
|
||||
*/
|
||||
export async function convertMessageToSdkParam(message: Message, isVisionModel = false): Promise<any> {
|
||||
const content = getMainTextContent(message)
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
|
||||
// 简单消息(无文件无图片)
|
||||
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content
|
||||
}
|
||||
}
|
||||
|
||||
// 复杂消息(包含文件或图片)
|
||||
const parts: any[] = []
|
||||
|
||||
if (content) {
|
||||
parts.push({ type: 'text', text: content })
|
||||
}
|
||||
|
||||
// 处理图片(仅在支持视觉的模型中)
|
||||
if (isVisionModel) {
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (imageBlock.file) {
|
||||
try {
|
||||
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
|
||||
parts.push({
|
||||
type: 'image_url',
|
||||
image_url: { url: image.data }
|
||||
})
|
||||
} catch (error) {
|
||||
console.warn('Failed to load image:', error)
|
||||
}
|
||||
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
|
||||
parts.push({
|
||||
type: 'image_url',
|
||||
image_url: { url: imageBlock.url }
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理文件
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const file = fileBlock.file
|
||||
if (!file) continue
|
||||
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
try {
|
||||
const fileContent = await window.api.file.read(file.id + file.ext)
|
||||
parts.push({
|
||||
type: 'text',
|
||||
text: `${file.origin_name}\n${fileContent.trim()}`
|
||||
})
|
||||
} catch (error) {
|
||||
console.warn('Failed to read file:', error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: parts.length === 1 && parts[0].type === 'text' ? parts[0].text : parts
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换 Cherry Studio 消息数组为 AI SDK 消息数组
|
||||
*/
|
||||
export async function convertMessagesToSdkMessages(
|
||||
messages: Message[],
|
||||
model: Model
|
||||
): Promise<StreamTextParams['messages']> {
|
||||
const sdkMessages: StreamTextParams['messages'] = []
|
||||
const isVision = model.id.includes('vision') || model.id.includes('gpt-4') // 简单的视觉模型检测
|
||||
|
||||
for (const message of messages) {
|
||||
const sdkMessage = await convertMessageToSdkParam(message, isVision)
|
||||
sdkMessages.push(sdkMessage)
|
||||
}
|
||||
|
||||
return sdkMessages
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 AI SDK 流式参数
|
||||
* 这是主要的参数构建函数,整合所有转换逻辑
|
||||
*/
|
||||
export async function buildStreamTextParams(
|
||||
messages: Message[],
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
options: {
|
||||
maxTokens?: number
|
||||
mcpTools?: MCPTool[]
|
||||
enableTools?: boolean
|
||||
} = {}
|
||||
): Promise<StreamTextParams> {
|
||||
const { maxTokens, mcpTools, enableTools = false } = options
|
||||
|
||||
// 转换消息
|
||||
const sdkMessages = await convertMessagesToSdkMessages(messages, model)
|
||||
|
||||
// 构建系统提示
|
||||
let systemPrompt = assistant.prompt || ''
|
||||
if (mcpTools && mcpTools.length > 0) {
|
||||
systemPrompt = await buildSystemPromptWithTools(systemPrompt, mcpTools, assistant)
|
||||
}
|
||||
|
||||
// 构建基础参数
|
||||
const params: StreamTextParams = {
|
||||
messages: sdkMessages,
|
||||
maxTokens: maxTokens || 1000,
|
||||
temperature: getTemperature(assistant, model),
|
||||
topP: getTopP(assistant, model),
|
||||
system: systemPrompt || undefined,
|
||||
...getCustomParameters(assistant)
|
||||
}
|
||||
|
||||
// 添加工具(如果启用且有工具)
|
||||
if (enableTools && mcpTools && mcpTools.length > 0) {
|
||||
// TODO: 暂时注释掉工具支持,等类型问题解决后再启用
|
||||
// params.tools = convertMcpToolsToSdkTools(mcpTools)
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建非流式的 generateText 参数
|
||||
*/
|
||||
export async function buildGenerateTextParams(
|
||||
messages: Message[],
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
options: {
|
||||
maxTokens?: number
|
||||
mcpTools?: MCPTool[]
|
||||
enableTools?: boolean
|
||||
} = {}
|
||||
): Promise<any> {
|
||||
// 复用流式参数的构建逻辑
|
||||
return await buildStreamTextParams(messages, assistant, model, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取自定义参数
|
||||
* 从 assistant 设置中提取自定义参数
|
||||
*/
|
||||
export function getCustomParameters(assistant: Assistant): Record<string, any> {
|
||||
return (
|
||||
assistant?.settings?.customParameters?.reduce((acc, param) => {
|
||||
if (!param.name?.trim()) {
|
||||
return acc
|
||||
}
|
||||
if (param.type === 'json') {
|
||||
const value = param.value as string
|
||||
if (value === 'undefined') {
|
||||
return { ...acc, [param.name]: undefined }
|
||||
}
|
||||
try {
|
||||
return { ...acc, [param.name]: JSON.parse(value) }
|
||||
} catch {
|
||||
return { ...acc, [param.name]: value }
|
||||
}
|
||||
}
|
||||
return {
|
||||
...acc,
|
||||
[param.name]: param.value
|
||||
}
|
||||
}, {}) || {}
|
||||
)
|
||||
}
|
||||
@ -37,7 +37,7 @@ import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils
|
||||
import { findLast, isEmpty, takeRight } from 'lodash'
|
||||
|
||||
import AiProvider from '../aiCore'
|
||||
import store from '../store'
|
||||
import AiProviderNew from '../aiCore/index_new'
|
||||
import {
|
||||
getAssistantProvider,
|
||||
getAssistantSettings,
|
||||
@ -313,7 +313,7 @@ export async function fetchChatCompletion({
|
||||
console.log('fetchChatCompletion', messages, assistant)
|
||||
|
||||
const provider = getAssistantProvider(assistant)
|
||||
const AI = new AiProvider(provider)
|
||||
const AI = new AiProviderNew(provider)
|
||||
|
||||
// Make sure that 'Clear Context' works for all scenarios including external tool and normal chat.
|
||||
messages = filterContextMessages(messages)
|
||||
|
||||
@ -4,7 +4,8 @@
|
||||
"src/renderer/src/**/*",
|
||||
"src/preload/*.d.ts",
|
||||
"local/src/renderer/**/*",
|
||||
"packages/shared/**/*"
|
||||
"packages/shared/**/*",
|
||||
"packages/aiCore/src/**/*"
|
||||
],
|
||||
"compilerOptions": {
|
||||
"composite": true,
|
||||
@ -14,7 +15,8 @@
|
||||
"paths": {
|
||||
"@renderer/*": ["src/renderer/src/*"],
|
||||
"@shared/*": ["packages/shared/*"],
|
||||
"@types": ["src/renderer/src/types/index.ts"]
|
||||
"@types": ["src/renderer/src/types/index.ts"],
|
||||
"@cherry-studio/ai-core": ["packages/aiCore/src/"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -960,7 +960,7 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@cherry-studio/ai-core@workspace:packages/aiCore":
|
||||
"@cherry-studio/ai-core@workspace:*, @cherry-studio/ai-core@workspace:packages/aiCore":
|
||||
version: 0.0.0-use.local
|
||||
resolution: "@cherry-studio/ai-core@workspace:packages/aiCore"
|
||||
dependencies:
|
||||
@ -6392,6 +6392,7 @@ __metadata:
|
||||
"@agentic/tavily": "npm:^7.3.3"
|
||||
"@ant-design/v5-patch-for-react-19": "npm:^1.0.3"
|
||||
"@anthropic-ai/sdk": "npm:^0.41.0"
|
||||
"@cherry-studio/ai-core": "workspace:*"
|
||||
"@cherrystudio/embedjs": "npm:^0.1.31"
|
||||
"@cherrystudio/embedjs-libsql": "npm:^0.1.31"
|
||||
"@cherrystudio/embedjs-loader-csv": "npm:^0.1.31"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user