fix: migrate to v5-patch2

This commit is contained in:
suyao 2025-08-03 23:09:19 +08:00
parent def685921c
commit 1ea8266280
No known key found for this signature in database
18 changed files with 372 additions and 314 deletions

View File

@ -272,7 +272,6 @@
"winston": "^3.17.0",
"winston-daily-rotate-file": "^5.0.0",
"word-extractor": "^1.0.4",
"zhipu-ai-provider": "0.2.0-beta.1",
"zipread": "^1.3.3",
"zod": "^3.25.74"
},

View File

@ -37,19 +37,6 @@ const configHandlers: {
resourceName: azureProvider.resourceName
})
}
// 'google-vertex': (builder, provider) => {
// const vertexBuilder = builder as ProviderConfigBuilder<'google-vertex'>
// const vertexProvider = provider as CompleteProviderConfig<'google-vertex'>
// vertexBuilder
// .withGoogleVertexConfig({
// project: vertexProvider.project,
// location: vertexProvider.location
// })
// .withGoogleCredentials({
// clientEmail: vertexProvider.googleCredentials?.clientEmail || '',
// privateKey: vertexProvider.googleCredentials?.privateKey || ''
// })
// }
}
export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
@ -117,42 +104,6 @@ export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
return this
}
/**
* Google
*/
// withGoogleVertexConfig(options: { project?: string; location?: string }): T extends 'google-vertex' ? this : never
// withGoogleVertexConfig(options: any): any {
// if (this.providerId === 'google-vertex') {
// const googleConfig = this.config as CompleteProviderConfig<'google-vertex'>
// if (options.project) {
// googleConfig.project = options.project
// }
// if (options.location) {
// googleConfig.location = options.location
// if (options.location === 'global') {
// googleConfig.baseURL = 'https://aiplatform.googleapis.com'
// }
// }
// }
// return this
// }
withGoogleCredentials(credentials: {
clientEmail: string
privateKey: string
}): T extends 'google-vertex' ? this : never
withGoogleCredentials(): any {
// withGoogleCredentials(credentials: any): any {
// if (this.providerId === 'google-vertex') {
// const vertexConfig = this.config as CompleteProviderConfig<'google-vertex'>
// vertexConfig.googleCredentials = {
// clientEmail: credentials.clientEmail,
// privateKey: formatPrivateKey(credentials.privateKey)
// }
// }
return this
}
/**
*
*/

View File

