diff --git a/src/renderer/src/aiCore/clients/__tests__/ApiClientFactory.test.ts b/src/renderer/src/aiCore/clients/__tests__/ApiClientFactory.test.ts index 172798dc38..5ec3bf6404 100644 --- a/src/renderer/src/aiCore/clients/__tests__/ApiClientFactory.test.ts +++ b/src/renderer/src/aiCore/clients/__tests__/ApiClientFactory.test.ts @@ -5,6 +5,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import { AihubmixAPIClient } from '../AihubmixAPIClient' import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient' import { ApiClientFactory } from '../ApiClientFactory' +import { AwsBedrockAPIClient } from '../aws/AwsBedrockAPIClient' import { GeminiAPIClient } from '../gemini/GeminiAPIClient' import { VertexAPIClient } from '../gemini/VertexAPIClient' import { NewAPIClient } from '../NewAPIClient' @@ -54,6 +55,19 @@ vi.mock('../openai/OpenAIResponseAPIClient', () => ({ vi.mock('../ppio/PPIOAPIClient', () => ({ PPIOAPIClient: vi.fn().mockImplementation(() => ({})) })) +vi.mock('../aws/AwsBedrockAPIClient', () => ({ + AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({})) +})) + +// Mock the models config to prevent circular dependency issues +vi.mock('@renderer/config/models', () => ({ + findTokenLimit: vi.fn(), + isReasoningModel: vi.fn(), + SYSTEM_MODELS: { + silicon: [], + defaultModel: [] + } +})) describe('ApiClientFactory', () => { beforeEach(() => { @@ -144,6 +158,15 @@ describe('ApiClientFactory', () => { expect(client).toBeDefined() }) + it('should create AwsBedrockAPIClient for aws-bedrock type', () => { + const provider = createTestProvider('aws-bedrock', 'aws-bedrock') + + const client = ApiClientFactory.create(provider) + + expect(AwsBedrockAPIClient).toHaveBeenCalledWith(provider) + expect(client).toBeDefined() + }) + // 测试默认情况 it('should create OpenAIAPIClient as default for unknown type', () => { const provider = createTestProvider('unknown', 'unknown-type') diff --git a/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts b/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts index 4e29a0cb5c..d9bd9af9c8 100644 --- a/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts +++ b/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts @@ -2,19 +2,23 @@ import { BedrockClient, ListFoundationModelsCommand, ListInferenceProfilesComman import { BedrockRuntimeClient, ConverseCommand, - ConverseStreamCommand, - InvokeModelCommand + InvokeModelCommand, + InvokeModelWithResponseStreamCommand } from '@aws-sdk/client-bedrock-runtime' import { loggerService } from '@logger' import { GenericChunk } from '@renderer/aiCore/middleware/schemas' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' +import { findTokenLimit, isReasoningModel } from '@renderer/config/models' import { getAwsBedrockAccessKeyId, getAwsBedrockRegion, getAwsBedrockSecretAccessKey } from '@renderer/hooks/useAwsBedrock' +import { getAssistantSettings } from '@renderer/services/AssistantService' import { estimateTextTokens } from '@renderer/services/TokenService' import { + Assistant, + EFFORT_RATIO, GenerateImageParams, MCPCallToolResponse, MCPTool, @@ -23,7 +27,13 @@ import { Provider, ToolCallResponse } from '@renderer/types' -import { ChunkType, MCPToolCreatedChunk, TextDeltaChunk } from '@renderer/types/chunk' +import { + ChunkType, + MCPToolCreatedChunk, + TextDeltaChunk, + ThinkingDeltaChunk, + ThinkingStartChunk +} from '@renderer/types/chunk' import { Message } from '@renderer/types/newMessage' import { AwsBedrockSdkInstance, @@ -33,6 +43,7 @@ import { AwsBedrockSdkRawOutput, AwsBedrockSdkTool, AwsBedrockSdkToolCall, + AwsBedrockStreamChunk, SdkModel } from '@renderer/types/sdk' import { convertBase64ImageToAwsBedrockFormat } from '@renderer/utils/aws-bedrock-utils' @@ -103,46 +114,65 @@ export class AwsBedrockAPIClient extends BaseApiClient< override async createCompletions(payload: AwsBedrockSdkParams): Promise { const sdk = await this.getSdkInstance() - // 转换消息格式到AWS SDK原生格式 + // 转换消息格式(用于 InvokeModelWithResponseStreamCommand) const awsMessages = payload.messages.map((msg) => ({ role: msg.role, content: msg.content.map((content) => { if (content.text) { - return { text: content.text } + return { type: 'text', text: content.text } } if (content.image) { + // 处理图片数据,将 Uint8Array 或数字数组转换为 base64 字符串 + let base64Data = '' + if (content.image.source.bytes) { + if (typeof content.image.source.bytes === 'string') { + // 如果已经是字符串,直接使用 + base64Data = content.image.source.bytes + } else { + // 如果是数组或 Uint8Array,转换为 base64 + const uint8Array = new Uint8Array(Object.values(content.image.source.bytes)) + const binaryString = Array.from(uint8Array) + .map((byte) => String.fromCharCode(byte)) + .join('') + base64Data = btoa(binaryString) + } + } + return { - image: { - format: content.image.format, - source: content.image.source + type: 'image', + source: { + type: 'base64', + media_type: `image/${content.image.format}`, + data: base64Data } } } if (content.toolResult) { return { - toolResult: { - toolUseId: content.toolResult.toolUseId, - content: content.toolResult.content, - status: content.toolResult.status - } + type: 'tool_result', + tool_use_id: content.toolResult.toolUseId, + content: content.toolResult.content } } if (content.toolUse) { return { - toolUse: { - toolUseId: content.toolUse.toolUseId, - name: content.toolUse.name, - input: content.toolUse.input - } + type: 'tool_use', + id: content.toolUse.toolUseId, + name: content.toolUse.name, + input: content.toolUse.input } } - // 返回符合AWS SDK ContentBlock类型的对象 - return { text: 'Unknown content type' } + return { type: 'text', text: 'Unknown content type' } }) })) logger.info('Creating completions with model ID:', { modelId: payload.modelId }) + const excludeKeys = ['modelId', 'messages', 'system', 'maxTokens', 'temperature', 'topP', 'stream', 'tools'] + const additionalParams = Object.keys(payload) + .filter((key) => !excludeKeys.includes(key)) + .reduce((acc, key) => ({ ...acc, [key]: payload[key] }), {}) + const commonParams = { modelId: payload.modelId, messages: awsMessages as any, @@ -162,10 +192,18 @@ export class AwsBedrockAPIClient extends BaseApiClient< try { if (payload.stream) { - const command = new ConverseStreamCommand(commonParams) + // 根据模型类型选择正确的 API 格式 + const requestBody = this.createRequestBodyForModel(commonParams, additionalParams) + + const command = new InvokeModelWithResponseStreamCommand({ + modelId: commonParams.modelId, + body: JSON.stringify(requestBody), + contentType: 'application/json', + accept: 'application/json' + }) + const response = await sdk.client.send(command) - // 直接返回AWS Bedrock流式响应的异步迭代器 - return this.createStreamIterator(response) + return this.createInvokeModelStreamIterator(response) } else { const command = new ConverseCommand(commonParams) const response = await sdk.client.send(command) @@ -177,32 +215,236 @@ export class AwsBedrockAPIClient extends BaseApiClient< } } - private async *createStreamIterator(response: any): AsyncIterable { - try { - if (response.stream) { - for await (const chunk of response.stream) { - logger.debug('AWS Bedrock chunk received:', chunk) + /** + * 根据模型类型创建请求体 + */ + private createRequestBodyForModel(commonParams: any, additionalParams: any): any { + const modelId = commonParams.modelId.toLowerCase() - // AWS Bedrock的流式响应格式转换为标准格式 - if (chunk.contentBlockDelta?.delta?.text) { - yield { - contentBlockDelta: { - delta: { text: chunk.contentBlockDelta.delta.text } + // Claude 系列模型使用 Anthropic API 格式 + if (modelId.includes('claude')) { + return { + anthropic_version: 'bedrock-2023-05-31', + max_tokens: commonParams.inferenceConfig.maxTokens, + temperature: commonParams.inferenceConfig.temperature, + top_p: commonParams.inferenceConfig.topP, + messages: commonParams.messages, + ...(commonParams.system && commonParams.system[0]?.text ? { system: commonParams.system[0].text } : {}), + ...(commonParams.toolConfig?.tools ? { tools: commonParams.toolConfig.tools } : {}), + ...additionalParams + } + } + + // OpenAI 系列模型 + if (modelId.includes('gpt') || modelId.includes('openai')) { + const messages: any[] = [] + + // 添加系统消息 + if (commonParams.system && commonParams.system[0]?.text) { + messages.push({ + role: 'system', + content: commonParams.system[0].text + }) + } + + // 转换消息格式 + for (const message of commonParams.messages) { + const content: any[] = [] + for (const part of message.content) { + if (part.text) { + content.push({ type: 'text', text: part.text }) + } else if (part.image) { + content.push({ + type: 'image_url', + image_url: { + url: `data:image/${part.image.format};base64,${part.image.source.bytes}` + } + }) + } + } + messages.push({ + role: message.role, + content: content.length === 1 && content[0].type === 'text' ? content[0].text : content + }) + } + + const baseBody: any = { + model: commonParams.modelId, + messages: messages, + max_tokens: commonParams.inferenceConfig.maxTokens, + temperature: commonParams.inferenceConfig.temperature, + top_p: commonParams.inferenceConfig.topP, + stream: true, + ...(commonParams.toolConfig?.tools ? { tools: commonParams.toolConfig.tools } : {}) + } + + // OpenAI 模型的 thinking 参数格式 + if (additionalParams.reasoning_effort) { + baseBody.reasoning_effort = additionalParams.reasoning_effort + delete additionalParams.reasoning_effort + } + + return { + ...baseBody, + ...additionalParams + } + } + + // Llama 系列模型 + if (modelId.includes('llama')) { + const baseBody: any = { + prompt: this.convertMessagesToPrompt(commonParams.messages, commonParams.system), + max_gen_len: commonParams.inferenceConfig.maxTokens, + temperature: commonParams.inferenceConfig.temperature, + top_p: commonParams.inferenceConfig.topP + } + + // Llama 模型的 thinking 参数格式 + if (additionalParams.thinking_mode) { + baseBody.thinking_mode = additionalParams.thinking_mode + delete additionalParams.thinking_mode + } + + return { + ...baseBody, + ...additionalParams + } + } + + // Amazon Titan 系列模型 + if (modelId.includes('titan')) { + const textGenerationConfig: any = { + maxTokenCount: commonParams.inferenceConfig.maxTokens, + temperature: commonParams.inferenceConfig.temperature, + topP: commonParams.inferenceConfig.topP + } + + // 将 thinking 相关参数添加到 textGenerationConfig 中 + if (additionalParams.thinking) { + textGenerationConfig.thinking = additionalParams.thinking + delete additionalParams.thinking + } + + return { + inputText: this.convertMessagesToPrompt(commonParams.messages, commonParams.system), + textGenerationConfig: { + ...textGenerationConfig, + ...Object.keys(additionalParams).reduce((acc, key) => { + if (['thinking_tokens', 'reasoning_mode'].includes(key)) { + acc[key] = additionalParams[key] + delete additionalParams[key] + } + return acc + }, {} as any) + }, + ...additionalParams + } + } + + // Cohere Command 系列模型 + if (modelId.includes('cohere') || modelId.includes('command')) { + const baseBody: any = { + message: this.convertMessagesToPrompt(commonParams.messages, commonParams.system), + max_tokens: commonParams.inferenceConfig.maxTokens, + temperature: commonParams.inferenceConfig.temperature, + p: commonParams.inferenceConfig.topP + } + + // Cohere 模型的 thinking 参数格式 + if (additionalParams.thinking) { + baseBody.thinking = additionalParams.thinking + delete additionalParams.thinking + } + if (additionalParams.reasoning_tokens) { + baseBody.reasoning_tokens = additionalParams.reasoning_tokens + delete additionalParams.reasoning_tokens + } + + return { + ...baseBody, + ...additionalParams + } + } + + // 默认使用通用格式 + const baseBody: any = { + prompt: this.convertMessagesToPrompt(commonParams.messages, commonParams.system), + max_tokens: commonParams.inferenceConfig.maxTokens, + temperature: commonParams.inferenceConfig.temperature, + top_p: commonParams.inferenceConfig.topP + } + + return { + ...baseBody, + ...additionalParams + } + } + + /** + * 将消息转换为简单的 prompt 格式 + */ + private convertMessagesToPrompt(messages: any[], system?: any[]): string { + let prompt = '' + + // 添加系统消息 + if (system && system[0]?.text) { + prompt += `System: ${system[0].text}\n\n` + } + + // 添加对话消息 + for (const message of messages) { + const role = message.role === 'assistant' ? 'Assistant' : 'Human' + let content = '' + + for (const part of message.content) { + if (part.text) { + content += part.text + } else if (part.image) { + content += '[Image]' + } + } + + prompt += `${role}: ${content}\n\n` + } + + prompt += 'Assistant:' + return prompt + } + + private async *createInvokeModelStreamIterator(response: any): AsyncIterable { + try { + if (response.body) { + for await (const event of response.body) { + if (event.chunk) { + const chunk: AwsBedrockStreamChunk = JSON.parse(new TextDecoder().decode(event.chunk.bytes)) + + // 转换为标准格式 + if (chunk.type === 'content_block_delta') { + yield { + contentBlockDelta: { + delta: chunk.delta, + contentBlockIndex: chunk.index + } + } + } else if (chunk.type === 'message_start') { + yield { messageStart: chunk } + } else if (chunk.type === 'message_stop') { + yield { messageStop: chunk } + } else if (chunk.type === 'content_block_start') { + yield { + contentBlockStart: { + start: chunk.content_block, + contentBlockIndex: chunk.index + } + } + } else if (chunk.type === 'content_block_stop') { + yield { + contentBlockStop: { + contentBlockIndex: chunk.index + } } } } - - if (chunk.messageStart) { - yield { messageStart: chunk.messageStart } - } - - if (chunk.messageStop) { - yield { messageStop: chunk.messageStop } - } - - if (chunk.metadata) { - yield { metadata: chunk.metadata } - } } } } catch (error) { @@ -485,6 +727,38 @@ export class AwsBedrockAPIClient extends BaseApiClient< } } + // 获取推理预算token(对所有支持推理的模型) + const budgetTokens = this.getBudgetToken(assistant, model) + + // 构建基础自定义参数 + const customParams: Record = + coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {} + + // 根据模型类型添加 thinking 参数 + if (budgetTokens) { + const modelId = model.id.toLowerCase() + + if (modelId.includes('claude')) { + // Claude 模型使用 Anthropic 格式 + customParams.thinking = { type: 'enabled', budget_tokens: budgetTokens } + } else if (modelId.includes('gpt') || modelId.includes('openai')) { + // OpenAI 模型格式 + customParams.reasoning_effort = assistant?.settings?.reasoning_effort + } else if (modelId.includes('llama')) { + // Llama 模型格式 + customParams.thinking_mode = true + customParams.thinking_tokens = budgetTokens + } else if (modelId.includes('titan')) { + // Titan 模型格式 + customParams.thinking = { enabled: true } + customParams.thinking_tokens = budgetTokens + } else if (modelId.includes('cohere') || modelId.includes('command')) { + // Cohere 模型格式 + customParams.thinking = { enabled: true } + customParams.reasoning_tokens = budgetTokens + } + } + const payload: AwsBedrockSdkParams = { modelId: model.id, messages: @@ -497,9 +771,7 @@ export class AwsBedrockAPIClient extends BaseApiClient< topP: this.getTopP(assistant, model), stream: streamOutput !== false, tools: tools.length > 0 ? tools : undefined, - // 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑 - // 注意:用户自定义参数总是应该覆盖其他参数 - ...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {}) + ...customParams } const timeout = this.getTimeout(model) @@ -511,6 +783,7 @@ export class AwsBedrockAPIClient extends BaseApiClient< getResponseChunkTransformer(): ResponseChunkTransformer { return () => { let hasStartedText = false + let hasStartedThinking = false let accumulatedJson = '' const toolCalls: Record = {} @@ -570,6 +843,24 @@ export class AwsBedrockAPIClient extends BaseApiClient< } as TextDeltaChunk) } + // 处理thinking增量 + if ( + rawChunk.contentBlockDelta?.delta?.type === 'thinking_delta' && + rawChunk.contentBlockDelta?.delta?.thinking + ) { + if (!hasStartedThinking) { + controller.enqueue({ + type: ChunkType.THINKING_START + } as ThinkingStartChunk) + hasStartedThinking = true + } + + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: rawChunk.contentBlockDelta.delta.thinking + } as ThinkingDeltaChunk) + } + // 处理内容块停止事件 - 参考 Anthropic 的 content_block_stop 处理 if (rawChunk.contentBlockStop) { const blockIndex = rawChunk.contentBlockStop.contentBlockIndex || 0 @@ -708,4 +999,49 @@ export class AwsBedrockAPIClient extends BaseApiClient< extractMessagesFromSdkPayload(sdkPayload: AwsBedrockSdkParams): AwsBedrockSdkMessageParam[] { return sdkPayload.messages || [] } + + /** + * 获取 AWS Bedrock 的推理工作量预算token + * @param assistant - The assistant + * @param model - The model + * @returns The budget tokens for reasoning effort + */ + private getBudgetToken(assistant: Assistant, model: Model): number | undefined { + try { + if (!isReasoningModel(model)) { + return undefined + } + + const { maxTokens } = getAssistantSettings(assistant) + const reasoningEffort = assistant?.settings?.reasoning_effort + + if (reasoningEffort === undefined) { + return undefined + } + + const effortRatio = EFFORT_RATIO[reasoningEffort] + const tokenLimits = findTokenLimit(model.id) + + if (tokenLimits) { + // 使用模型特定的 token 限制 + const budgetTokens = Math.max( + 1024, + Math.floor( + Math.min( + (tokenLimits.max - tokenLimits.min) * effortRatio + tokenLimits.min, + (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio + ) + ) + ) + return budgetTokens + } else { + // 对于没有特定限制的模型,使用简化计算 + const budgetTokens = Math.max(1024, Math.floor((maxTokens || DEFAULT_MAX_TOKENS) * effortRatio)) + return budgetTokens + } + } catch (error) { + logger.warn('Failed to calculate budget tokens for reasoning effort:', error as Error) + return undefined + } + } } diff --git a/src/renderer/src/types/sdk.ts b/src/renderer/src/types/sdk.ts index a5413e54fb..36608ab9fe 100644 --- a/src/renderer/src/types/sdk.ts +++ b/src/renderer/src/types/sdk.ts @@ -162,6 +162,7 @@ export interface AwsBedrockSdkParams { topP?: number stream?: boolean tools?: AwsBedrockSdkTool[] + [key: string]: any // Allow any additional custom parameters } export interface AwsBedrockSdkMessageParam { @@ -206,6 +207,22 @@ export interface AwsBedrockSdkMessageParam { }> } +export interface AwsBedrockStreamChunk { + type: string + delta?: { + text?: string + toolUse?: { input?: string } + type?: string + thinking?: string + } + index?: number + content_block?: any + usage?: { + inputTokens?: number + outputTokens?: number + } +} + export interface AwsBedrockSdkRawChunk { contentBlockStart?: { start?: { @@ -222,6 +239,8 @@ export interface AwsBedrockSdkRawChunk { toolUse?: { input?: string } + type?: string // 支持 'thinking_delta' 等类型 + thinking?: string // 支持 thinking 内容 } contentBlockIndex?: number }