diff --git a/src/renderer/src/aiCore/clients/AihubmixAPIClient.ts b/src/renderer/src/aiCore/clients/AihubmixAPIClient.ts index ffd2140d54..c3d57f7e0f 100644 --- a/src/renderer/src/aiCore/clients/AihubmixAPIClient.ts +++ b/src/renderer/src/aiCore/clients/AihubmixAPIClient.ts @@ -1,43 +1,23 @@ import { isOpenAILLMModel } from '@renderer/config/models' -import { - GenerateImageParams, - MCPCallToolResponse, - MCPTool, - MCPToolResponse, - Model, - Provider, - ToolCallResponse -} from '@renderer/types' -import { - RequestOptions, - SdkInstance, - SdkMessageParam, - SdkModel, - SdkParams, - SdkRawChunk, - SdkRawOutput, - SdkTool, - SdkToolCall -} from '@renderer/types/sdk' +import { Model, Provider } from '@renderer/types' -import { CompletionsContext } from '../middleware/types' import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient' import { BaseApiClient } from './BaseApiClient' import { GeminiAPIClient } from './gemini/GeminiAPIClient' +import { MixedBaseAPIClient } from './MixedBaseApiClient' import { OpenAIAPIClient } from './openai/OpenAIApiClient' import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient' -import { RequestTransformer, ResponseChunkTransformer } from './types' /** * AihubmixAPIClient - 根据模型类型自动选择合适的ApiClient * 使用装饰器模式实现,在ApiClient层面进行模型路由 */ -export class AihubmixAPIClient extends BaseApiClient { +export class AihubmixAPIClient extends MixedBaseAPIClient { // 使用联合类型而不是any,保持类型安全 - private clients: Map = + protected clients: Map = new Map() - private defaultClient: OpenAIAPIClient - private currentClient: BaseApiClient + protected defaultClient: OpenAIAPIClient + protected currentClient: BaseApiClient constructor(provider: Provider) { super(provider) @@ -73,24 +53,10 @@ export class AihubmixAPIClient extends BaseApiClient { return this.currentClient.getBaseURL() } - /** - * 类型守卫:确保client是BaseApiClient的实例 - */ - private isValidClient(client: unknown): client is BaseApiClient { - return ( - client !== null && - client !== undefined && - typeof client === 'object' && - 'createCompletions' in client && - 'getRequestTransformer' in client && - 'getResponseChunkTransformer' in client - ) - } - /** * 根据模型获取合适的client */ - private getClient(model: Model): BaseApiClient { + protected getClient(model: Model): BaseApiClient { const id = model.id.toLowerCase() // claude开头 @@ -127,114 +93,4 @@ export class AihubmixAPIClient extends BaseApiClient { return this.defaultClient as BaseApiClient } - - /** - * 根据模型选择合适的client并委托调用 - */ - public getClientForModel(model: Model): BaseApiClient { - this.currentClient = this.getClient(model) - return this.currentClient - } - - /** - * 重写基类方法,返回内部实际使用的客户端类型 - */ - public override getClientCompatibilityType(model?: Model): string[] { - if (!model) { - return [this.constructor.name] - } - - const actualClient = this.getClient(model) - return actualClient.getClientCompatibilityType(model) - } - - // ============ BaseApiClient 抽象方法实现 ============ - - async createCompletions(payload: SdkParams, options?: RequestOptions): Promise { - // 尝试从payload中提取模型信息来选择client - const modelId = this.extractModelFromPayload(payload) - if (modelId) { - const modelObj = { id: modelId } as Model - const targetClient = this.getClient(modelObj) - return targetClient.createCompletions(payload, options) - } - - // 如果无法从payload中提取模型,使用当前设置的client - return this.currentClient.createCompletions(payload, options) - } - - /** - * 从SDK payload中提取模型ID - */ - private extractModelFromPayload(payload: SdkParams): string | null { - // 不同的SDK可能有不同的字段名 - if ('model' in payload && typeof payload.model === 'string') { - return payload.model - } - return null - } - - async generateImage(params: GenerateImageParams): Promise { - return this.currentClient.generateImage(params) - } - - async getEmbeddingDimensions(model?: Model): Promise { - const client = model ? this.getClient(model) : this.currentClient - return client.getEmbeddingDimensions(model) - } - - async listModels(): Promise { - // 可以聚合所有client的模型,或者使用默认client - return this.defaultClient.listModels() - } - - async getSdkInstance(): Promise { - return this.currentClient.getSdkInstance() - } - - getRequestTransformer(): RequestTransformer { - return this.currentClient.getRequestTransformer() - } - - getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer { - return this.currentClient.getResponseChunkTransformer(ctx) - } - - convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] { - return this.currentClient.convertMcpToolsToSdkTools(mcpTools) - } - - convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined { - return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools) - } - - convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse { - return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool) - } - - buildSdkMessages( - currentReqMessages: SdkMessageParam[], - output: SdkRawOutput | string, - toolResults: SdkMessageParam[], - toolCalls?: SdkToolCall[] - ): SdkMessageParam[] { - return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls) - } - - convertMcpToolResponseToSdkMessageParam( - mcpToolResponse: MCPToolResponse, - resp: MCPCallToolResponse, - model: Model - ): SdkMessageParam | undefined { - const client = this.getClient(model) - return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model) - } - - extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] { - return this.currentClient.extractMessagesFromSdkPayload(sdkPayload) - } - - estimateMessageTokens(message: SdkMessageParam): number { - return this.currentClient.estimateMessageTokens(message) - } } diff --git a/src/renderer/src/aiCore/clients/MixedBaseApiClient.ts b/src/renderer/src/aiCore/clients/MixedBaseApiClient.ts new file mode 100644 index 0000000000..36a207ecb3 --- /dev/null +++ b/src/renderer/src/aiCore/clients/MixedBaseApiClient.ts @@ -0,0 +1,181 @@ +import { + GenerateImageParams, + MCPCallToolResponse, + MCPTool, + MCPToolResponse, + Model, + Provider, + ToolCallResponse +} from '@renderer/types' +import { + RequestOptions, + SdkInstance, + SdkMessageParam, + SdkModel, + SdkParams, + SdkRawChunk, + SdkRawOutput, + SdkTool, + SdkToolCall +} from '@renderer/types/sdk' + +import { CompletionsContext } from '../middleware/types' +import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient' +import { BaseApiClient } from './BaseApiClient' +import { GeminiAPIClient } from './gemini/GeminiAPIClient' +import { OpenAIAPIClient } from './openai/OpenAIApiClient' +import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient' +import { RequestTransformer, ResponseChunkTransformer } from './types' + +/** + * MixedAPIClient - 适用于可能含有多种接口类型的Provider + */ +export abstract class MixedBaseAPIClient extends BaseApiClient { + // 使用联合类型而不是any,保持类型安全 + protected abstract clients: Map< + string, + AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient + > + protected abstract defaultClient: OpenAIAPIClient + protected abstract currentClient: BaseApiClient + + constructor(provider: Provider) { + super(provider) + } + + override getBaseURL(): string { + if (!this.currentClient) { + return this.provider.apiHost + } + return this.currentClient.getBaseURL() + } + + /** + * 类型守卫:确保client是BaseApiClient的实例 + */ + protected isValidClient(client: unknown): client is BaseApiClient { + return ( + client !== null && + client !== undefined && + typeof client === 'object' && + 'createCompletions' in client && + 'getRequestTransformer' in client && + 'getResponseChunkTransformer' in client + ) + } + + /** + * 根据模型获取合适的client + */ + protected abstract getClient(model: Model): BaseApiClient + + /** + * 根据模型选择合适的client并委托调用 + */ + public getClientForModel(model: Model): BaseApiClient { + this.currentClient = this.getClient(model) + return this.currentClient + } + + /** + * 重写基类方法,返回内部实际使用的客户端类型 + */ + public override getClientCompatibilityType(model?: Model): string[] { + if (!model) { + return [this.constructor.name] + } + + const actualClient = this.getClient(model) + return actualClient.getClientCompatibilityType(model) + } + + /** + * 从SDK payload中提取模型ID + */ + protected extractModelFromPayload(payload: SdkParams): string | null { + // 不同的SDK可能有不同的字段名 + if ('model' in payload && typeof payload.model === 'string') { + return payload.model + } + return null + } + + // ============ BaseApiClient 的抽象方法 ============ + + async createCompletions(payload: SdkParams, options?: RequestOptions): Promise { + // 尝试从payload中提取模型信息来选择client + const modelId = this.extractModelFromPayload(payload) + if (modelId) { + const modelObj = { id: modelId } as Model + const targetClient = this.getClient(modelObj) + return targetClient.createCompletions(payload, options) + } + + // 如果无法从payload中提取模型,使用当前设置的client + return this.currentClient.createCompletions(payload, options) + } + + async generateImage(params: GenerateImageParams): Promise { + return this.currentClient.generateImage(params) + } + + async getEmbeddingDimensions(model?: Model): Promise { + const client = model ? this.getClient(model) : this.currentClient + return client.getEmbeddingDimensions(model) + } + + async listModels(): Promise { + // 可以聚合所有client的模型,或者使用默认client + return this.defaultClient.listModels() + } + + async getSdkInstance(): Promise { + return this.currentClient.getSdkInstance() + } + + getRequestTransformer(): RequestTransformer { + return this.currentClient.getRequestTransformer() + } + + getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer { + return this.currentClient.getResponseChunkTransformer(ctx) + } + + convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] { + return this.currentClient.convertMcpToolsToSdkTools(mcpTools) + } + + convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined { + return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools) + } + + convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse { + return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool) + } + + buildSdkMessages( + currentReqMessages: SdkMessageParam[], + output: SdkRawOutput | string, + toolResults: SdkMessageParam[], + toolCalls?: SdkToolCall[] + ): SdkMessageParam[] { + return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls) + } + + estimateMessageTokens(message: SdkMessageParam): number { + return this.currentClient.estimateMessageTokens(message) + } + + convertMcpToolResponseToSdkMessageParam( + mcpToolResponse: MCPToolResponse, + resp: MCPCallToolResponse, + model: Model + ): SdkMessageParam | undefined { + const client = this.getClient(model) + return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model) + } + + extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] { + return this.currentClient.extractMessagesFromSdkPayload(sdkPayload) + } +} diff --git a/src/renderer/src/aiCore/clients/NewAPIClient.ts b/src/renderer/src/aiCore/clients/NewAPIClient.ts index 6242f6a320..e87d54ae3e 100644 --- a/src/renderer/src/aiCore/clients/NewAPIClient.ts +++ b/src/renderer/src/aiCore/clients/NewAPIClient.ts @@ -1,42 +1,23 @@ import { loggerService } from '@logger' import { isSupportedModel } from '@renderer/config/models' -import { - GenerateImageParams, - MCPCallToolResponse, - MCPTool, - MCPToolResponse, - Model, - Provider, - ToolCallResponse -} from '@renderer/types' -import { - NewApiModel, - RequestOptions, - SdkInstance, - SdkMessageParam, - SdkParams, - SdkRawChunk, - SdkRawOutput, - SdkTool, - SdkToolCall -} from '@renderer/types/sdk' +import { Model, Provider } from '@renderer/types' +import { NewApiModel } from '@renderer/types/sdk' -import { CompletionsContext } from '../middleware/types' import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient' import { BaseApiClient } from './BaseApiClient' import { GeminiAPIClient } from './gemini/GeminiAPIClient' +import { MixedBaseAPIClient } from './MixedBaseApiClient' import { OpenAIAPIClient } from './openai/OpenAIApiClient' import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient' -import { RequestTransformer, ResponseChunkTransformer } from './types' const logger = loggerService.withContext('NewAPIClient') -export class NewAPIClient extends BaseApiClient { +export class NewAPIClient extends MixedBaseAPIClient { // 使用联合类型而不是any,保持类型安全 - private clients: Map = + protected clients: Map = new Map() - private defaultClient: OpenAIAPIClient - private currentClient: BaseApiClient + protected defaultClient: OpenAIAPIClient + protected currentClient: BaseApiClient constructor(provider: Provider) { super(provider) @@ -63,24 +44,10 @@ export class NewAPIClient extends BaseApiClient { return this.currentClient.getBaseURL() } - /** - * 类型守卫:确保client是BaseApiClient的实例 - */ - private isValidClient(client: unknown): client is BaseApiClient { - return ( - client !== null && - client !== undefined && - typeof client === 'object' && - 'createCompletions' in client && - 'getRequestTransformer' in client && - 'getResponseChunkTransformer' in client - ) - } - /** * 根据模型获取合适的client */ - private getClient(model: Model): BaseApiClient { + protected getClient(model: Model): BaseApiClient { if (!model.endpoint_type) { throw new Error('Model endpoint type is not defined') } @@ -120,61 +87,6 @@ export class NewAPIClient extends BaseApiClient { throw new Error('Invalid model endpoint type: ' + model.endpoint_type) } - /** - * 根据模型选择合适的client并委托调用 - */ - public getClientForModel(model: Model): BaseApiClient { - this.currentClient = this.getClient(model) - return this.currentClient - } - - /** - * 重写基类方法,返回内部实际使用的客户端类型 - */ - public override getClientCompatibilityType(model?: Model): string[] { - if (!model) { - return [this.constructor.name] - } - - const actualClient = this.getClient(model) - return actualClient.getClientCompatibilityType(model) - } - - // ============ BaseApiClient 抽象方法实现 ============ - - async createCompletions(payload: SdkParams, options?: RequestOptions): Promise { - // 尝试从payload中提取模型信息来选择client - const modelId = this.extractModelFromPayload(payload) - if (modelId) { - const modelObj = { id: modelId } as Model - const targetClient = this.getClient(modelObj) - return targetClient.createCompletions(payload, options) - } - - // 如果无法从payload中提取模型,使用当前设置的client - return this.currentClient.createCompletions(payload, options) - } - - /** - * 从SDK payload中提取模型ID - */ - private extractModelFromPayload(payload: SdkParams): string | null { - // 不同的SDK可能有不同的字段名 - if ('model' in payload && typeof payload.model === 'string') { - return payload.model - } - return null - } - - async generateImage(params: GenerateImageParams): Promise { - return this.currentClient.generateImage(params) - } - - async getEmbeddingDimensions(model?: Model): Promise { - const client = model ? this.getClient(model) : this.currentClient - return client.getEmbeddingDimensions(model) - } - override async listModels(): Promise { try { const sdk = await this.defaultClient.getSdkInstance() @@ -195,54 +107,4 @@ export class NewAPIClient extends BaseApiClient { return [] } } - - async getSdkInstance(): Promise { - return this.currentClient.getSdkInstance() - } - - getRequestTransformer(): RequestTransformer { - return this.currentClient.getRequestTransformer() - } - - getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer { - return this.currentClient.getResponseChunkTransformer(ctx) - } - - convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] { - return this.currentClient.convertMcpToolsToSdkTools(mcpTools) - } - - convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined { - return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools) - } - - convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse { - return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool) - } - - buildSdkMessages( - currentReqMessages: SdkMessageParam[], - output: SdkRawOutput | string, - toolResults: SdkMessageParam[], - toolCalls?: SdkToolCall[] - ): SdkMessageParam[] { - return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls) - } - - convertMcpToolResponseToSdkMessageParam( - mcpToolResponse: MCPToolResponse, - resp: MCPCallToolResponse, - model: Model - ): SdkMessageParam | undefined { - const client = this.getClient(model) - return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model) - } - - extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] { - return this.currentClient.extractMessagesFromSdkPayload(sdkPayload) - } - - estimateMessageTokens(message: SdkMessageParam): number { - return this.currentClient.estimateMessageTokens(message) - } }