mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-11 16:39:15 +08:00
fix(aws-bedrock): add auto get model list (#9052)
* fix(aws-bedrock): add auto get model list * fix(aws-bedrock): fix type definition
This commit is contained in:
parent
30b7028dd8
commit
bfd2f9d156
@ -88,6 +88,7 @@
|
|||||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||||
"@anthropic-ai/sdk": "^0.41.0",
|
"@anthropic-ai/sdk": "^0.41.0",
|
||||||
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
|
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
|
||||||
|
"@aws-sdk/client-bedrock": "^3.840.0",
|
||||||
"@aws-sdk/client-bedrock-runtime": "^3.840.0",
|
"@aws-sdk/client-bedrock-runtime": "^3.840.0",
|
||||||
"@aws-sdk/client-s3": "^3.840.0",
|
"@aws-sdk/client-s3": "^3.840.0",
|
||||||
"@cherrystudio/embedjs": "^0.1.31",
|
"@cherrystudio/embedjs": "^0.1.31",
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import { BedrockClient, ListFoundationModelsCommand, ListInferenceProfilesCommand } from '@aws-sdk/client-bedrock'
|
||||||
import {
|
import {
|
||||||
BedrockRuntimeClient,
|
BedrockRuntimeClient,
|
||||||
ConverseCommand,
|
ConverseCommand,
|
||||||
@ -87,7 +88,15 @@ export class AwsBedrockAPIClient extends BaseApiClient<
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
this.sdkInstance = { client, region }
|
const bedrockClient = new BedrockClient({
|
||||||
|
region,
|
||||||
|
credentials: {
|
||||||
|
accessKeyId,
|
||||||
|
secretAccessKey
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
this.sdkInstance = { client, bedrockClient, region }
|
||||||
return this.sdkInstance
|
return this.sdkInstance
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -132,6 +141,8 @@ export class AwsBedrockAPIClient extends BaseApiClient<
|
|||||||
})
|
})
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
logger.info('Creating completions with model ID:', { modelId: payload.modelId })
|
||||||
|
|
||||||
const commonParams = {
|
const commonParams = {
|
||||||
modelId: payload.modelId,
|
modelId: payload.modelId,
|
||||||
messages: awsMessages as any,
|
messages: awsMessages as any,
|
||||||
@ -295,9 +306,76 @@ export class AwsBedrockAPIClient extends BaseApiClient<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// @ts-ignore sdk未提供
|
|
||||||
override async listModels(): Promise<SdkModel[]> {
|
override async listModels(): Promise<SdkModel[]> {
|
||||||
return []
|
try {
|
||||||
|
const sdk = await this.getSdkInstance()
|
||||||
|
|
||||||
|
// 获取支持ON_DEMAND的基础模型列表
|
||||||
|
const modelsCommand = new ListFoundationModelsCommand({
|
||||||
|
byInferenceType: 'ON_DEMAND',
|
||||||
|
byOutputModality: 'TEXT'
|
||||||
|
})
|
||||||
|
const modelsResponse = await sdk.bedrockClient.send(modelsCommand)
|
||||||
|
|
||||||
|
// 获取推理配置文件列表
|
||||||
|
const profilesCommand = new ListInferenceProfilesCommand({})
|
||||||
|
const profilesResponse = await sdk.bedrockClient.send(profilesCommand)
|
||||||
|
|
||||||
|
logger.info('Found ON_DEMAND foundation models:', { count: modelsResponse.modelSummaries?.length || 0 })
|
||||||
|
logger.info('Found inference profiles:', { count: profilesResponse.inferenceProfileSummaries?.length || 0 })
|
||||||
|
|
||||||
|
const models: any[] = []
|
||||||
|
|
||||||
|
// 处理ON_DEMAND基础模型
|
||||||
|
if (modelsResponse.modelSummaries) {
|
||||||
|
for (const model of modelsResponse.modelSummaries) {
|
||||||
|
if (!model.modelId || !model.modelName) continue
|
||||||
|
|
||||||
|
logger.info('Adding ON_DEMAND model', { modelId: model.modelId })
|
||||||
|
models.push({
|
||||||
|
id: model.modelId,
|
||||||
|
name: model.modelName,
|
||||||
|
display_name: model.modelName,
|
||||||
|
description: `${model.providerName || 'AWS'} - ${model.modelName}`,
|
||||||
|
owned_by: model.providerName || 'AWS',
|
||||||
|
provider: this.provider.id,
|
||||||
|
group: 'AWS Bedrock',
|
||||||
|
isInferenceProfile: false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理推理配置文件
|
||||||
|
if (profilesResponse.inferenceProfileSummaries) {
|
||||||
|
for (const profile of profilesResponse.inferenceProfileSummaries) {
|
||||||
|
if (!profile.inferenceProfileArn || !profile.inferenceProfileName) continue
|
||||||
|
|
||||||
|
logger.info('Adding inference profile', {
|
||||||
|
profileArn: profile.inferenceProfileArn,
|
||||||
|
profileName: profile.inferenceProfileName
|
||||||
|
})
|
||||||
|
|
||||||
|
models.push({
|
||||||
|
id: profile.inferenceProfileArn,
|
||||||
|
name: `${profile.inferenceProfileName} (Profile)`,
|
||||||
|
display_name: `${profile.inferenceProfileName} (Profile)`,
|
||||||
|
description: `AWS Inference Profile - ${profile.inferenceProfileName}`,
|
||||||
|
owned_by: 'AWS',
|
||||||
|
provider: this.provider.id,
|
||||||
|
group: 'AWS Bedrock Profiles',
|
||||||
|
isInferenceProfile: true,
|
||||||
|
inferenceProfileId: profile.inferenceProfileId,
|
||||||
|
inferenceProfileArn: profile.inferenceProfileArn
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info('Total models added to list', { count: models.length })
|
||||||
|
return models
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to list AWS Bedrock models:', error as Error)
|
||||||
|
return []
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public async convertMessageToSdkParam(message: Message): Promise<AwsBedrockSdkMessageParam> {
|
public async convertMessageToSdkParam(message: Message): Promise<AwsBedrockSdkMessageParam> {
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import {
|
|||||||
} from '@anthropic-ai/sdk/resources'
|
} from '@anthropic-ai/sdk/resources'
|
||||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||||
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||||
|
import type { BedrockClient } from '@aws-sdk/client-bedrock'
|
||||||
import type { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime'
|
import type { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime'
|
||||||
import {
|
import {
|
||||||
Content,
|
Content,
|
||||||
@ -146,6 +147,7 @@ export interface NewApiModel extends OpenAI.Models.Model {
|
|||||||
*/
|
*/
|
||||||
export interface AwsBedrockSdkInstance {
|
export interface AwsBedrockSdkInstance {
|
||||||
client: BedrockRuntimeClient
|
client: BedrockRuntimeClient
|
||||||
|
bedrockClient: BedrockClient
|
||||||
region: string
|
region: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user