mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-26 03:31:24 +08:00
fix(aws-bedrock): add auto get model list (#9052)
* fix(aws-bedrock): add auto get model list * fix(aws-bedrock): fix type definition
This commit is contained in:
parent
30b7028dd8
commit
bfd2f9d156
@ -88,6 +88,7 @@
|
||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||
"@anthropic-ai/sdk": "^0.41.0",
|
||||
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
|
||||
"@aws-sdk/client-bedrock": "^3.840.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.840.0",
|
||||
"@aws-sdk/client-s3": "^3.840.0",
|
||||
"@cherrystudio/embedjs": "^0.1.31",
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import { BedrockClient, ListFoundationModelsCommand, ListInferenceProfilesCommand } from '@aws-sdk/client-bedrock'
|
||||
import {
|
||||
BedrockRuntimeClient,
|
||||
ConverseCommand,
|
||||
@ -87,7 +88,15 @@ export class AwsBedrockAPIClient extends BaseApiClient<
|
||||
}
|
||||
})
|
||||
|
||||
this.sdkInstance = { client, region }
|
||||
const bedrockClient = new BedrockClient({
|
||||
region,
|
||||
credentials: {
|
||||
accessKeyId,
|
||||
secretAccessKey
|
||||
}
|
||||
})
|
||||
|
||||
this.sdkInstance = { client, bedrockClient, region }
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
@ -132,6 +141,8 @@ export class AwsBedrockAPIClient extends BaseApiClient<
|
||||
})
|
||||
}))
|
||||
|
||||
logger.info('Creating completions with model ID:', { modelId: payload.modelId })
|
||||
|
||||
const commonParams = {
|
||||
modelId: payload.modelId,
|
||||
messages: awsMessages as any,
|
||||
@ -295,9 +306,76 @@ export class AwsBedrockAPIClient extends BaseApiClient<
|
||||
}
|
||||
}
|
||||
|
||||
// @ts-ignore sdk未提供
|
||||
override async listModels(): Promise<SdkModel[]> {
|
||||
return []
|
||||
try {
|
||||
const sdk = await this.getSdkInstance()
|
||||
|
||||
// 获取支持ON_DEMAND的基础模型列表
|
||||
const modelsCommand = new ListFoundationModelsCommand({
|
||||
byInferenceType: 'ON_DEMAND',
|
||||
byOutputModality: 'TEXT'
|
||||
})
|
||||
const modelsResponse = await sdk.bedrockClient.send(modelsCommand)
|
||||
|
||||
// 获取推理配置文件列表
|
||||
const profilesCommand = new ListInferenceProfilesCommand({})
|
||||
const profilesResponse = await sdk.bedrockClient.send(profilesCommand)
|
||||
|
||||
logger.info('Found ON_DEMAND foundation models:', { count: modelsResponse.modelSummaries?.length || 0 })
|
||||
logger.info('Found inference profiles:', { count: profilesResponse.inferenceProfileSummaries?.length || 0 })
|
||||
|
||||
const models: any[] = []
|
||||
|
||||
// 处理ON_DEMAND基础模型
|
||||
if (modelsResponse.modelSummaries) {
|
||||
for (const model of modelsResponse.modelSummaries) {
|
||||
if (!model.modelId || !model.modelName) continue
|
||||
|
||||
logger.info('Adding ON_DEMAND model', { modelId: model.modelId })
|
||||
models.push({
|
||||
id: model.modelId,
|
||||
name: model.modelName,
|
||||
display_name: model.modelName,
|
||||
description: `${model.providerName || 'AWS'} - ${model.modelName}`,
|
||||
owned_by: model.providerName || 'AWS',
|
||||
provider: this.provider.id,
|
||||
group: 'AWS Bedrock',
|
||||
isInferenceProfile: false
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 处理推理配置文件
|
||||
if (profilesResponse.inferenceProfileSummaries) {
|
||||
for (const profile of profilesResponse.inferenceProfileSummaries) {
|
||||
if (!profile.inferenceProfileArn || !profile.inferenceProfileName) continue
|
||||
|
||||
logger.info('Adding inference profile', {
|
||||
profileArn: profile.inferenceProfileArn,
|
||||
profileName: profile.inferenceProfileName
|
||||
})
|
||||
|
||||
models.push({
|
||||
id: profile.inferenceProfileArn,
|
||||
name: `${profile.inferenceProfileName} (Profile)`,
|
||||
display_name: `${profile.inferenceProfileName} (Profile)`,
|
||||
description: `AWS Inference Profile - ${profile.inferenceProfileName}`,
|
||||
owned_by: 'AWS',
|
||||
provider: this.provider.id,
|
||||
group: 'AWS Bedrock Profiles',
|
||||
isInferenceProfile: true,
|
||||
inferenceProfileId: profile.inferenceProfileId,
|
||||
inferenceProfileArn: profile.inferenceProfileArn
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
logger.info('Total models added to list', { count: models.length })
|
||||
return models
|
||||
} catch (error) {
|
||||
logger.error('Failed to list AWS Bedrock models:', error as Error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
public async convertMessageToSdkParam(message: Message): Promise<AwsBedrockSdkMessageParam> {
|
||||
|
||||
@ -9,6 +9,7 @@ import {
|
||||
} from '@anthropic-ai/sdk/resources'
|
||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||
import type { BedrockClient } from '@aws-sdk/client-bedrock'
|
||||
import type { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime'
|
||||
import {
|
||||
Content,
|
||||
@ -146,6 +147,7 @@ export interface NewApiModel extends OpenAI.Models.Model {
|
||||
*/
|
||||
export interface AwsBedrockSdkInstance {
|
||||
client: BedrockRuntimeClient
|
||||
bedrockClient: BedrockClient
|
||||
region: string
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user