mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-09 06:49:02 +08:00
refactor(aiCore): extract MixedBaseAPIClient as abstract class (#8618)
refactor(aiCore): 提取MixedBaseAPIClient基类重构API客户端 将AihubmixAPIClient和NewAPIClient的公共逻辑提取到MixedBaseAPIClient基类 减少代码重复,提高可维护性
This commit is contained in:
parent
3a2a9d26eb
commit
78173ae24e
@ -1,43 +1,23 @@
|
|||||||
import { isOpenAILLMModel } from '@renderer/config/models'
|
import { isOpenAILLMModel } from '@renderer/config/models'
|
||||||
import {
|
import { Model, Provider } from '@renderer/types'
|
||||||
GenerateImageParams,
|
|
||||||
MCPCallToolResponse,
|
|
||||||
MCPTool,
|
|
||||||
MCPToolResponse,
|
|
||||||
Model,
|
|
||||||
Provider,
|
|
||||||
ToolCallResponse
|
|
||||||
} from '@renderer/types'
|
|
||||||
import {
|
|
||||||
RequestOptions,
|
|
||||||
SdkInstance,
|
|
||||||
SdkMessageParam,
|
|
||||||
SdkModel,
|
|
||||||
SdkParams,
|
|
||||||
SdkRawChunk,
|
|
||||||
SdkRawOutput,
|
|
||||||
SdkTool,
|
|
||||||
SdkToolCall
|
|
||||||
} from '@renderer/types/sdk'
|
|
||||||
|
|
||||||
import { CompletionsContext } from '../middleware/types'
|
|
||||||
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||||
import { BaseApiClient } from './BaseApiClient'
|
import { BaseApiClient } from './BaseApiClient'
|
||||||
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||||
|
import { MixedBaseAPIClient } from './MixedBaseApiClient'
|
||||||
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
||||||
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
||||||
import { RequestTransformer, ResponseChunkTransformer } from './types'
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* AihubmixAPIClient - 根据模型类型自动选择合适的ApiClient
|
* AihubmixAPIClient - 根据模型类型自动选择合适的ApiClient
|
||||||
* 使用装饰器模式实现,在ApiClient层面进行模型路由
|
* 使用装饰器模式实现,在ApiClient层面进行模型路由
|
||||||
*/
|
*/
|
||||||
export class AihubmixAPIClient extends BaseApiClient {
|
export class AihubmixAPIClient extends MixedBaseAPIClient {
|
||||||
// 使用联合类型而不是any,保持类型安全
|
// 使用联合类型而不是any,保持类型安全
|
||||||
private clients: Map<string, AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient> =
|
protected clients: Map<string, AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient> =
|
||||||
new Map()
|
new Map()
|
||||||
private defaultClient: OpenAIAPIClient
|
protected defaultClient: OpenAIAPIClient
|
||||||
private currentClient: BaseApiClient
|
protected currentClient: BaseApiClient
|
||||||
|
|
||||||
constructor(provider: Provider) {
|
constructor(provider: Provider) {
|
||||||
super(provider)
|
super(provider)
|
||||||
@ -73,24 +53,10 @@ export class AihubmixAPIClient extends BaseApiClient {
|
|||||||
return this.currentClient.getBaseURL()
|
return this.currentClient.getBaseURL()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 类型守卫:确保client是BaseApiClient的实例
|
|
||||||
*/
|
|
||||||
private isValidClient(client: unknown): client is BaseApiClient {
|
|
||||||
return (
|
|
||||||
client !== null &&
|
|
||||||
client !== undefined &&
|
|
||||||
typeof client === 'object' &&
|
|
||||||
'createCompletions' in client &&
|
|
||||||
'getRequestTransformer' in client &&
|
|
||||||
'getResponseChunkTransformer' in client
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 根据模型获取合适的client
|
* 根据模型获取合适的client
|
||||||
*/
|
*/
|
||||||
private getClient(model: Model): BaseApiClient {
|
protected getClient(model: Model): BaseApiClient {
|
||||||
const id = model.id.toLowerCase()
|
const id = model.id.toLowerCase()
|
||||||
|
|
||||||
// claude开头
|
// claude开头
|
||||||
@ -127,114 +93,4 @@ export class AihubmixAPIClient extends BaseApiClient {
|
|||||||
|
|
||||||
return this.defaultClient as BaseApiClient
|
return this.defaultClient as BaseApiClient
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 根据模型选择合适的client并委托调用
|
|
||||||
*/
|
|
||||||
public getClientForModel(model: Model): BaseApiClient {
|
|
||||||
this.currentClient = this.getClient(model)
|
|
||||||
return this.currentClient
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 重写基类方法,返回内部实际使用的客户端类型
|
|
||||||
*/
|
|
||||||
public override getClientCompatibilityType(model?: Model): string[] {
|
|
||||||
if (!model) {
|
|
||||||
return [this.constructor.name]
|
|
||||||
}
|
|
||||||
|
|
||||||
const actualClient = this.getClient(model)
|
|
||||||
return actualClient.getClientCompatibilityType(model)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============ BaseApiClient 抽象方法实现 ============
|
|
||||||
|
|
||||||
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
|
|
||||||
// 尝试从payload中提取模型信息来选择client
|
|
||||||
const modelId = this.extractModelFromPayload(payload)
|
|
||||||
if (modelId) {
|
|
||||||
const modelObj = { id: modelId } as Model
|
|
||||||
const targetClient = this.getClient(modelObj)
|
|
||||||
return targetClient.createCompletions(payload, options)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果无法从payload中提取模型,使用当前设置的client
|
|
||||||
return this.currentClient.createCompletions(payload, options)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 从SDK payload中提取模型ID
|
|
||||||
*/
|
|
||||||
private extractModelFromPayload(payload: SdkParams): string | null {
|
|
||||||
// 不同的SDK可能有不同的字段名
|
|
||||||
if ('model' in payload && typeof payload.model === 'string') {
|
|
||||||
return payload.model
|
|
||||||
}
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
async generateImage(params: GenerateImageParams): Promise<string[]> {
|
|
||||||
return this.currentClient.generateImage(params)
|
|
||||||
}
|
|
||||||
|
|
||||||
async getEmbeddingDimensions(model?: Model): Promise<number> {
|
|
||||||
const client = model ? this.getClient(model) : this.currentClient
|
|
||||||
return client.getEmbeddingDimensions(model)
|
|
||||||
}
|
|
||||||
|
|
||||||
async listModels(): Promise<SdkModel[]> {
|
|
||||||
// 可以聚合所有client的模型,或者使用默认client
|
|
||||||
return this.defaultClient.listModels()
|
|
||||||
}
|
|
||||||
|
|
||||||
async getSdkInstance(): Promise<SdkInstance> {
|
|
||||||
return this.currentClient.getSdkInstance()
|
|
||||||
}
|
|
||||||
|
|
||||||
getRequestTransformer(): RequestTransformer<SdkParams, SdkMessageParam> {
|
|
||||||
return this.currentClient.getRequestTransformer()
|
|
||||||
}
|
|
||||||
|
|
||||||
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<SdkRawChunk> {
|
|
||||||
return this.currentClient.getResponseChunkTransformer(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] {
|
|
||||||
return this.currentClient.convertMcpToolsToSdkTools(mcpTools)
|
|
||||||
}
|
|
||||||
|
|
||||||
convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
|
||||||
return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools)
|
|
||||||
}
|
|
||||||
|
|
||||||
convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
|
||||||
return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
|
||||||
}
|
|
||||||
|
|
||||||
buildSdkMessages(
|
|
||||||
currentReqMessages: SdkMessageParam[],
|
|
||||||
output: SdkRawOutput | string,
|
|
||||||
toolResults: SdkMessageParam[],
|
|
||||||
toolCalls?: SdkToolCall[]
|
|
||||||
): SdkMessageParam[] {
|
|
||||||
return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
convertMcpToolResponseToSdkMessageParam(
|
|
||||||
mcpToolResponse: MCPToolResponse,
|
|
||||||
resp: MCPCallToolResponse,
|
|
||||||
model: Model
|
|
||||||
): SdkMessageParam | undefined {
|
|
||||||
const client = this.getClient(model)
|
|
||||||
return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] {
|
|
||||||
return this.currentClient.extractMessagesFromSdkPayload(sdkPayload)
|
|
||||||
}
|
|
||||||
|
|
||||||
estimateMessageTokens(message: SdkMessageParam): number {
|
|
||||||
return this.currentClient.estimateMessageTokens(message)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
181
src/renderer/src/aiCore/clients/MixedBaseApiClient.ts
Normal file
181
src/renderer/src/aiCore/clients/MixedBaseApiClient.ts
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
import {
|
||||||
|
GenerateImageParams,
|
||||||
|
MCPCallToolResponse,
|
||||||
|
MCPTool,
|
||||||
|
MCPToolResponse,
|
||||||
|
Model,
|
||||||
|
Provider,
|
||||||
|
ToolCallResponse
|
||||||
|
} from '@renderer/types'
|
||||||
|
import {
|
||||||
|
RequestOptions,
|
||||||
|
SdkInstance,
|
||||||
|
SdkMessageParam,
|
||||||
|
SdkModel,
|
||||||
|
SdkParams,
|
||||||
|
SdkRawChunk,
|
||||||
|
SdkRawOutput,
|
||||||
|
SdkTool,
|
||||||
|
SdkToolCall
|
||||||
|
} from '@renderer/types/sdk'
|
||||||
|
|
||||||
|
import { CompletionsContext } from '../middleware/types'
|
||||||
|
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||||
|
import { BaseApiClient } from './BaseApiClient'
|
||||||
|
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||||
|
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
||||||
|
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
||||||
|
import { RequestTransformer, ResponseChunkTransformer } from './types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MixedAPIClient - 适用于可能含有多种接口类型的Provider
|
||||||
|
*/
|
||||||
|
export abstract class MixedBaseAPIClient extends BaseApiClient {
|
||||||
|
// 使用联合类型而不是any,保持类型安全
|
||||||
|
protected abstract clients: Map<
|
||||||
|
string,
|
||||||
|
AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient
|
||||||
|
>
|
||||||
|
protected abstract defaultClient: OpenAIAPIClient
|
||||||
|
protected abstract currentClient: BaseApiClient
|
||||||
|
|
||||||
|
constructor(provider: Provider) {
|
||||||
|
super(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
override getBaseURL(): string {
|
||||||
|
if (!this.currentClient) {
|
||||||
|
return this.provider.apiHost
|
||||||
|
}
|
||||||
|
return this.currentClient.getBaseURL()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 类型守卫:确保client是BaseApiClient的实例
|
||||||
|
*/
|
||||||
|
protected isValidClient(client: unknown): client is BaseApiClient {
|
||||||
|
return (
|
||||||
|
client !== null &&
|
||||||
|
client !== undefined &&
|
||||||
|
typeof client === 'object' &&
|
||||||
|
'createCompletions' in client &&
|
||||||
|
'getRequestTransformer' in client &&
|
||||||
|
'getResponseChunkTransformer' in client
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据模型获取合适的client
|
||||||
|
*/
|
||||||
|
protected abstract getClient(model: Model): BaseApiClient
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据模型选择合适的client并委托调用
|
||||||
|
*/
|
||||||
|
public getClientForModel(model: Model): BaseApiClient {
|
||||||
|
this.currentClient = this.getClient(model)
|
||||||
|
return this.currentClient
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 重写基类方法,返回内部实际使用的客户端类型
|
||||||
|
*/
|
||||||
|
public override getClientCompatibilityType(model?: Model): string[] {
|
||||||
|
if (!model) {
|
||||||
|
return [this.constructor.name]
|
||||||
|
}
|
||||||
|
|
||||||
|
const actualClient = this.getClient(model)
|
||||||
|
return actualClient.getClientCompatibilityType(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从SDK payload中提取模型ID
|
||||||
|
*/
|
||||||
|
protected extractModelFromPayload(payload: SdkParams): string | null {
|
||||||
|
// 不同的SDK可能有不同的字段名
|
||||||
|
if ('model' in payload && typeof payload.model === 'string') {
|
||||||
|
return payload.model
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ BaseApiClient 的抽象方法 ============
|
||||||
|
|
||||||
|
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
|
||||||
|
// 尝试从payload中提取模型信息来选择client
|
||||||
|
const modelId = this.extractModelFromPayload(payload)
|
||||||
|
if (modelId) {
|
||||||
|
const modelObj = { id: modelId } as Model
|
||||||
|
const targetClient = this.getClient(modelObj)
|
||||||
|
return targetClient.createCompletions(payload, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果无法从payload中提取模型,使用当前设置的client
|
||||||
|
return this.currentClient.createCompletions(payload, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||||
|
return this.currentClient.generateImage(params)
|
||||||
|
}
|
||||||
|
|
||||||
|
async getEmbeddingDimensions(model?: Model): Promise<number> {
|
||||||
|
const client = model ? this.getClient(model) : this.currentClient
|
||||||
|
return client.getEmbeddingDimensions(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
async listModels(): Promise<SdkModel[]> {
|
||||||
|
// 可以聚合所有client的模型,或者使用默认client
|
||||||
|
return this.defaultClient.listModels()
|
||||||
|
}
|
||||||
|
|
||||||
|
async getSdkInstance(): Promise<SdkInstance> {
|
||||||
|
return this.currentClient.getSdkInstance()
|
||||||
|
}
|
||||||
|
|
||||||
|
getRequestTransformer(): RequestTransformer<SdkParams, SdkMessageParam> {
|
||||||
|
return this.currentClient.getRequestTransformer()
|
||||||
|
}
|
||||||
|
|
||||||
|
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<SdkRawChunk> {
|
||||||
|
return this.currentClient.getResponseChunkTransformer(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] {
|
||||||
|
return this.currentClient.convertMcpToolsToSdkTools(mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||||
|
return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||||
|
return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
||||||
|
}
|
||||||
|
|
||||||
|
buildSdkMessages(
|
||||||
|
currentReqMessages: SdkMessageParam[],
|
||||||
|
output: SdkRawOutput | string,
|
||||||
|
toolResults: SdkMessageParam[],
|
||||||
|
toolCalls?: SdkToolCall[]
|
||||||
|
): SdkMessageParam[] {
|
||||||
|
return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
estimateMessageTokens(message: SdkMessageParam): number {
|
||||||
|
return this.currentClient.estimateMessageTokens(message)
|
||||||
|
}
|
||||||
|
|
||||||
|
convertMcpToolResponseToSdkMessageParam(
|
||||||
|
mcpToolResponse: MCPToolResponse,
|
||||||
|
resp: MCPCallToolResponse,
|
||||||
|
model: Model
|
||||||
|
): SdkMessageParam | undefined {
|
||||||
|
const client = this.getClient(model)
|
||||||
|
return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] {
|
||||||
|
return this.currentClient.extractMessagesFromSdkPayload(sdkPayload)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,42 +1,23 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { isSupportedModel } from '@renderer/config/models'
|
import { isSupportedModel } from '@renderer/config/models'
|
||||||
import {
|
import { Model, Provider } from '@renderer/types'
|
||||||
GenerateImageParams,
|
import { NewApiModel } from '@renderer/types/sdk'
|
||||||
MCPCallToolResponse,
|
|
||||||
MCPTool,
|
|
||||||
MCPToolResponse,
|
|
||||||
Model,
|
|
||||||
Provider,
|
|
||||||
ToolCallResponse
|
|
||||||
} from '@renderer/types'
|
|
||||||
import {
|
|
||||||
NewApiModel,
|
|
||||||
RequestOptions,
|
|
||||||
SdkInstance,
|
|
||||||
SdkMessageParam,
|
|
||||||
SdkParams,
|
|
||||||
SdkRawChunk,
|
|
||||||
SdkRawOutput,
|
|
||||||
SdkTool,
|
|
||||||
SdkToolCall
|
|
||||||
} from '@renderer/types/sdk'
|
|
||||||
|
|
||||||
import { CompletionsContext } from '../middleware/types'
|
|
||||||
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||||
import { BaseApiClient } from './BaseApiClient'
|
import { BaseApiClient } from './BaseApiClient'
|
||||||
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||||
|
import { MixedBaseAPIClient } from './MixedBaseApiClient'
|
||||||
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
||||||
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
||||||
import { RequestTransformer, ResponseChunkTransformer } from './types'
|
|
||||||
|
|
||||||
const logger = loggerService.withContext('NewAPIClient')
|
const logger = loggerService.withContext('NewAPIClient')
|
||||||
|
|
||||||
export class NewAPIClient extends BaseApiClient {
|
export class NewAPIClient extends MixedBaseAPIClient {
|
||||||
// 使用联合类型而不是any,保持类型安全
|
// 使用联合类型而不是any,保持类型安全
|
||||||
private clients: Map<string, AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient> =
|
protected clients: Map<string, AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient> =
|
||||||
new Map()
|
new Map()
|
||||||
private defaultClient: OpenAIAPIClient
|
protected defaultClient: OpenAIAPIClient
|
||||||
private currentClient: BaseApiClient
|
protected currentClient: BaseApiClient
|
||||||
|
|
||||||
constructor(provider: Provider) {
|
constructor(provider: Provider) {
|
||||||
super(provider)
|
super(provider)
|
||||||
@ -63,24 +44,10 @@ export class NewAPIClient extends BaseApiClient {
|
|||||||
return this.currentClient.getBaseURL()
|
return this.currentClient.getBaseURL()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 类型守卫:确保client是BaseApiClient的实例
|
|
||||||
*/
|
|
||||||
private isValidClient(client: unknown): client is BaseApiClient {
|
|
||||||
return (
|
|
||||||
client !== null &&
|
|
||||||
client !== undefined &&
|
|
||||||
typeof client === 'object' &&
|
|
||||||
'createCompletions' in client &&
|
|
||||||
'getRequestTransformer' in client &&
|
|
||||||
'getResponseChunkTransformer' in client
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 根据模型获取合适的client
|
* 根据模型获取合适的client
|
||||||
*/
|
*/
|
||||||
private getClient(model: Model): BaseApiClient {
|
protected getClient(model: Model): BaseApiClient {
|
||||||
if (!model.endpoint_type) {
|
if (!model.endpoint_type) {
|
||||||
throw new Error('Model endpoint type is not defined')
|
throw new Error('Model endpoint type is not defined')
|
||||||
}
|
}
|
||||||
@ -120,61 +87,6 @@ export class NewAPIClient extends BaseApiClient {
|
|||||||
throw new Error('Invalid model endpoint type: ' + model.endpoint_type)
|
throw new Error('Invalid model endpoint type: ' + model.endpoint_type)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 根据模型选择合适的client并委托调用
|
|
||||||
*/
|
|
||||||
public getClientForModel(model: Model): BaseApiClient {
|
|
||||||
this.currentClient = this.getClient(model)
|
|
||||||
return this.currentClient
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 重写基类方法,返回内部实际使用的客户端类型
|
|
||||||
*/
|
|
||||||
public override getClientCompatibilityType(model?: Model): string[] {
|
|
||||||
if (!model) {
|
|
||||||
return [this.constructor.name]
|
|
||||||
}
|
|
||||||
|
|
||||||
const actualClient = this.getClient(model)
|
|
||||||
return actualClient.getClientCompatibilityType(model)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============ BaseApiClient 抽象方法实现 ============
|
|
||||||
|
|
||||||
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
|
|
||||||
// 尝试从payload中提取模型信息来选择client
|
|
||||||
const modelId = this.extractModelFromPayload(payload)
|
|
||||||
if (modelId) {
|
|
||||||
const modelObj = { id: modelId } as Model
|
|
||||||
const targetClient = this.getClient(modelObj)
|
|
||||||
return targetClient.createCompletions(payload, options)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果无法从payload中提取模型,使用当前设置的client
|
|
||||||
return this.currentClient.createCompletions(payload, options)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 从SDK payload中提取模型ID
|
|
||||||
*/
|
|
||||||
private extractModelFromPayload(payload: SdkParams): string | null {
|
|
||||||
// 不同的SDK可能有不同的字段名
|
|
||||||
if ('model' in payload && typeof payload.model === 'string') {
|
|
||||||
return payload.model
|
|
||||||
}
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
async generateImage(params: GenerateImageParams): Promise<string[]> {
|
|
||||||
return this.currentClient.generateImage(params)
|
|
||||||
}
|
|
||||||
|
|
||||||
async getEmbeddingDimensions(model?: Model): Promise<number> {
|
|
||||||
const client = model ? this.getClient(model) : this.currentClient
|
|
||||||
return client.getEmbeddingDimensions(model)
|
|
||||||
}
|
|
||||||
|
|
||||||
override async listModels(): Promise<NewApiModel[]> {
|
override async listModels(): Promise<NewApiModel[]> {
|
||||||
try {
|
try {
|
||||||
const sdk = await this.defaultClient.getSdkInstance()
|
const sdk = await this.defaultClient.getSdkInstance()
|
||||||
@ -195,54 +107,4 @@ export class NewAPIClient extends BaseApiClient {
|
|||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async getSdkInstance(): Promise<SdkInstance> {
|
|
||||||
return this.currentClient.getSdkInstance()
|
|
||||||
}
|
|
||||||
|
|
||||||
getRequestTransformer(): RequestTransformer<SdkParams, SdkMessageParam> {
|
|
||||||
return this.currentClient.getRequestTransformer()
|
|
||||||
}
|
|
||||||
|
|
||||||
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<SdkRawChunk> {
|
|
||||||
return this.currentClient.getResponseChunkTransformer(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] {
|
|
||||||
return this.currentClient.convertMcpToolsToSdkTools(mcpTools)
|
|
||||||
}
|
|
||||||
|
|
||||||
convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
|
||||||
return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools)
|
|
||||||
}
|
|
||||||
|
|
||||||
convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
|
||||||
return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
|
||||||
}
|
|
||||||
|
|
||||||
buildSdkMessages(
|
|
||||||
currentReqMessages: SdkMessageParam[],
|
|
||||||
output: SdkRawOutput | string,
|
|
||||||
toolResults: SdkMessageParam[],
|
|
||||||
toolCalls?: SdkToolCall[]
|
|
||||||
): SdkMessageParam[] {
|
|
||||||
return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
convertMcpToolResponseToSdkMessageParam(
|
|
||||||
mcpToolResponse: MCPToolResponse,
|
|
||||||
resp: MCPCallToolResponse,
|
|
||||||
model: Model
|
|
||||||
): SdkMessageParam | undefined {
|
|
||||||
const client = this.getClient(model)
|
|
||||||
return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] {
|
|
||||||
return this.currentClient.extractMessagesFromSdkPayload(sdkPayload)
|
|
||||||
}
|
|
||||||
|
|
||||||
estimateMessageTokens(message: SdkMessageParam): number {
|
|
||||||
return this.currentClient.estimateMessageTokens(message)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user