@ -16,6 +16,9 @@ import { type ProviderConfig } from './types'
export class AiProviderRegistry {
private static instance: AiProviderRegistry
private registry = new Map<string, ProviderConfig>()
// 动态注册扩展
private dynamicMappings = new Map<string, string>()
private dynamicProviders = new Set<string>()
private constructor() {
this.initializeProviders()
@ -30,11 +33,10 @@ export class AiProviderRegistry {
/**
* Providers
* AI SDK 官方文档: https://ai-sdk.dev/providers/ai-sdk-providers
* AI SDK 官方文档: https://v5.ai-sdk.dev/providers/ai-sdk-providers
*/
private initializeProviders(): void {
const providers: ProviderConfig[] = [
// 官方 AI SDK Providers (19个)
{
id: 'openai',
name: 'OpenAI',
@ -67,20 +69,6 @@ export class AiProviderRegistry {
creator: createGoogleGenerativeAI,
supportsImageGeneration: true
},
// {
// id: 'google-vertex',
// name: 'Google Vertex AI',
// import: () => import('@ai-sdk/google-vertex/edge'),
// creatorFunctionName: 'createVertex',
// supportsImageGeneration: true
// },
// {
// id: 'mistral',
// name: 'Mistral AI',
// import: () => import('@ai-sdk/mistral'),
// creatorFunctionName: 'createMistral',
// supportsImageGeneration: false
// },
{
id: 'xai',
name: 'xAI (Grok)',
@ -93,111 +81,12 @@ export class AiProviderRegistry {
creator: createAzure,
supportsImageGeneration: true
},
// {
// id: 'bedrock',
// name: 'Amazon Bedrock',
// import: () => import('@ai-sdk/amazon-bedrock'),
// creatorFunctionName: 'createAmazonBedrock',
// supportsImageGeneration: false
// },
// {
// id: 'cohere',
// name: 'Cohere',
// import: () => import('@ai-sdk/cohere'),
// creatorFunctionName: 'createCohere',
// supportsImageGeneration: false
// },
// {
// id: 'groq',
// name: 'Groq',
// import: () => import('@ai-sdk/groq'),
// creatorFunctionName: 'createGroq',
// supportsImageGeneration: false
// },
// {
// id: 'together',
// name: 'Together.ai',
// import: () => import('@ai-sdk/togetherai'),
// creatorFunctionName: 'createTogetherAI',
// supportsImageGeneration: true
// },
// {
// id: 'fireworks',
// name: 'Fireworks',
// import: () => import('@ai-sdk/fireworks'),
// creatorFunctionName: 'createFireworks',
// supportsImageGeneration: true
// },
{
id: 'deepseek',
name: 'DeepSeek',
creator: createDeepSeek,
supportsImageGeneration: false
}
// {
// id: 'cerebras',
// name: 'Cerebras',
// import: () => import('@ai-sdk/cerebras'),
// creatorFunctionName: 'createCerebras',
// supportsImageGeneration: false
// },
// {
// id: 'deepinfra',
// name: 'DeepInfra',
// import: () => import('@ai-sdk/deepinfra'),
// creatorFunctionName: 'createDeepInfra',
// supportsImageGeneration: false
// },
// {
// id: 'replicate',
// name: 'Replicate',
// import: () => import('@ai-sdk/replicate'),
// creatorFunctionName: 'createReplicate',
// supportsImageGeneration: true
// },
// {
// id: 'perplexity',
// name: 'Perplexity',
// import: () => import('@ai-sdk/perplexity'),
// creatorFunctionName: 'createPerplexity',
// supportsImageGeneration: false
// },
// {
// id: 'fal',
// name: 'Fal AI',
// import: () => import('@ai-sdk/fal'),
// creatorFunctionName: 'createFal',
// supportsImageGeneration: false
// },
// {
// id: 'vercel',
// name: 'Vercel',
// import: () => import('@ai-sdk/vercel'),
// creatorFunctionName: 'createVercel'
// },
// 社区 Providers (5个)
// {
// id: 'ollama',
// name: 'Ollama',
// import: () => import('ollama-ai-provider'),
// creatorFunctionName: 'createOllama',
// supportsImageGeneration: false
// },
// {
// id: 'anthropic-vertex',
// name: 'Anthropic Vertex AI',
// import: () => import('anthropic-vertex-ai'),
// creatorFunctionName: 'createAnthropicVertex',
// supportsImageGeneration: false
// },
// {
// id: 'openrouter',
// name: 'OpenRouter',
// import: () => import('@openrouter/ai-sdk-provider'),
// creatorFunctionName: 'createOpenRouter',
// supportsImageGeneration: false
// }
]
providers.forEach((config) => {
@ -243,28 +132,92 @@ export class AiProviderRegistry {
this.registry.set(config.id, config)
}
/**
* Provider并支持映射关系
*/
public registerDynamicProvider(
config: ProviderConfig & {
mappings?: Record<string, string>
}
): boolean {
try {
// 验证配置
if (!config.id || config.id.trim() === '') {
console.error('Provider ID cannot be empty')
return false
}
// 注册provider
this.registerProvider(config)
// 记录为动态provider
this.dynamicProviders.add(config.id)
// 添加映射关系(如果提供)
if (config.mappings) {
Object.entries(config.mappings).forEach(([key, value]) => {
this.dynamicMappings.set(key, value)
})
}
return true
} catch (error) {
console.error(`Failed to register provider ${config.id}:`, error)
return false
}
}
/**
* Providers
*/
public registerMultipleProviders(
configs: (ProviderConfig & {
mappings?: Record<string, string>
})[]
): number {
let successCount = 0
configs.forEach((config) => {
if (this.registerDynamicProvider(config)) {
successCount++
}
})
return successCount
}
/**
* Provider映射
*/
public getProviderMapping(providerId: string): string | undefined {
return this.dynamicMappings.get(providerId) || (this.dynamicProviders.has(providerId) ? providerId : undefined)
}
/**
* Provider
*/
public isDynamicProvider(providerId: string): boolean {
return this.dynamicProviders.has(providerId)
}
/**
* Provider映射
*/
public getAllDynamicMappings(): Record<string, string> {
return Object.fromEntries(this.dynamicMappings)
}
/**
* Providers
*/
public getDynamicProviders(): string[] {
return Array.from(this.dynamicProviders)
}
/**
*
*/
public cleanup(): void {
this.registry.clear()
}
// /**
// * 获取兼容现有实现的注册表格式
// */
// public getCompatibleRegistry(): Record<string, { import: () => Promise<any>; creatorFunctionName: string }> {
// const compatibleRegistry: Record<string, { import: () => Promise<any>; creatorFunctionName: string }> = {}
// this.getAllProviders().forEach((provider) => {
// compatibleRegistry[provider.id] = {
// import: provider.import,
// creatorFunctionName: provider.creatorFunctionName
// }
// })
// return compatibleRegistry
// }
}
// 导出单例实例
@ -275,5 +228,16 @@ export const getProvider = (id: string) => aiProviderRegistry.getProvider(id)
export const getAllProviders = () => aiProviderRegistry.getAllProviders()
export const isProviderSupported = (id: string) => aiProviderRegistry.isSupported(id)
export const registerProvider = (config: ProviderConfig) => aiProviderRegistry.registerProvider(config)
// 动态注册相关便捷函数
export const registerDynamicProvider = (config: ProviderConfig & { mappings?: Record<string, string> }) =>
aiProviderRegistry.registerDynamicProvider(config)
export const registerMultipleProviders = (configs: (ProviderConfig & { mappings?: Record<string, string> })[]) =>
aiProviderRegistry.registerMultipleProviders(configs)
export const getProviderMapping = (providerId: string) => aiProviderRegistry.getProviderMapping(providerId)
export const isDynamicProvider = (providerId: string) => aiProviderRegistry.isDynamicProvider(providerId)
export const getAllDynamicMappings = () => aiProviderRegistry.getAllDynamicMappings()
export const getDynamicProviders = () => aiProviderRegistry.getDynamicProviders()
// 兼容现有实现的导出
// export const PROVIDER_REGISTRY = aiProviderRegistry.getCompatibleRegistry()

View File

@ -6,11 +6,30 @@ import { type OpenAIProviderSettings } from '@ai-sdk/openai'
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
import { type XaiProviderSettings } from '@ai-sdk/xai'
export interface ExtensibleProviderSettingsMap {
// 基础的静态providers
openai: OpenAIProviderSettings
'openai-responses': OpenAIProviderSettings
'openai-compatible': OpenAICompatibleProviderSettings
anthropic: AnthropicProviderSettings
google: GoogleGenerativeAIProviderSettings
xai: XaiProviderSettings
azure: AzureOpenAIProviderSettings
deepseek: DeepSeekProviderSettings
}
// 动态扩展的provider类型注册表
export interface DynamicProviderRegistry {
[key: string]: any
}
// 合并基础和动态provider类型
export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProviderRegistry
/**
* Provider
* 使 AI SDK
*/
export type ProviderId = keyof ProviderSettingsMap & string
// Provider 配置接口 - 支持灵活的创建方式
export interface ProviderConfig {
@ -45,57 +64,23 @@ export class ProviderError extends Error {
}
}
// 类型安全的 Provider Settings 映射
export type ProviderSettingsMap = {
openai: OpenAIProviderSettings
'openai-responses': OpenAIProviderSettings
'openai-compatible': OpenAICompatibleProviderSettings
// openrouter: OpenRouterProviderSettings
anthropic: AnthropicProviderSettings
google: GoogleGenerativeAIProviderSettings
// 'google-vertex': GoogleVertexProviderSettings
// mistral: MistralProviderSettings
xai: XaiProviderSettings
azure: AzureOpenAIProviderSettings
// bedrock: AmazonBedrockProviderSettings
// cohere: CohereProviderSettings
// groq: GroqProviderSettings
// together: TogetherAIProviderSettings
// fireworks: FireworksProviderSettings
deepseek: DeepSeekProviderSettings
// cerebras: CerebrasProviderSettings
// deepinfra: DeepInfraProviderSettings
// replicate: ReplicateProviderSettings
// perplexity: PerplexityProviderSettings
// fal: FalProviderSettings
// vercel: VercelProviderSettings
// ollama: OllamaProviderSettings
// 'anthropic-vertex': AnthropicVertexProviderSettings
// 动态ProviderId类型 - 支持运行时扩展
export type ProviderId = keyof ExtensibleProviderSettingsMap | string
// Provider类型注册工具
export interface ProviderTypeRegistrar {
registerProviderType<T extends string, S>(providerId: T, settingsType: S): void
getProviderSettings<T extends string>(providerId: T): any
}
// 重新导出所有类型供外部使用
export type {
// AmazonBedrockProviderSettings,
AnthropicProviderSettings,
// AnthropicVertexProviderSettings,
AzureOpenAIProviderSettings,
// CerebrasProviderSettings,
// CohereProviderSettings,
// DeepInfraProviderSettings,
DeepSeekProviderSettings,
// FalProviderSettings,
// FireworksProviderSettings,
GoogleGenerativeAIProviderSettings,
// GoogleVertexProviderSettings,
// GroqProviderSettings,
// MistralProviderSettings,
// OllamaProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
// OpenRouterProviderSettings,
// PerplexityProviderSettings,
// ReplicateProviderSettings,
// TogetherAIProviderSettings,
// VercelProviderSettings,
XaiProviderSettings
}
// 新的provider类型已经在上面直接export不需要重复导出

View File

@ -120,7 +120,19 @@ export {
} from './core/options'
// ==================== 工具函数 ====================
export { getAllProviders, getProvider, isProviderSupported, registerProvider } from './core/providers/registry'
export {
getAllDynamicMappings,
getAllProviders,
getDynamicProviders,
getProvider,
getProviderMapping,
isDynamicProvider,
isProviderSupported,
// 动态注册功能
registerDynamicProvider,
registerMultipleProviders,
registerProvider
} from './core/providers/registry'
// ==================== Provider 配置工厂 ====================
export {

View File

@ -114,12 +114,12 @@ export class AiSdkToChunkAdapter {
}
break
case 'reasoning-delta':
final.reasoningContent += chunk.text || ''
this.onChunk({
type: ChunkType.THINKING_DELTA,
text: final.reasoningContent || '',
thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0
})
final.reasoningContent += chunk.text || ''
break
case 'reasoning-end':
this.onChunk({

View File

@ -4,8 +4,8 @@
* API使
*/
import { ToolCallUnion, ToolResultUnion, ToolSet } from '@cherrystudio/ai-core'
import Logger from '@renderer/config/logger'
import { ToolSet, TypedToolCall, TypedToolResult } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import { BaseTool, MCPToolResponse, ToolCallResponse } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import { type ProviderMetadata } from 'ai'
@ -14,6 +14,8 @@ import { type ProviderMetadata } from 'ai'
// WebSearchPluginConfig
// } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin'
const logger = loggerService.withContext('ToolCallChunkHandler')
/**
*
*/
@ -84,7 +86,7 @@ export class ToolCallChunkHandler {
case 'tool-input-delta': {
const toolCall = this.activeToolCalls.get(chunk.id)
if (!toolCall) {
Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
return
}
toolCall.args += chunk.delta
@ -94,7 +96,7 @@ export class ToolCallChunkHandler {
const toolCall = this.activeToolCalls.get(chunk.id)
this.activeToolCalls.delete(chunk.id)
if (!toolCall) {
Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
return
}
const toolResponse: ToolCallResponse = {
@ -104,7 +106,7 @@ export class ToolCallChunkHandler {
status: 'pending',
toolCallId: toolCall.toolCallId
}
console.log('toolResponse', toolResponse)
logger.debug('toolResponse', toolResponse)
this.onChunk({
type: ChunkType.MCP_TOOL_PENDING,
responses: [toolResponse]
@ -134,12 +136,12 @@ export class ToolCallChunkHandler {
public handleToolCall(
chunk: {
type: 'tool-call'
} & ToolCallUnion<ToolSet>
} & TypedToolCall<ToolSet>
): void {
const { toolCallId, toolName, input: args, providerExecuted } = chunk
if (!toolCallId || !toolName) {
Logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`)
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`)
return
}
@ -148,7 +150,7 @@ export class ToolCallChunkHandler {
// 根据 providerExecuted 标志区分处理逻辑
if (providerExecuted) {
// 如果是 Provider 执行的工具(如 web_search
Logger.info(`[ToolCallChunkHandler] Handling provider-executed tool: ${toolName}`)
logger.info(`[ToolCallChunkHandler] Handling provider-executed tool: ${toolName}`)
tool = {
id: toolCallId,
name: toolName,
@ -157,7 +159,7 @@ export class ToolCallChunkHandler {
}
} else if (toolName.startsWith('builtin_')) {
// 如果是内置工具,沿用现有逻辑
Logger.info(`[ToolCallChunkHandler] Handling builtin tool: ${toolName}`)
logger.info(`[ToolCallChunkHandler] Handling builtin tool: ${toolName}`)
tool = {
id: toolCallId,
name: toolName,
@ -166,10 +168,10 @@ export class ToolCallChunkHandler {
}
} else {
// 如果是客户端执行的 MCP 工具,沿用现有逻辑
Logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`)
logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`)
const mcpTool = this.mcpTools.find((t) => t.name === toolName)
if (!mcpTool) {
Logger.warn(`[ToolCallChunkHandler] MCP tool not found: ${toolName}`)
logger.warn(`[ToolCallChunkHandler] MCP tool not found: ${toolName}`)
return
}
tool = mcpTool
@ -207,19 +209,19 @@ export class ToolCallChunkHandler {
public handleToolResult(
chunk: {
type: 'tool-result'
} & ToolResultUnion<ToolSet>
} & TypedToolResult<ToolSet>
): void {
const { toolCallId, output, input } = chunk
if (!toolCallId) {
Logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool result chunk: missing toolCallId`)
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool result chunk: missing toolCallId`)
return
}
// 查找对应的工具调用信息
const toolCallInfo = this.activeToolCalls.get(toolCallId)
if (!toolCallInfo) {
Logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
return
}

View File

@ -37,7 +37,7 @@ import {
import { convertBase64ImageToAwsBedrockFormat } from '@renderer/utils/aws-bedrock-utils'
import {
awsBedrockToolUseToMcpTool,
isEnabledToolUse,
isSupportedToolUse,
mcpToolCallResponseToAwsBedrockMessage,
mcpToolsToAwsBedrockTools
} from '@renderer/utils/mcp-tools'
@ -393,7 +393,7 @@ export class AwsBedrockAPIClient extends BaseApiClient<
const { tools } = this.setupToolsConfig({
mcpTools: mcpTools,
model,
enableToolUse: isEnabledToolUse(assistant)
enableToolUse: isSupportedToolUse(assistant)
})
// 3. 处理消息

View File

@ -69,10 +69,10 @@ function providerToAiSdkConfig(actualProvider: Provider): {
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// console.log('aiSdkProviderId', aiSdkProviderId)
// 如果provider是openai则使用strict模式并且默认responses api
const actualProviderId = actualProvider.type
const actualProviderType = actualProvider.type
const openaiResponseOptions =
// 对于实际是openai的需要走responses,aiCore内部会判断model是否可用responses
actualProviderId === 'openai-response'
actualProviderType === 'openai-response'
? {
mode: 'responses'
}

View File

@ -13,14 +13,13 @@ export default definePlugin({
return new TransformStream<TextStreamPart<ToolSet>, TextStreamPart<ToolSet>>({
transform(chunk: TextStreamPart<ToolSet>, controller: TransformStreamDefaultController<TextStreamPart<ToolSet>>) {
// === 处理 reasoning 类型 ===
if (chunk.type === 'reasoning') {
if (!hasStartedThinking) {
hasStartedThinking = true
thinkingStartTime = performance.now()
reasoningBlockId = chunk.id
}
if (chunk.type === 'reasoning-start') {
controller.enqueue(chunk)
hasStartedThinking = true
thinkingStartTime = performance.now()
reasoningBlockId = chunk.id
} else if (chunk.type === 'reasoning-delta') {
accumulatedThinkingContent += chunk.text
controller.enqueue({
...chunk,
providerMetadata: {
@ -32,7 +31,7 @@ export default definePlugin({
}
}
})
} else if (hasStartedThinking) {
} else if (chunk.type === 'reasoning-end' && hasStartedThinking) {
controller.enqueue({
type: 'reasoning-end',
id: reasoningBlockId,
@ -47,28 +46,8 @@ export default definePlugin({
hasStartedThinking = false
thinkingStartTime = 0
reasoningBlockId = ''
if (chunk.type !== 'reasoning-end') {
controller.enqueue(chunk)
}
} else {
if (chunk.type !== 'reasoning-end') {
controller.enqueue(chunk)
}
}
},
flush(controller) {
if (hasStartedThinking) {
controller.enqueue({
type: 'reasoning-end',
id: reasoningBlockId,
providerMetadata: {
metadata: {
thinking_millsec: performance.now() - thinkingStartTime,
thinking_content: accumulatedThinkingContent
}
}
})
controller.enqueue(chunk)
}
}
})

View File

@ -0,0 +1,103 @@
import type { Provider } from '@renderer/types'
import { describe, expect, it, vi } from 'vitest'
import { getAiSdkProviderId } from '../factory'
// Mock the external dependencies
vi.mock('@cherrystudio/ai-core', () => ({
registerMultipleProviders: vi.fn(() => 4), // Mock successful registration of 4 providers
getProviderMapping: vi.fn((id: string) => {
// Mock dynamic mappings
const mappings: Record<string, string> = {
openrouter: 'openrouter',
'google-vertex': 'google-vertex',
vertexai: 'google-vertex',
bedrock: 'bedrock',
'aws-bedrock': 'bedrock',
zhipu: 'zhipu'
}
return mappings[id]
}),
AiCore: {
isSupported: vi.fn(() => true)
}
}))
// Mock the provider configs
vi.mock('../providerConfigs', () => ({
initializeNewProviders: vi.fn()
}))
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
info: vi.fn(),
warn: vi.fn(),
error: vi.fn()
})
}
}))
function createTestProvider(id: string, type: string): Provider {
return {
id,
type,
name: `Test ${id}`,
apiKey: 'test-key',
apiHost: 'test-host'
} as Provider
}
describe('Integrated Provider Registry', () => {
describe('Provider ID Resolution', () => {
it('should resolve openrouter provider correctly', () => {
const provider = createTestProvider('openrouter', 'openrouter')
const result = getAiSdkProviderId(provider)
expect(result).toBe('openrouter')
})
it('should resolve google-vertex provider correctly', () => {
const provider = createTestProvider('google-vertex', 'vertexai')
const result = getAiSdkProviderId(provider)
expect(result).toBe('google-vertex')
})
it('should resolve bedrock provider correctly', () => {
const provider = createTestProvider('bedrock', 'aws-bedrock')
const result = getAiSdkProviderId(provider)
expect(result).toBe('bedrock')
})
it('should resolve zhipu provider correctly', () => {
const provider = createTestProvider('zhipu', 'zhipu')
const result = getAiSdkProviderId(provider)
expect(result).toBe('zhipu')
})
it('should resolve provider type mapping correctly', () => {
const provider = createTestProvider('vertex-test', 'vertexai')
const result = getAiSdkProviderId(provider)
expect(result).toBe('google-vertex')
})
it('should handle static provider mappings', () => {
const geminiProvider = createTestProvider('gemini', 'gemini')
const result = getAiSdkProviderId(geminiProvider)
expect(result).toBe('google')
})
it('should fallback to provider.id for unknown providers', () => {
const unknownProvider = createTestProvider('unknown-provider', 'unknown-type')
const result = getAiSdkProviderId(unknownProvider)
expect(result).toBe('unknown-provider')
})
})
describe('Backward Compatibility', () => {
it('should maintain compatibility with existing providers', () => {
const grokProvider = createTestProvider('grok', 'grok')
const result = getAiSdkProviderId(grokProvider)
expect(result).toBe('xai')
})
})
})

View File

@ -10,9 +10,7 @@ export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'opena
return 'anthropic'
}
// TODO:暂时注释,不清楚为什么排除,webSearch时会导致gemini模型走openai的逻辑
// if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
if (id.startsWith('gemini') || id.startsWith('imagen')) {
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
return 'google'
}

View File

@ -1,35 +1,50 @@
import { AiCore, ProviderId } from '@cherrystudio/ai-core'
import { AiCore, getProviderMapping, type ProviderId } from '@cherrystudio/ai-core'
import { Provider } from '@renderer/types'
const PROVIDER_MAPPING: Record<string, ProviderId> = {
import { initializeNewProviders } from './providerConfigs'
// 初始化新的Provider注册系统
initializeNewProviders()
// 静态Provider映射 - 核心providers
const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
// anthropic: 'anthropic',
gemini: 'google',
vertexai: 'google-vertex',
'azure-openai': 'azure',
'openai-response': 'openai',
grok: 'xai'
}
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
const providerId = PROVIDER_MAPPING[provider.id]
if (providerId) {
return providerId
// 1. 首先检查静态映射
const staticProviderId = STATIC_PROVIDER_MAPPING[provider.id]
if (staticProviderId) {
return staticProviderId
}
const providerType = PROVIDER_MAPPING[provider.type] // 有些第三方需要映射到aicore对应sdk
if (providerType) {
return providerType
// 2. 检查动态注册的provider映射使用aiCore的函数
const dynamicProviderId = getProviderMapping(provider.id)
if (dynamicProviderId) {
return dynamicProviderId as ProviderId
}
// 3. 检查provider.type的静态映射
const staticProviderType = STATIC_PROVIDER_MAPPING[provider.type]
if (staticProviderType) {
return staticProviderType
}
// 4. 检查provider.type的动态映射
const dynamicProviderType = getProviderMapping(provider.type)
if (dynamicProviderType) {
return dynamicProviderType as ProviderId
}
// 5. 检查AiCore是否直接支持
if (AiCore.isSupported(provider.id)) {
return provider.id as ProviderId
}
// 先注释掉会影响获取providerOptions
// if (AiCore.isSupported(provider.type)) {
// return provider.type as ProviderId
// }
// 6. 最后的fallback
return provider.id as ProviderId
}

View File

@ -0,0 +1,63 @@
import type { ProviderConfig } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
const logger = loggerService.withContext('ProviderConfigs')
/**
* Provider配置定义
* AI Providers
*/
export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
mappings?: Record<string, string>
})[] = [
{
id: 'openrouter',
name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: true,
mappings: {
openrouter: 'openrouter'
}
},
{
id: 'google-vertex',
name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex'),
creatorFunctionName: 'createGoogleVertex',
supportsImageGeneration: true,
mappings: {
'google-vertex': 'google-vertex',
vertexai: 'google-vertex'
}
},
{
id: 'bedrock',
name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: true,
mappings: {
'aws-bedrock': 'bedrock'
}
}
] as const
/**
* Providers
* 使aiCore的动态注册功能
*/
export async function initializeNewProviders(): Promise<void> {
try {
// 动态导入以避免循环依赖
const { registerMultipleProviders } = await import('@cherrystudio/ai-core')
const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS)
if (successCount < NEW_PROVIDER_CONFIGS.length) {
logger.warn('Some providers failed to register. Check previous error logs.')
}
} catch (error) {
logger.error('Failed to initialize new providers:', error as Error)
}
}

View File

@ -46,7 +46,6 @@ import {
findThinkingBlocks,
getMainTextContent
} from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { defaultTimeout } from '@shared/config/constant'
import { isEmpty } from 'lodash'
@ -79,17 +78,6 @@ export function getTimeout(model: Model): number {
return defaultTimeout
}
/**
*
*/
export async function buildSystemPromptWithTools(
prompt: string,
mcpTools?: MCPTool[],
assistant?: Assistant
): Promise<string> {
return await buildSystemPrompt(prompt, mcpTools, assistant)
}
/**
*
*/

View File

@ -5,7 +5,6 @@ export const SYSTEM_PROMPT_THRESHOLD = 128
export const DEFAULT_KNOWLEDGE_DOCUMENT_COUNT = 6
export const DEFAULT_KNOWLEDGE_THRESHOLD = 0.0
export const DEFAULT_WEBSEARCH_RAG_DOCUMENT_COUNT = 1
export const SYSTEM_PROMPT_THRESHOLD = 128
export const platform = window.electron?.process?.platform
export const isMac = platform === 'darwin'

View File

@ -3,6 +3,7 @@
*/
import { StreamTextParams } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import AiProvider from '@renderer/aiCore'
import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder'
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
@ -26,7 +27,7 @@ import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils
import { containsSupportedVariables, replacePromptVariables } from '@renderer/utils/prompt'
import { isEmpty, takeRight } from 'lodash'
import AiProvider from '../aiCore'
import AiProviderNew from '../aiCore/index_new'
import {
// getAssistantProvider,
// getAssistantSettings,
@ -395,11 +396,8 @@ export async function fetchChatCompletion({
}) {
console.log('fetchChatCompletion', messages, assistant)
const provider = getAssistantProvider(assistant)
const AI = new AiProvider(provider)
// Make sure that 'Clear Context' works for all scenarios including external tool and normal chat.
messages = filterContextMessages(messages)
const AI = new AiProviderNew(assistant.model || getDefaultModel())
const provider = AI.getActualProvider()
const mcpTools: MCPTool[] = []

View File

@ -8,6 +8,7 @@ import store from '@renderer/store'
import { addMCPServer } from '@renderer/store/mcp'
import {
Assistant,
BaseTool,
MCPCallToolResponse,
MCPServer,
MCPTool,
@ -492,12 +493,13 @@ export function getMcpServerByTool(tool: MCPTool) {
return servers.find((s) => s.id === tool.serverId)
}
export function isToolAutoApproved(tool: MCPTool, server?: MCPServer): boolean {
if (tool.isBuiltIn) {
export function isToolAutoApproved(tool: BaseTool, server?: MCPServer): boolean {
if (tool.type === 'builtin') {
return true
}
const effectiveServer = server ?? getMcpServerByTool(tool)
return effectiveServer ? !effectiveServer.disabledAutoApproveTools?.includes(tool.name) : false
const mcpTool = tool as MCPTool
const effectiveServer = server ?? getMcpServerByTool(mcpTool)
return effectiveServer ? !effectiveServer.disabledAutoApproveTools?.includes(mcpTool.name) : false
}
export function parseToolUse(content: string, mcpTools: MCPTool[], startIdx: number = 0): ToolUseResponse[] {