fix(aws-bedrock): support thinking mode (#9172)

* fix(aws-bedrock): support thinking mode

* fix(aws-bedrock): fix code review suggestions

* fix(aws-bedrock): Add thinking processing for other models
This commit is contained in:
陈天寒 2025-08-15 15:13:48 +08:00 committed by GitHub
parent c2561726e0
commit 748ac600fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 426 additions and 48 deletions

View File

@ -5,6 +5,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
import { AihubmixAPIClient } from '../AihubmixAPIClient'
import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient'
import { ApiClientFactory } from '../ApiClientFactory'
import { AwsBedrockAPIClient } from '../aws/AwsBedrockAPIClient'
import { GeminiAPIClient } from '../gemini/GeminiAPIClient'
import { VertexAPIClient } from '../gemini/VertexAPIClient'
import { NewAPIClient } from '../NewAPIClient'
@ -54,6 +55,19 @@ vi.mock('../openai/OpenAIResponseAPIClient', () => ({
vi.mock('../ppio/PPIOAPIClient', () => ({
PPIOAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../aws/AwsBedrockAPIClient', () => ({
AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({}))
}))
// Mock the models config to prevent circular dependency issues
vi.mock('@renderer/config/models', () => ({
findTokenLimit: vi.fn(),
isReasoningModel: vi.fn(),
SYSTEM_MODELS: {
silicon: [],
defaultModel: []
}
}))
describe('ApiClientFactory', () => {
beforeEach(() => {
@ -144,6 +158,15 @@ describe('ApiClientFactory', () => {
expect(client).toBeDefined()
})
it('should create AwsBedrockAPIClient for aws-bedrock type', () => {
const provider = createTestProvider('aws-bedrock', 'aws-bedrock')
const client = ApiClientFactory.create(provider)
expect(AwsBedrockAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
// 测试默认情况
it('should create OpenAIAPIClient as default for unknown type', () => {
const provider = createTestProvider('unknown', 'unknown-type')

View File

@ -2,19 +2,23 @@ import { BedrockClient, ListFoundationModelsCommand, ListInferenceProfilesComman
import {
BedrockRuntimeClient,
ConverseCommand,
ConverseStreamCommand,
InvokeModelCommand
InvokeModelCommand,
InvokeModelWithResponseStreamCommand
} from '@aws-sdk/client-bedrock-runtime'
import { loggerService } from '@logger'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { findTokenLimit, isReasoningModel } from '@renderer/config/models'
import {
getAwsBedrockAccessKeyId,
getAwsBedrockRegion,
getAwsBedrockSecretAccessKey
} from '@renderer/hooks/useAwsBedrock'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import { estimateTextTokens } from '@renderer/services/TokenService'
import {
Assistant,
EFFORT_RATIO,
GenerateImageParams,
MCPCallToolResponse,
MCPTool,
@ -23,7 +27,13 @@ import {
Provider,
ToolCallResponse
} from '@renderer/types'
import { ChunkType, MCPToolCreatedChunk, TextDeltaChunk } from '@renderer/types/chunk'
import {
ChunkType,
MCPToolCreatedChunk,
TextDeltaChunk,
ThinkingDeltaChunk,
ThinkingStartChunk
} from '@renderer/types/chunk'
import { Message } from '@renderer/types/newMessage'
import {
AwsBedrockSdkInstance,
@ -33,6 +43,7 @@ import {
AwsBedrockSdkRawOutput,
AwsBedrockSdkTool,
AwsBedrockSdkToolCall,
AwsBedrockStreamChunk,
SdkModel
} from '@renderer/types/sdk'
import { convertBase64ImageToAwsBedrockFormat } from '@renderer/utils/aws-bedrock-utils'
@ -103,46 +114,65 @@ export class AwsBedrockAPIClient extends BaseApiClient<
override async createCompletions(payload: AwsBedrockSdkParams): Promise<AwsBedrockSdkRawOutput> {
const sdk = await this.getSdkInstance()
// 转换消息格式到AWS SDK原生格式
// 转换消息格式(用于 InvokeModelWithResponseStreamCommand
const awsMessages = payload.messages.map((msg) => ({
role: msg.role,
content: msg.content.map((content) => {
if (content.text) {
return { text: content.text }
return { type: 'text', text: content.text }
}
if (content.image) {
// 处理图片数据,将 Uint8Array 或数字数组转换为 base64 字符串
let base64Data = ''
if (content.image.source.bytes) {
if (typeof content.image.source.bytes === 'string') {
// 如果已经是字符串,直接使用
base64Data = content.image.source.bytes
} else {
// 如果是数组或 Uint8Array转换为 base64
const uint8Array = new Uint8Array(Object.values(content.image.source.bytes))
const binaryString = Array.from(uint8Array)
.map((byte) => String.fromCharCode(byte))
.join('')
base64Data = btoa(binaryString)
}
}
return {
image: {
format: content.image.format,
source: content.image.source
type: 'image',
source: {
type: 'base64',
media_type: `image/${content.image.format}`,
data: base64Data
}
}
}
if (content.toolResult) {
return {
toolResult: {
toolUseId: content.toolResult.toolUseId,
content: content.toolResult.content,
status: content.toolResult.status
}
type: 'tool_result',
tool_use_id: content.toolResult.toolUseId,
content: content.toolResult.content
}
}
if (content.toolUse) {
return {
toolUse: {
toolUseId: content.toolUse.toolUseId,
name: content.toolUse.name,
input: content.toolUse.input
}
type: 'tool_use',
id: content.toolUse.toolUseId,
name: content.toolUse.name,
input: content.toolUse.input
}
}
// 返回符合AWS SDK ContentBlock类型的对象
return { text: 'Unknown content type' }
return { type: 'text', text: 'Unknown content type' }
})
}))
logger.info('Creating completions with model ID:', { modelId: payload.modelId })
const excludeKeys = ['modelId', 'messages', 'system', 'maxTokens', 'temperature', 'topP', 'stream', 'tools']
const additionalParams = Object.keys(payload)
.filter((key) => !excludeKeys.includes(key))
.reduce((acc, key) => ({ ...acc, [key]: payload[key] }), {})
const commonParams = {
modelId: payload.modelId,
messages: awsMessages as any,
@ -162,10 +192,18 @@ export class AwsBedrockAPIClient extends BaseApiClient<
try {
if (payload.stream) {
const command = new ConverseStreamCommand(commonParams)
// 根据模型类型选择正确的 API 格式
const requestBody = this.createRequestBodyForModel(commonParams, additionalParams)
const command = new InvokeModelWithResponseStreamCommand({
modelId: commonParams.modelId,
body: JSON.stringify(requestBody),
contentType: 'application/json',
accept: 'application/json'
})
const response = await sdk.client.send(command)
// 直接返回AWS Bedrock流式响应的异步迭代器
return this.createStreamIterator(response)
return this.createInvokeModelStreamIterator(response)
} else {
const command = new ConverseCommand(commonParams)
const response = await sdk.client.send(command)
@ -177,32 +215,236 @@ export class AwsBedrockAPIClient extends BaseApiClient<
}
}
private async *createStreamIterator(response: any): AsyncIterable<AwsBedrockSdkRawChunk> {
try {
if (response.stream) {
for await (const chunk of response.stream) {
logger.debug('AWS Bedrock chunk received:', chunk)
/**
*
*/
private createRequestBodyForModel(commonParams: any, additionalParams: any): any {
const modelId = commonParams.modelId.toLowerCase()
// AWS Bedrock的流式响应格式转换为标准格式
if (chunk.contentBlockDelta?.delta?.text) {
yield {
contentBlockDelta: {
delta: { text: chunk.contentBlockDelta.delta.text }
// Claude 系列模型使用 Anthropic API 格式
if (modelId.includes('claude')) {
return {
anthropic_version: 'bedrock-2023-05-31',
max_tokens: commonParams.inferenceConfig.maxTokens,
temperature: commonParams.inferenceConfig.temperature,
top_p: commonParams.inferenceConfig.topP,
messages: commonParams.messages,
...(commonParams.system && commonParams.system[0]?.text ? { system: commonParams.system[0].text } : {}),
...(commonParams.toolConfig?.tools ? { tools: commonParams.toolConfig.tools } : {}),
...additionalParams
}
}
// OpenAI 系列模型
if (modelId.includes('gpt') || modelId.includes('openai')) {
const messages: any[] = []
// 添加系统消息
if (commonParams.system && commonParams.system[0]?.text) {
messages.push({
role: 'system',
content: commonParams.system[0].text
})
}
// 转换消息格式
for (const message of commonParams.messages) {
const content: any[] = []
for (const part of message.content) {
if (part.text) {
content.push({ type: 'text', text: part.text })
} else if (part.image) {
content.push({
type: 'image_url',
image_url: {
url: `data:image/${part.image.format};base64,${part.image.source.bytes}`
}
})
}
}
messages.push({
role: message.role,
content: content.length === 1 && content[0].type === 'text' ? content[0].text : content
})
}
const baseBody: any = {
model: commonParams.modelId,
messages: messages,
max_tokens: commonParams.inferenceConfig.maxTokens,
temperature: commonParams.inferenceConfig.temperature,
top_p: commonParams.inferenceConfig.topP,
stream: true,
...(commonParams.toolConfig?.tools ? { tools: commonParams.toolConfig.tools } : {})
}
// OpenAI 模型的 thinking 参数格式
if (additionalParams.reasoning_effort) {
baseBody.reasoning_effort = additionalParams.reasoning_effort
delete additionalParams.reasoning_effort
}
return {
...baseBody,
...additionalParams
}
}
// Llama 系列模型
if (modelId.includes('llama')) {
const baseBody: any = {
prompt: this.convertMessagesToPrompt(commonParams.messages, commonParams.system),
max_gen_len: commonParams.inferenceConfig.maxTokens,
temperature: commonParams.inferenceConfig.temperature,
top_p: commonParams.inferenceConfig.topP
}
// Llama 模型的 thinking 参数格式
if (additionalParams.thinking_mode) {
baseBody.thinking_mode = additionalParams.thinking_mode
delete additionalParams.thinking_mode
}
return {
...baseBody,
...additionalParams
}
}
// Amazon Titan 系列模型
if (modelId.includes('titan')) {
const textGenerationConfig: any = {
maxTokenCount: commonParams.inferenceConfig.maxTokens,
temperature: commonParams.inferenceConfig.temperature,
topP: commonParams.inferenceConfig.topP
}
// 将 thinking 相关参数添加到 textGenerationConfig 中
if (additionalParams.thinking) {
textGenerationConfig.thinking = additionalParams.thinking
delete additionalParams.thinking
}
return {
inputText: this.convertMessagesToPrompt(commonParams.messages, commonParams.system),
textGenerationConfig: {
...textGenerationConfig,
...Object.keys(additionalParams).reduce((acc, key) => {
if (['thinking_tokens', 'reasoning_mode'].includes(key)) {
acc[key] = additionalParams[key]
delete additionalParams[key]
}
return acc
}, {} as any)
},
...additionalParams
}
}
// Cohere Command 系列模型
if (modelId.includes('cohere') || modelId.includes('command')) {
const baseBody: any = {
message: this.convertMessagesToPrompt(commonParams.messages, commonParams.system),
max_tokens: commonParams.inferenceConfig.maxTokens,
temperature: commonParams.inferenceConfig.temperature,
p: commonParams.inferenceConfig.topP
}
// Cohere 模型的 thinking 参数格式
if (additionalParams.thinking) {
baseBody.thinking = additionalParams.thinking
delete additionalParams.thinking
}
if (additionalParams.reasoning_tokens) {
baseBody.reasoning_tokens = additionalParams.reasoning_tokens
delete additionalParams.reasoning_tokens
}
return {
...baseBody,
...additionalParams
}
}
// 默认使用通用格式
const baseBody: any = {
prompt: this.convertMessagesToPrompt(commonParams.messages, commonParams.system),
max_tokens: commonParams.inferenceConfig.maxTokens,
temperature: commonParams.inferenceConfig.temperature,
top_p: commonParams.inferenceConfig.topP
}
return {
...baseBody,
...additionalParams
}
}
/**
* prompt
*/
private convertMessagesToPrompt(messages: any[], system?: any[]): string {
let prompt = ''
// 添加系统消息
if (system && system[0]?.text) {
prompt += `System: ${system[0].text}\n\n`
}
// 添加对话消息
for (const message of messages) {
const role = message.role === 'assistant' ? 'Assistant' : 'Human'
let content = ''
for (const part of message.content) {
if (part.text) {
content += part.text
} else if (part.image) {
content += '[Image]'
}
}
prompt += `${role}: ${content}\n\n`
}
prompt += 'Assistant:'
return prompt
}
private async *createInvokeModelStreamIterator(response: any): AsyncIterable<AwsBedrockSdkRawChunk> {
try {
if (response.body) {
for await (const event of response.body) {
if (event.chunk) {
const chunk: AwsBedrockStreamChunk = JSON.parse(new TextDecoder().decode(event.chunk.bytes))
// 转换为标准格式
if (chunk.type === 'content_block_delta') {
yield {
contentBlockDelta: {
delta: chunk.delta,
contentBlockIndex: chunk.index
}
}
} else if (chunk.type === 'message_start') {
yield { messageStart: chunk }
} else if (chunk.type === 'message_stop') {
yield { messageStop: chunk }
} else if (chunk.type === 'content_block_start') {
yield {
contentBlockStart: {
start: chunk.content_block,
contentBlockIndex: chunk.index
}
}
} else if (chunk.type === 'content_block_stop') {
yield {
contentBlockStop: {
contentBlockIndex: chunk.index
}
}
}
}
if (chunk.messageStart) {
yield { messageStart: chunk.messageStart }
}
if (chunk.messageStop) {
yield { messageStop: chunk.messageStop }
}
if (chunk.metadata) {
yield { metadata: chunk.metadata }
}
}
}
} catch (error) {
@ -485,6 +727,38 @@ export class AwsBedrockAPIClient extends BaseApiClient<
}
}
// 获取推理预算token对所有支持推理的模型
const budgetTokens = this.getBudgetToken(assistant, model)
// 构建基础自定义参数
const customParams: Record<string, any> =
coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {}
// 根据模型类型添加 thinking 参数
if (budgetTokens) {
const modelId = model.id.toLowerCase()
if (modelId.includes('claude')) {
// Claude 模型使用 Anthropic 格式
customParams.thinking = { type: 'enabled', budget_tokens: budgetTokens }
} else if (modelId.includes('gpt') || modelId.includes('openai')) {
// OpenAI 模型格式
customParams.reasoning_effort = assistant?.settings?.reasoning_effort
} else if (modelId.includes('llama')) {
// Llama 模型格式
customParams.thinking_mode = true
customParams.thinking_tokens = budgetTokens
} else if (modelId.includes('titan')) {
// Titan 模型格式
customParams.thinking = { enabled: true }
customParams.thinking_tokens = budgetTokens
} else if (modelId.includes('cohere') || modelId.includes('command')) {
// Cohere 模型格式
customParams.thinking = { enabled: true }
customParams.reasoning_tokens = budgetTokens
}
}
const payload: AwsBedrockSdkParams = {
modelId: model.id,
messages:
@ -497,9 +771,7 @@ export class AwsBedrockAPIClient extends BaseApiClient<
topP: this.getTopP(assistant, model),
stream: streamOutput !== false,
tools: tools.length > 0 ? tools : undefined,
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
// 注意:用户自定义参数总是应该覆盖其他参数
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
...customParams
}
const timeout = this.getTimeout(model)
@ -511,6 +783,7 @@ export class AwsBedrockAPIClient extends BaseApiClient<
getResponseChunkTransformer(): ResponseChunkTransformer<AwsBedrockSdkRawChunk> {
return () => {
let hasStartedText = false
let hasStartedThinking = false
let accumulatedJson = ''
const toolCalls: Record<number, AwsBedrockSdkToolCall> = {}
@ -570,6 +843,24 @@ export class AwsBedrockAPIClient extends BaseApiClient<
} as TextDeltaChunk)
}
// 处理thinking增量
if (
rawChunk.contentBlockDelta?.delta?.type === 'thinking_delta' &&
rawChunk.contentBlockDelta?.delta?.thinking
) {
if (!hasStartedThinking) {
controller.enqueue({
type: ChunkType.THINKING_START
} as ThinkingStartChunk)
hasStartedThinking = true
}
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: rawChunk.contentBlockDelta.delta.thinking
} as ThinkingDeltaChunk)
}
// 处理内容块停止事件 - 参考 Anthropic 的 content_block_stop 处理
if (rawChunk.contentBlockStop) {
const blockIndex = rawChunk.contentBlockStop.contentBlockIndex || 0
@ -708,4 +999,49 @@ export class AwsBedrockAPIClient extends BaseApiClient<
extractMessagesFromSdkPayload(sdkPayload: AwsBedrockSdkParams): AwsBedrockSdkMessageParam[] {
return sdkPayload.messages || []
}
/**
* AWS Bedrock token
* @param assistant - The assistant
* @param model - The model
* @returns The budget tokens for reasoning effort
*/
private getBudgetToken(assistant: Assistant, model: Model): number | undefined {
try {
if (!isReasoningModel(model)) {
return undefined
}
const { maxTokens } = getAssistantSettings(assistant)
const reasoningEffort = assistant?.settings?.reasoning_effort
if (reasoningEffort === undefined) {
return undefined
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const tokenLimits = findTokenLimit(model.id)
if (tokenLimits) {
// 使用模型特定的 token 限制
const budgetTokens = Math.max(
1024,
Math.floor(
Math.min(
(tokenLimits.max - tokenLimits.min) * effortRatio + tokenLimits.min,
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
)
)
)
return budgetTokens
} else {
// 对于没有特定限制的模型,使用简化计算
const budgetTokens = Math.max(1024, Math.floor((maxTokens || DEFAULT_MAX_TOKENS) * effortRatio))
return budgetTokens
}
} catch (error) {
logger.warn('Failed to calculate budget tokens for reasoning effort:', error as Error)
return undefined
}
}
}

View File

@ -162,6 +162,7 @@ export interface AwsBedrockSdkParams {
topP?: number
stream?: boolean
tools?: AwsBedrockSdkTool[]
[key: string]: any // Allow any additional custom parameters
}
export interface AwsBedrockSdkMessageParam {
@ -206,6 +207,22 @@ export interface AwsBedrockSdkMessageParam {
}>
}
export interface AwsBedrockStreamChunk {
type: string
delta?: {
text?: string
toolUse?: { input?: string }
type?: string
thinking?: string
}
index?: number
content_block?: any
usage?: {
inputTokens?: number
outputTokens?: number
}
}
export interface AwsBedrockSdkRawChunk {
contentBlockStart?: {
start?: {
@ -222,6 +239,8 @@ export interface AwsBedrockSdkRawChunk {
toolUse?: {
input?: string
}
type?: string // 支持 'thinking_delta' 等类型
thinking?: string // 支持 thinking 内容
}
contentBlockIndex?: number
}