fix(aws-bedrock): add auto get model list (#9052)

* fix(aws-bedrock): add auto get model list

* fix(aws-bedrock): fix type definition
This commit is contained in:
陈天寒 2025-08-11 16:20:11 +08:00 committed by GitHub
parent 30b7028dd8
commit bfd2f9d156
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 876 additions and 3 deletions

View File

@ -88,6 +88,7 @@
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
"@aws-sdk/client-bedrock": "^3.840.0",
"@aws-sdk/client-bedrock-runtime": "^3.840.0",
"@aws-sdk/client-s3": "^3.840.0",
"@cherrystudio/embedjs": "^0.1.31",

View File

@ -1,3 +1,4 @@
import { BedrockClient, ListFoundationModelsCommand, ListInferenceProfilesCommand } from '@aws-sdk/client-bedrock'
import {
BedrockRuntimeClient,
ConverseCommand,
@ -87,7 +88,15 @@ export class AwsBedrockAPIClient extends BaseApiClient<
}
})
this.sdkInstance = { client, region }
const bedrockClient = new BedrockClient({
region,
credentials: {
accessKeyId,
secretAccessKey
}
})
this.sdkInstance = { client, bedrockClient, region }
return this.sdkInstance
}
@ -132,6 +141,8 @@ export class AwsBedrockAPIClient extends BaseApiClient<
})
}))
logger.info('Creating completions with model ID:', { modelId: payload.modelId })
const commonParams = {
modelId: payload.modelId,
messages: awsMessages as any,
@ -295,9 +306,76 @@ export class AwsBedrockAPIClient extends BaseApiClient<
}
}
// @ts-ignore sdk未提供
override async listModels(): Promise<SdkModel[]> {
return []
try {
const sdk = await this.getSdkInstance()
// 获取支持ON_DEMAND的基础模型列表
const modelsCommand = new ListFoundationModelsCommand({
byInferenceType: 'ON_DEMAND',
byOutputModality: 'TEXT'
})
const modelsResponse = await sdk.bedrockClient.send(modelsCommand)
// 获取推理配置文件列表
const profilesCommand = new ListInferenceProfilesCommand({})
const profilesResponse = await sdk.bedrockClient.send(profilesCommand)
logger.info('Found ON_DEMAND foundation models:', { count: modelsResponse.modelSummaries?.length || 0 })
logger.info('Found inference profiles:', { count: profilesResponse.inferenceProfileSummaries?.length || 0 })
const models: any[] = []
// 处理ON_DEMAND基础模型
if (modelsResponse.modelSummaries) {
for (const model of modelsResponse.modelSummaries) {
if (!model.modelId || !model.modelName) continue
logger.info('Adding ON_DEMAND model', { modelId: model.modelId })
models.push({
id: model.modelId,
name: model.modelName,
display_name: model.modelName,
description: `${model.providerName || 'AWS'} - ${model.modelName}`,
owned_by: model.providerName || 'AWS',
provider: this.provider.id,
group: 'AWS Bedrock',
isInferenceProfile: false
})
}
}
// 处理推理配置文件
if (profilesResponse.inferenceProfileSummaries) {
for (const profile of profilesResponse.inferenceProfileSummaries) {
if (!profile.inferenceProfileArn || !profile.inferenceProfileName) continue
logger.info('Adding inference profile', {
profileArn: profile.inferenceProfileArn,
profileName: profile.inferenceProfileName
})
models.push({
id: profile.inferenceProfileArn,
name: `${profile.inferenceProfileName} (Profile)`,
display_name: `${profile.inferenceProfileName} (Profile)`,
description: `AWS Inference Profile - ${profile.inferenceProfileName}`,
owned_by: 'AWS',
provider: this.provider.id,
group: 'AWS Bedrock Profiles',
isInferenceProfile: true,
inferenceProfileId: profile.inferenceProfileId,
inferenceProfileArn: profile.inferenceProfileArn
})
}
}
logger.info('Total models added to list', { count: models.length })
return models
} catch (error) {
logger.error('Failed to list AWS Bedrock models:', error as Error)
return []
}
}
public async convertMessageToSdkParam(message: Message): Promise<AwsBedrockSdkMessageParam> {

View File

@ -9,6 +9,7 @@ import {
} from '@anthropic-ai/sdk/resources'
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
import type { BedrockClient } from '@aws-sdk/client-bedrock'
import type { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime'
import {
Content,
@ -146,6 +147,7 @@ export interface NewApiModel extends OpenAI.Models.Model {
*/
export interface AwsBedrockSdkInstance {
client: BedrockRuntimeClient
bedrockClient: BedrockClient
region: string
}

792
yarn.lock

File diff suppressed because it is too large Load Diff