mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-31 00:10:22 +08:00
fix: migrate to v5-patch2
This commit is contained in:
parent
def685921c
commit
1ea8266280
@ -272,7 +272,6 @@
|
||||
"winston": "^3.17.0",
|
||||
"winston-daily-rotate-file": "^5.0.0",
|
||||
"word-extractor": "^1.0.4",
|
||||
"zhipu-ai-provider": "0.2.0-beta.1",
|
||||
"zipread": "^1.3.3",
|
||||
"zod": "^3.25.74"
|
||||
},
|
||||
|
||||
@ -37,19 +37,6 @@ const configHandlers: {
|
||||
resourceName: azureProvider.resourceName
|
||||
})
|
||||
}
|
||||
// 'google-vertex': (builder, provider) => {
|
||||
// const vertexBuilder = builder as ProviderConfigBuilder<'google-vertex'>
|
||||
// const vertexProvider = provider as CompleteProviderConfig<'google-vertex'>
|
||||
// vertexBuilder
|
||||
// .withGoogleVertexConfig({
|
||||
// project: vertexProvider.project,
|
||||
// location: vertexProvider.location
|
||||
// })
|
||||
// .withGoogleCredentials({
|
||||
// clientEmail: vertexProvider.googleCredentials?.clientEmail || '',
|
||||
// privateKey: vertexProvider.googleCredentials?.privateKey || ''
|
||||
// })
|
||||
// }
|
||||
}
|
||||
|
||||
export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
|
||||
@ -117,42 +104,6 @@ export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Google 特定配置
|
||||
*/
|
||||
// withGoogleVertexConfig(options: { project?: string; location?: string }): T extends 'google-vertex' ? this : never
|
||||
// withGoogleVertexConfig(options: any): any {
|
||||
// if (this.providerId === 'google-vertex') {
|
||||
// const googleConfig = this.config as CompleteProviderConfig<'google-vertex'>
|
||||
// if (options.project) {
|
||||
// googleConfig.project = options.project
|
||||
// }
|
||||
// if (options.location) {
|
||||
// googleConfig.location = options.location
|
||||
// if (options.location === 'global') {
|
||||
// googleConfig.baseURL = 'https://aiplatform.googleapis.com'
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// return this
|
||||
// }
|
||||
|
||||
withGoogleCredentials(credentials: {
|
||||
clientEmail: string
|
||||
privateKey: string
|
||||
}): T extends 'google-vertex' ? this : never
|
||||
withGoogleCredentials(): any {
|
||||
// withGoogleCredentials(credentials: any): any {
|
||||
// if (this.providerId === 'google-vertex') {
|
||||
// const vertexConfig = this.config as CompleteProviderConfig<'google-vertex'>
|
||||
// vertexConfig.googleCredentials = {
|
||||
// clientEmail: credentials.clientEmail,
|
||||
// privateKey: formatPrivateKey(credentials.privateKey)
|
||||
// }
|
||||
// }
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置自定义参数
|
||||
*/
|
||||
|
||||
@ -16,6 +16,9 @@ import { type ProviderConfig } from './types'
|
||||
export class AiProviderRegistry {
|
||||
private static instance: AiProviderRegistry
|
||||
private registry = new Map<string, ProviderConfig>()
|
||||
// 动态注册扩展
|
||||
private dynamicMappings = new Map<string, string>()
|
||||
private dynamicProviders = new Set<string>()
|
||||
|
||||
private constructor() {
|
||||
this.initializeProviders()
|
||||
@ -30,11 +33,10 @@ export class AiProviderRegistry {
|
||||
|
||||
/**
|
||||
* 初始化所有支持的 Providers
|
||||
* 基于 AI SDK 官方文档: https://ai-sdk.dev/providers/ai-sdk-providers
|
||||
* 基于 AI SDK 官方文档: https://v5.ai-sdk.dev/providers/ai-sdk-providers
|
||||
*/
|
||||
private initializeProviders(): void {
|
||||
const providers: ProviderConfig[] = [
|
||||
// 官方 AI SDK Providers (19个)
|
||||
{
|
||||
id: 'openai',
|
||||
name: 'OpenAI',
|
||||
@ -67,20 +69,6 @@ export class AiProviderRegistry {
|
||||
creator: createGoogleGenerativeAI,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
// {
|
||||
// id: 'google-vertex',
|
||||
// name: 'Google Vertex AI',
|
||||
// import: () => import('@ai-sdk/google-vertex/edge'),
|
||||
// creatorFunctionName: 'createVertex',
|
||||
// supportsImageGeneration: true
|
||||
// },
|
||||
// {
|
||||
// id: 'mistral',
|
||||
// name: 'Mistral AI',
|
||||
// import: () => import('@ai-sdk/mistral'),
|
||||
// creatorFunctionName: 'createMistral',
|
||||
// supportsImageGeneration: false
|
||||
// },
|
||||
{
|
||||
id: 'xai',
|
||||
name: 'xAI (Grok)',
|
||||
@ -93,111 +81,12 @@ export class AiProviderRegistry {
|
||||
creator: createAzure,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
// {
|
||||
// id: 'bedrock',
|
||||
// name: 'Amazon Bedrock',
|
||||
// import: () => import('@ai-sdk/amazon-bedrock'),
|
||||
// creatorFunctionName: 'createAmazonBedrock',
|
||||
// supportsImageGeneration: false
|
||||
// },
|
||||
// {
|
||||
// id: 'cohere',
|
||||
// name: 'Cohere',
|
||||
// import: () => import('@ai-sdk/cohere'),
|
||||
// creatorFunctionName: 'createCohere',
|
||||
// supportsImageGeneration: false
|
||||
// },
|
||||
// {
|
||||
// id: 'groq',
|
||||
// name: 'Groq',
|
||||
// import: () => import('@ai-sdk/groq'),
|
||||
// creatorFunctionName: 'createGroq',
|
||||
// supportsImageGeneration: false
|
||||
// },
|
||||
// {
|
||||
// id: 'together',
|
||||
// name: 'Together.ai',
|
||||
// import: () => import('@ai-sdk/togetherai'),
|
||||
// creatorFunctionName: 'createTogetherAI',
|
||||
// supportsImageGeneration: true
|
||||
// },
|
||||
// {
|
||||
// id: 'fireworks',
|
||||
// name: 'Fireworks',
|
||||
// import: () => import('@ai-sdk/fireworks'),
|
||||
// creatorFunctionName: 'createFireworks',
|
||||
// supportsImageGeneration: true
|
||||
// },
|
||||
{
|
||||
id: 'deepseek',
|
||||
name: 'DeepSeek',
|
||||
creator: createDeepSeek,
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
// {
|
||||
// id: 'cerebras',
|
||||
// name: 'Cerebras',
|
||||
// import: () => import('@ai-sdk/cerebras'),
|
||||
// creatorFunctionName: 'createCerebras',
|
||||
// supportsImageGeneration: false
|
||||
// },
|
||||
// {
|
||||
// id: 'deepinfra',
|
||||
// name: 'DeepInfra',
|
||||
// import: () => import('@ai-sdk/deepinfra'),
|
||||
// creatorFunctionName: 'createDeepInfra',
|
||||
// supportsImageGeneration: false
|
||||
// },
|
||||
// {
|
||||
// id: 'replicate',
|
||||
// name: 'Replicate',
|
||||
// import: () => import('@ai-sdk/replicate'),
|
||||
// creatorFunctionName: 'createReplicate',
|
||||
// supportsImageGeneration: true
|
||||
// },
|
||||
// {
|
||||
// id: 'perplexity',
|
||||
// name: 'Perplexity',
|
||||
// import: () => import('@ai-sdk/perplexity'),
|
||||
// creatorFunctionName: 'createPerplexity',
|
||||
// supportsImageGeneration: false
|
||||
// },
|
||||
// {
|
||||
// id: 'fal',
|
||||
// name: 'Fal AI',
|
||||
// import: () => import('@ai-sdk/fal'),
|
||||
// creatorFunctionName: 'createFal',
|
||||
// supportsImageGeneration: false
|
||||
// },
|
||||
// {
|
||||
// id: 'vercel',
|
||||
// name: 'Vercel',
|
||||
// import: () => import('@ai-sdk/vercel'),
|
||||
// creatorFunctionName: 'createVercel'
|
||||
// },
|
||||
|
||||
// 社区 Providers (5个)
|
||||
// {
|
||||
// id: 'ollama',
|
||||
// name: 'Ollama',
|
||||
// import: () => import('ollama-ai-provider'),
|
||||
// creatorFunctionName: 'createOllama',
|
||||
// supportsImageGeneration: false
|
||||
// },
|
||||
// {
|
||||
// id: 'anthropic-vertex',
|
||||
// name: 'Anthropic Vertex AI',
|
||||
// import: () => import('anthropic-vertex-ai'),
|
||||
// creatorFunctionName: 'createAnthropicVertex',
|
||||
// supportsImageGeneration: false
|
||||
// },
|
||||
// {
|
||||
// id: 'openrouter',
|
||||
// name: 'OpenRouter',
|
||||
// import: () => import('@openrouter/ai-sdk-provider'),
|
||||
// creatorFunctionName: 'createOpenRouter',
|
||||
// supportsImageGeneration: false
|
||||
// }
|
||||
]
|
||||
|
||||
providers.forEach((config) => {
|
||||
@ -243,28 +132,92 @@ export class AiProviderRegistry {
|
||||
this.registry.set(config.id, config)
|
||||
}
|
||||
|
||||
/**
|
||||
* 动态注册Provider并支持映射关系
|
||||
*/
|
||||
public registerDynamicProvider(
|
||||
config: ProviderConfig & {
|
||||
mappings?: Record<string, string>
|
||||
}
|
||||
): boolean {
|
||||
try {
|
||||
// 验证配置
|
||||
if (!config.id || config.id.trim() === '') {
|
||||
console.error('Provider ID cannot be empty')
|
||||
return false
|
||||
}
|
||||
|
||||
// 注册provider
|
||||
this.registerProvider(config)
|
||||
|
||||
// 记录为动态provider
|
||||
this.dynamicProviders.add(config.id)
|
||||
|
||||
// 添加映射关系(如果提供)
|
||||
if (config.mappings) {
|
||||
Object.entries(config.mappings).forEach(([key, value]) => {
|
||||
this.dynamicMappings.set(key, value)
|
||||
})
|
||||
}
|
||||
|
||||
return true
|
||||
} catch (error) {
|
||||
console.error(`Failed to register provider ${config.id}:`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量注册多个动态Providers
|
||||
*/
|
||||
public registerMultipleProviders(
|
||||
configs: (ProviderConfig & {
|
||||
mappings?: Record<string, string>
|
||||
})[]
|
||||
): number {
|
||||
let successCount = 0
|
||||
configs.forEach((config) => {
|
||||
if (this.registerDynamicProvider(config)) {
|
||||
successCount++
|
||||
}
|
||||
})
|
||||
return successCount
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取Provider映射(包括动态映射)
|
||||
*/
|
||||
public getProviderMapping(providerId: string): string | undefined {
|
||||
return this.dynamicMappings.get(providerId) || (this.dynamicProviders.has(providerId) ? providerId : undefined)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否为动态注册的Provider
|
||||
*/
|
||||
public isDynamicProvider(providerId: string): boolean {
|
||||
return this.dynamicProviders.has(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有动态Provider映射
|
||||
*/
|
||||
public getAllDynamicMappings(): Record<string, string> {
|
||||
return Object.fromEntries(this.dynamicMappings)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有动态注册的Providers
|
||||
*/
|
||||
public getDynamicProviders(): string[] {
|
||||
return Array.from(this.dynamicProviders)
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理资源
|
||||
*/
|
||||
public cleanup(): void {
|
||||
this.registry.clear()
|
||||
}
|
||||
|
||||
// /**
|
||||
// * 获取兼容现有实现的注册表格式
|
||||
// */
|
||||
// public getCompatibleRegistry(): Record<string, { import: () => Promise<any>; creatorFunctionName: string }> {
|
||||
// const compatibleRegistry: Record<string, { import: () => Promise<any>; creatorFunctionName: string }> = {}
|
||||
|
||||
// this.getAllProviders().forEach((provider) => {
|
||||
// compatibleRegistry[provider.id] = {
|
||||
// import: provider.import,
|
||||
// creatorFunctionName: provider.creatorFunctionName
|
||||
// }
|
||||
// })
|
||||
|
||||
// return compatibleRegistry
|
||||
// }
|
||||
}
|
||||
|
||||
// 导出单例实例
|
||||
@ -275,5 +228,16 @@ export const getProvider = (id: string) => aiProviderRegistry.getProvider(id)
|
||||
export const getAllProviders = () => aiProviderRegistry.getAllProviders()
|
||||
export const isProviderSupported = (id: string) => aiProviderRegistry.isSupported(id)
|
||||
export const registerProvider = (config: ProviderConfig) => aiProviderRegistry.registerProvider(config)
|
||||
|
||||
// 动态注册相关便捷函数
|
||||
export const registerDynamicProvider = (config: ProviderConfig & { mappings?: Record<string, string> }) =>
|
||||
aiProviderRegistry.registerDynamicProvider(config)
|
||||
export const registerMultipleProviders = (configs: (ProviderConfig & { mappings?: Record<string, string> })[]) =>
|
||||
aiProviderRegistry.registerMultipleProviders(configs)
|
||||
export const getProviderMapping = (providerId: string) => aiProviderRegistry.getProviderMapping(providerId)
|
||||
export const isDynamicProvider = (providerId: string) => aiProviderRegistry.isDynamicProvider(providerId)
|
||||
export const getAllDynamicMappings = () => aiProviderRegistry.getAllDynamicMappings()
|
||||
export const getDynamicProviders = () => aiProviderRegistry.getDynamicProviders()
|
||||
|
||||
// 兼容现有实现的导出
|
||||
// export const PROVIDER_REGISTRY = aiProviderRegistry.getCompatibleRegistry()
|
||||
|
||||
@ -6,11 +6,30 @@ import { type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
|
||||
import { type XaiProviderSettings } from '@ai-sdk/xai'
|
||||
|
||||
export interface ExtensibleProviderSettingsMap {
|
||||
// 基础的静态providers
|
||||
openai: OpenAIProviderSettings
|
||||
'openai-responses': OpenAIProviderSettings
|
||||
'openai-compatible': OpenAICompatibleProviderSettings
|
||||
anthropic: AnthropicProviderSettings
|
||||
google: GoogleGenerativeAIProviderSettings
|
||||
xai: XaiProviderSettings
|
||||
azure: AzureOpenAIProviderSettings
|
||||
deepseek: DeepSeekProviderSettings
|
||||
}
|
||||
|
||||
// 动态扩展的provider类型注册表
|
||||
export interface DynamicProviderRegistry {
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
// 合并基础和动态provider类型
|
||||
export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProviderRegistry
|
||||
|
||||
/**
|
||||
* Provider 相关核心类型定义
|
||||
* 只定义必要的接口,其他类型直接使用 AI SDK
|
||||
*/
|
||||
export type ProviderId = keyof ProviderSettingsMap & string
|
||||
|
||||
// Provider 配置接口 - 支持灵活的创建方式
|
||||
export interface ProviderConfig {
|
||||
@ -45,57 +64,23 @@ export class ProviderError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
// 类型安全的 Provider Settings 映射
|
||||
export type ProviderSettingsMap = {
|
||||
openai: OpenAIProviderSettings
|
||||
'openai-responses': OpenAIProviderSettings
|
||||
'openai-compatible': OpenAICompatibleProviderSettings
|
||||
// openrouter: OpenRouterProviderSettings
|
||||
anthropic: AnthropicProviderSettings
|
||||
google: GoogleGenerativeAIProviderSettings
|
||||
// 'google-vertex': GoogleVertexProviderSettings
|
||||
// mistral: MistralProviderSettings
|
||||
xai: XaiProviderSettings
|
||||
azure: AzureOpenAIProviderSettings
|
||||
// bedrock: AmazonBedrockProviderSettings
|
||||
// cohere: CohereProviderSettings
|
||||
// groq: GroqProviderSettings
|
||||
// together: TogetherAIProviderSettings
|
||||
// fireworks: FireworksProviderSettings
|
||||
deepseek: DeepSeekProviderSettings
|
||||
// cerebras: CerebrasProviderSettings
|
||||
// deepinfra: DeepInfraProviderSettings
|
||||
// replicate: ReplicateProviderSettings
|
||||
// perplexity: PerplexityProviderSettings
|
||||
// fal: FalProviderSettings
|
||||
// vercel: VercelProviderSettings
|
||||
// ollama: OllamaProviderSettings
|
||||
// 'anthropic-vertex': AnthropicVertexProviderSettings
|
||||
// 动态ProviderId类型 - 支持运行时扩展
|
||||
export type ProviderId = keyof ExtensibleProviderSettingsMap | string
|
||||
|
||||
// Provider类型注册工具
|
||||
export interface ProviderTypeRegistrar {
|
||||
registerProviderType<T extends string, S>(providerId: T, settingsType: S): void
|
||||
getProviderSettings<T extends string>(providerId: T): any
|
||||
}
|
||||
|
||||
// 重新导出所有类型供外部使用
|
||||
export type {
|
||||
// AmazonBedrockProviderSettings,
|
||||
AnthropicProviderSettings,
|
||||
// AnthropicVertexProviderSettings,
|
||||
AzureOpenAIProviderSettings,
|
||||
// CerebrasProviderSettings,
|
||||
// CohereProviderSettings,
|
||||
// DeepInfraProviderSettings,
|
||||
DeepSeekProviderSettings,
|
||||
// FalProviderSettings,
|
||||
// FireworksProviderSettings,
|
||||
GoogleGenerativeAIProviderSettings,
|
||||
// GoogleVertexProviderSettings,
|
||||
// GroqProviderSettings,
|
||||
// MistralProviderSettings,
|
||||
// OllamaProviderSettings,
|
||||
OpenAICompatibleProviderSettings,
|
||||
OpenAIProviderSettings,
|
||||
// OpenRouterProviderSettings,
|
||||
// PerplexityProviderSettings,
|
||||
// ReplicateProviderSettings,
|
||||
// TogetherAIProviderSettings,
|
||||
// VercelProviderSettings,
|
||||
XaiProviderSettings
|
||||
}
|
||||
// 新的provider类型已经在上面直接export,不需要重复导出
|
||||
|
||||
@ -120,7 +120,19 @@ export {
|
||||
} from './core/options'
|
||||
|
||||
// ==================== 工具函数 ====================
|
||||
export { getAllProviders, getProvider, isProviderSupported, registerProvider } from './core/providers/registry'
|
||||
export {
|
||||
getAllDynamicMappings,
|
||||
getAllProviders,
|
||||
getDynamicProviders,
|
||||
getProvider,
|
||||
getProviderMapping,
|
||||
isDynamicProvider,
|
||||
isProviderSupported,
|
||||
// 动态注册功能
|
||||
registerDynamicProvider,
|
||||
registerMultipleProviders,
|
||||
registerProvider
|
||||
} from './core/providers/registry'
|
||||
|
||||
// ==================== Provider 配置工厂 ====================
|
||||
export {
|
||||
|
||||
@ -114,12 +114,12 @@ export class AiSdkToChunkAdapter {
|
||||
}
|
||||
break
|
||||
case 'reasoning-delta':
|
||||
final.reasoningContent += chunk.text || ''
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: final.reasoningContent || '',
|
||||
thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0
|
||||
})
|
||||
final.reasoningContent += chunk.text || ''
|
||||
break
|
||||
case 'reasoning-end':
|
||||
this.onChunk({
|
||||
|
||||
@ -4,8 +4,8 @@
|
||||
* 提供工具调用相关的处理API,每个交互使用一个新的实例
|
||||
*/
|
||||
|
||||
import { ToolCallUnion, ToolResultUnion, ToolSet } from '@cherrystudio/ai-core'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { ToolSet, TypedToolCall, TypedToolResult } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import { BaseTool, MCPToolResponse, ToolCallResponse } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import { type ProviderMetadata } from 'ai'
|
||||
@ -14,6 +14,8 @@ import { type ProviderMetadata } from 'ai'
|
||||
// WebSearchPluginConfig
|
||||
// } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin'
|
||||
|
||||
const logger = loggerService.withContext('ToolCallChunkHandler')
|
||||
|
||||
/**
|
||||
* 工具调用处理器类
|
||||
*/
|
||||
@ -84,7 +86,7 @@ export class ToolCallChunkHandler {
|
||||
case 'tool-input-delta': {
|
||||
const toolCall = this.activeToolCalls.get(chunk.id)
|
||||
if (!toolCall) {
|
||||
Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
return
|
||||
}
|
||||
toolCall.args += chunk.delta
|
||||
@ -94,7 +96,7 @@ export class ToolCallChunkHandler {
|
||||
const toolCall = this.activeToolCalls.get(chunk.id)
|
||||
this.activeToolCalls.delete(chunk.id)
|
||||
if (!toolCall) {
|
||||
Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
return
|
||||
}
|
||||
const toolResponse: ToolCallResponse = {
|
||||
@ -104,7 +106,7 @@ export class ToolCallChunkHandler {
|
||||
status: 'pending',
|
||||
toolCallId: toolCall.toolCallId
|
||||
}
|
||||
console.log('toolResponse', toolResponse)
|
||||
logger.debug('toolResponse', toolResponse)
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_PENDING,
|
||||
responses: [toolResponse]
|
||||
@ -134,12 +136,12 @@ export class ToolCallChunkHandler {
|
||||
public handleToolCall(
|
||||
chunk: {
|
||||
type: 'tool-call'
|
||||
} & ToolCallUnion<ToolSet>
|
||||
} & TypedToolCall<ToolSet>
|
||||
): void {
|
||||
const { toolCallId, toolName, input: args, providerExecuted } = chunk
|
||||
|
||||
if (!toolCallId || !toolName) {
|
||||
Logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`)
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`)
|
||||
return
|
||||
}
|
||||
|
||||
@ -148,7 +150,7 @@ export class ToolCallChunkHandler {
|
||||
// 根据 providerExecuted 标志区分处理逻辑
|
||||
if (providerExecuted) {
|
||||
// 如果是 Provider 执行的工具(如 web_search)
|
||||
Logger.info(`[ToolCallChunkHandler] Handling provider-executed tool: ${toolName}`)
|
||||
logger.info(`[ToolCallChunkHandler] Handling provider-executed tool: ${toolName}`)
|
||||
tool = {
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
@ -157,7 +159,7 @@ export class ToolCallChunkHandler {
|
||||
}
|
||||
} else if (toolName.startsWith('builtin_')) {
|
||||
// 如果是内置工具,沿用现有逻辑
|
||||
Logger.info(`[ToolCallChunkHandler] Handling builtin tool: ${toolName}`)
|
||||
logger.info(`[ToolCallChunkHandler] Handling builtin tool: ${toolName}`)
|
||||
tool = {
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
@ -166,10 +168,10 @@ export class ToolCallChunkHandler {
|
||||
}
|
||||
} else {
|
||||
// 如果是客户端执行的 MCP 工具,沿用现有逻辑
|
||||
Logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`)
|
||||
logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`)
|
||||
const mcpTool = this.mcpTools.find((t) => t.name === toolName)
|
||||
if (!mcpTool) {
|
||||
Logger.warn(`[ToolCallChunkHandler] MCP tool not found: ${toolName}`)
|
||||
logger.warn(`[ToolCallChunkHandler] MCP tool not found: ${toolName}`)
|
||||
return
|
||||
}
|
||||
tool = mcpTool
|
||||
@ -207,19 +209,19 @@ export class ToolCallChunkHandler {
|
||||
public handleToolResult(
|
||||
chunk: {
|
||||
type: 'tool-result'
|
||||
} & ToolResultUnion<ToolSet>
|
||||
} & TypedToolResult<ToolSet>
|
||||
): void {
|
||||
const { toolCallId, output, input } = chunk
|
||||
|
||||
if (!toolCallId) {
|
||||
Logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool result chunk: missing toolCallId`)
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool result chunk: missing toolCallId`)
|
||||
return
|
||||
}
|
||||
|
||||
// 查找对应的工具调用信息
|
||||
const toolCallInfo = this.activeToolCalls.get(toolCallId)
|
||||
if (!toolCallInfo) {
|
||||
Logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ import {
|
||||
import { convertBase64ImageToAwsBedrockFormat } from '@renderer/utils/aws-bedrock-utils'
|
||||
import {
|
||||
awsBedrockToolUseToMcpTool,
|
||||
isEnabledToolUse,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToAwsBedrockMessage,
|
||||
mcpToolsToAwsBedrockTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
@ -393,7 +393,7 @@ export class AwsBedrockAPIClient extends BaseApiClient<
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
// 3. 处理消息
|
||||
|
||||
@ -69,10 +69,10 @@ function providerToAiSdkConfig(actualProvider: Provider): {
|
||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||
// console.log('aiSdkProviderId', aiSdkProviderId)
|
||||
// 如果provider是openai,则使用strict模式并且默认responses api
|
||||
const actualProviderId = actualProvider.type
|
||||
const actualProviderType = actualProvider.type
|
||||
const openaiResponseOptions =
|
||||
// 对于实际是openai的需要走responses,aiCore内部会判断model是否可用responses
|
||||
actualProviderId === 'openai-response'
|
||||
actualProviderType === 'openai-response'
|
||||
? {
|
||||
mode: 'responses'
|
||||
}
|
||||
|
||||
@ -13,14 +13,13 @@ export default definePlugin({
|
||||
return new TransformStream<TextStreamPart<ToolSet>, TextStreamPart<ToolSet>>({
|
||||
transform(chunk: TextStreamPart<ToolSet>, controller: TransformStreamDefaultController<TextStreamPart<ToolSet>>) {
|
||||
// === 处理 reasoning 类型 ===
|
||||
if (chunk.type === 'reasoning') {
|
||||
if (!hasStartedThinking) {
|
||||
hasStartedThinking = true
|
||||
thinkingStartTime = performance.now()
|
||||
reasoningBlockId = chunk.id
|
||||
}
|
||||
if (chunk.type === 'reasoning-start') {
|
||||
controller.enqueue(chunk)
|
||||
hasStartedThinking = true
|
||||
thinkingStartTime = performance.now()
|
||||
reasoningBlockId = chunk.id
|
||||
} else if (chunk.type === 'reasoning-delta') {
|
||||
accumulatedThinkingContent += chunk.text
|
||||
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
providerMetadata: {
|
||||
@ -32,7 +31,7 @@ export default definePlugin({
|
||||
}
|
||||
}
|
||||
})
|
||||
} else if (hasStartedThinking) {
|
||||
} else if (chunk.type === 'reasoning-end' && hasStartedThinking) {
|
||||
controller.enqueue({
|
||||
type: 'reasoning-end',
|
||||
id: reasoningBlockId,
|
||||
@ -47,28 +46,8 @@ export default definePlugin({
|
||||
hasStartedThinking = false
|
||||
thinkingStartTime = 0
|
||||
reasoningBlockId = ''
|
||||
if (chunk.type !== 'reasoning-end') {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
} else {
|
||||
if (chunk.type !== 'reasoning-end') {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
flush(controller) {
|
||||
if (hasStartedThinking) {
|
||||
controller.enqueue({
|
||||
type: 'reasoning-end',
|
||||
id: reasoningBlockId,
|
||||
providerMetadata: {
|
||||
metadata: {
|
||||
thinking_millsec: performance.now() - thinkingStartTime,
|
||||
thinking_content: accumulatedThinkingContent
|
||||
}
|
||||
}
|
||||
})
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@ -0,0 +1,103 @@
|
||||
import type { Provider } from '@renderer/types'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { getAiSdkProviderId } from '../factory'
|
||||
|
||||
// Mock the external dependencies
|
||||
vi.mock('@cherrystudio/ai-core', () => ({
|
||||
registerMultipleProviders: vi.fn(() => 4), // Mock successful registration of 4 providers
|
||||
getProviderMapping: vi.fn((id: string) => {
|
||||
// Mock dynamic mappings
|
||||
const mappings: Record<string, string> = {
|
||||
openrouter: 'openrouter',
|
||||
'google-vertex': 'google-vertex',
|
||||
vertexai: 'google-vertex',
|
||||
bedrock: 'bedrock',
|
||||
'aws-bedrock': 'bedrock',
|
||||
zhipu: 'zhipu'
|
||||
}
|
||||
return mappings[id]
|
||||
}),
|
||||
AiCore: {
|
||||
isSupported: vi.fn(() => true)
|
||||
}
|
||||
}))
|
||||
|
||||
// Mock the provider configs
|
||||
vi.mock('../providerConfigs', () => ({
|
||||
initializeNewProviders: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: () => ({
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn()
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
function createTestProvider(id: string, type: string): Provider {
|
||||
return {
|
||||
id,
|
||||
type,
|
||||
name: `Test ${id}`,
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'test-host'
|
||||
} as Provider
|
||||
}
|
||||
|
||||
describe('Integrated Provider Registry', () => {
|
||||
describe('Provider ID Resolution', () => {
|
||||
it('should resolve openrouter provider correctly', () => {
|
||||
const provider = createTestProvider('openrouter', 'openrouter')
|
||||
const result = getAiSdkProviderId(provider)
|
||||
expect(result).toBe('openrouter')
|
||||
})
|
||||
|
||||
it('should resolve google-vertex provider correctly', () => {
|
||||
const provider = createTestProvider('google-vertex', 'vertexai')
|
||||
const result = getAiSdkProviderId(provider)
|
||||
expect(result).toBe('google-vertex')
|
||||
})
|
||||
|
||||
it('should resolve bedrock provider correctly', () => {
|
||||
const provider = createTestProvider('bedrock', 'aws-bedrock')
|
||||
const result = getAiSdkProviderId(provider)
|
||||
expect(result).toBe('bedrock')
|
||||
})
|
||||
|
||||
it('should resolve zhipu provider correctly', () => {
|
||||
const provider = createTestProvider('zhipu', 'zhipu')
|
||||
const result = getAiSdkProviderId(provider)
|
||||
expect(result).toBe('zhipu')
|
||||
})
|
||||
|
||||
it('should resolve provider type mapping correctly', () => {
|
||||
const provider = createTestProvider('vertex-test', 'vertexai')
|
||||
const result = getAiSdkProviderId(provider)
|
||||
expect(result).toBe('google-vertex')
|
||||
})
|
||||
|
||||
it('should handle static provider mappings', () => {
|
||||
const geminiProvider = createTestProvider('gemini', 'gemini')
|
||||
const result = getAiSdkProviderId(geminiProvider)
|
||||
expect(result).toBe('google')
|
||||
})
|
||||
|
||||
it('should fallback to provider.id for unknown providers', () => {
|
||||
const unknownProvider = createTestProvider('unknown-provider', 'unknown-type')
|
||||
const result = getAiSdkProviderId(unknownProvider)
|
||||
expect(result).toBe('unknown-provider')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Backward Compatibility', () => {
|
||||
it('should maintain compatibility with existing providers', () => {
|
||||
const grokProvider = createTestProvider('grok', 'grok')
|
||||
const result = getAiSdkProviderId(grokProvider)
|
||||
expect(result).toBe('xai')
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -10,9 +10,7 @@ export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'opena
|
||||
return 'anthropic'
|
||||
}
|
||||
// TODO:暂时注释,不清楚为什么排除,webSearch时会导致gemini模型走openai的逻辑
|
||||
// if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
|
||||
|
||||
if (id.startsWith('gemini') || id.startsWith('imagen')) {
|
||||
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
|
||||
return 'google'
|
||||
}
|
||||
|
||||
|
||||
@ -1,35 +1,50 @@
|
||||
import { AiCore, ProviderId } from '@cherrystudio/ai-core'
|
||||
import { AiCore, getProviderMapping, type ProviderId } from '@cherrystudio/ai-core'
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
const PROVIDER_MAPPING: Record<string, ProviderId> = {
|
||||
import { initializeNewProviders } from './providerConfigs'
|
||||
|
||||
// 初始化新的Provider注册系统
|
||||
initializeNewProviders()
|
||||
|
||||
// 静态Provider映射 - 核心providers
|
||||
const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
|
||||
// anthropic: 'anthropic',
|
||||
gemini: 'google',
|
||||
vertexai: 'google-vertex',
|
||||
'azure-openai': 'azure',
|
||||
'openai-response': 'openai',
|
||||
grok: 'xai'
|
||||
}
|
||||
|
||||
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
|
||||
const providerId = PROVIDER_MAPPING[provider.id]
|
||||
|
||||
if (providerId) {
|
||||
return providerId
|
||||
// 1. 首先检查静态映射
|
||||
const staticProviderId = STATIC_PROVIDER_MAPPING[provider.id]
|
||||
if (staticProviderId) {
|
||||
return staticProviderId
|
||||
}
|
||||
|
||||
const providerType = PROVIDER_MAPPING[provider.type] // 有些第三方需要映射到aicore对应sdk
|
||||
|
||||
if (providerType) {
|
||||
return providerType
|
||||
// 2. 检查动态注册的provider映射(使用aiCore的函数)
|
||||
const dynamicProviderId = getProviderMapping(provider.id)
|
||||
if (dynamicProviderId) {
|
||||
return dynamicProviderId as ProviderId
|
||||
}
|
||||
|
||||
// 3. 检查provider.type的静态映射
|
||||
const staticProviderType = STATIC_PROVIDER_MAPPING[provider.type]
|
||||
if (staticProviderType) {
|
||||
return staticProviderType
|
||||
}
|
||||
|
||||
// 4. 检查provider.type的动态映射
|
||||
const dynamicProviderType = getProviderMapping(provider.type)
|
||||
if (dynamicProviderType) {
|
||||
return dynamicProviderType as ProviderId
|
||||
}
|
||||
|
||||
// 5. 检查AiCore是否直接支持
|
||||
if (AiCore.isSupported(provider.id)) {
|
||||
return provider.id as ProviderId
|
||||
}
|
||||
// 先注释掉,会影响获取providerOptions
|
||||
// if (AiCore.isSupported(provider.type)) {
|
||||
// return provider.type as ProviderId
|
||||
// }
|
||||
|
||||
// 6. 最后的fallback
|
||||
return provider.id as ProviderId
|
||||
}
|
||||
|
||||
63
src/renderer/src/aiCore/provider/providerConfigs.ts
Normal file
63
src/renderer/src/aiCore/provider/providerConfigs.ts
Normal file
@ -0,0 +1,63 @@
|
||||
import type { ProviderConfig } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
|
||||
const logger = loggerService.withContext('ProviderConfigs')
|
||||
|
||||
/**
|
||||
* 新Provider配置定义
|
||||
* 定义了需要动态注册的AI Providers
|
||||
*/
|
||||
export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
|
||||
mappings?: Record<string, string>
|
||||
})[] = [
|
||||
{
|
||||
id: 'openrouter',
|
||||
name: 'OpenRouter',
|
||||
import: () => import('@openrouter/ai-sdk-provider'),
|
||||
creatorFunctionName: 'createOpenRouter',
|
||||
supportsImageGeneration: true,
|
||||
mappings: {
|
||||
openrouter: 'openrouter'
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'google-vertex',
|
||||
name: 'Google Vertex AI',
|
||||
import: () => import('@ai-sdk/google-vertex'),
|
||||
creatorFunctionName: 'createGoogleVertex',
|
||||
supportsImageGeneration: true,
|
||||
mappings: {
|
||||
'google-vertex': 'google-vertex',
|
||||
vertexai: 'google-vertex'
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'bedrock',
|
||||
name: 'Amazon Bedrock',
|
||||
import: () => import('@ai-sdk/amazon-bedrock'),
|
||||
creatorFunctionName: 'createAmazonBedrock',
|
||||
supportsImageGeneration: true,
|
||||
mappings: {
|
||||
'aws-bedrock': 'bedrock'
|
||||
}
|
||||
}
|
||||
] as const
|
||||
|
||||
/**
|
||||
* 初始化新的Providers
|
||||
* 使用aiCore的动态注册功能
|
||||
*/
|
||||
export async function initializeNewProviders(): Promise<void> {
|
||||
try {
|
||||
// 动态导入以避免循环依赖
|
||||
const { registerMultipleProviders } = await import('@cherrystudio/ai-core')
|
||||
|
||||
const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS)
|
||||
|
||||
if (successCount < NEW_PROVIDER_CONFIGS.length) {
|
||||
logger.warn('Some providers failed to register. Check previous error logs.')
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to initialize new providers:', error as Error)
|
||||
}
|
||||
}
|
||||
@ -46,7 +46,6 @@ import {
|
||||
findThinkingBlocks,
|
||||
getMainTextContent
|
||||
} from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
@ -79,17 +78,6 @@ export function getTimeout(model: Model): number {
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建系统提示词
|
||||
*/
|
||||
export async function buildSystemPromptWithTools(
|
||||
prompt: string,
|
||||
mcpTools?: MCPTool[],
|
||||
assistant?: Assistant
|
||||
): Promise<string> {
|
||||
return await buildSystemPrompt(prompt, mcpTools, assistant)
|
||||
}
|
||||
|
||||
/**
|
||||
* 提取文件内容
|
||||
*/
|
||||
|
||||
@ -5,7 +5,6 @@ export const SYSTEM_PROMPT_THRESHOLD = 128
|
||||
export const DEFAULT_KNOWLEDGE_DOCUMENT_COUNT = 6
|
||||
export const DEFAULT_KNOWLEDGE_THRESHOLD = 0.0
|
||||
export const DEFAULT_WEBSEARCH_RAG_DOCUMENT_COUNT = 1
|
||||
export const SYSTEM_PROMPT_THRESHOLD = 128
|
||||
|
||||
export const platform = window.electron?.process?.platform
|
||||
export const isMac = platform === 'darwin'
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
*/
|
||||
import { StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder'
|
||||
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
|
||||
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
|
||||
@ -26,7 +27,7 @@ import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils
|
||||
import { containsSupportedVariables, replacePromptVariables } from '@renderer/utils/prompt'
|
||||
import { isEmpty, takeRight } from 'lodash'
|
||||
|
||||
import AiProvider from '../aiCore'
|
||||
import AiProviderNew from '../aiCore/index_new'
|
||||
import {
|
||||
// getAssistantProvider,
|
||||
// getAssistantSettings,
|
||||
@ -395,11 +396,8 @@ export async function fetchChatCompletion({
|
||||
}) {
|
||||
console.log('fetchChatCompletion', messages, assistant)
|
||||
|
||||
const provider = getAssistantProvider(assistant)
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
// Make sure that 'Clear Context' works for all scenarios including external tool and normal chat.
|
||||
messages = filterContextMessages(messages)
|
||||
const AI = new AiProviderNew(assistant.model || getDefaultModel())
|
||||
const provider = AI.getActualProvider()
|
||||
|
||||
const mcpTools: MCPTool[] = []
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ import store from '@renderer/store'
|
||||
import { addMCPServer } from '@renderer/store/mcp'
|
||||
import {
|
||||
Assistant,
|
||||
BaseTool,
|
||||
MCPCallToolResponse,
|
||||
MCPServer,
|
||||
MCPTool,
|
||||
@ -492,12 +493,13 @@ export function getMcpServerByTool(tool: MCPTool) {
|
||||
return servers.find((s) => s.id === tool.serverId)
|
||||
}
|
||||
|
||||
export function isToolAutoApproved(tool: MCPTool, server?: MCPServer): boolean {
|
||||
if (tool.isBuiltIn) {
|
||||
export function isToolAutoApproved(tool: BaseTool, server?: MCPServer): boolean {
|
||||
if (tool.type === 'builtin') {
|
||||
return true
|
||||
}
|
||||
const effectiveServer = server ?? getMcpServerByTool(tool)
|
||||
return effectiveServer ? !effectiveServer.disabledAutoApproveTools?.includes(tool.name) : false
|
||||
const mcpTool = tool as MCPTool
|
||||
const effectiveServer = server ?? getMcpServerByTool(mcpTool)
|
||||
return effectiveServer ? !effectiveServer.disabledAutoApproveTools?.includes(mcpTool.name) : false
|
||||
}
|
||||
|
||||
export function parseToolUse(content: string, mcpTools: MCPTool[], startIdx: number = 0): ToolUseResponse[] {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user