mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-10 15:49:29 +08:00
fix: Add AWS Bedrock reasoning extraction middleware (#10231)
* Add AWS Bedrock reasoning extraction middleware - Add 'reasoning' tag to tagNameArray for broader reasoning support - Add AWS Bedrock case with gpt-oss model-specific reasoning extraction - Add openai-chat and openrouter cases to provider options switch - Remove unused zod import * Add OpenRouter provider support Updates ai-core to version alpha.18 with OpenRouter integration and improves provider ID resolution for OpenAI API hosts.
This commit is contained in:
parent
6c9fc598d4
commit
89d5bd817b
@ -108,7 +108,7 @@
|
|||||||
"@aws-sdk/client-bedrock-runtime": "^3.840.0",
|
"@aws-sdk/client-bedrock-runtime": "^3.840.0",
|
||||||
"@aws-sdk/client-s3": "^3.840.0",
|
"@aws-sdk/client-s3": "^3.840.0",
|
||||||
"@biomejs/biome": "2.2.4",
|
"@biomejs/biome": "2.2.4",
|
||||||
"@cherrystudio/ai-core": "workspace:^1.0.0-alpha.17",
|
"@cherrystudio/ai-core": "workspace:^1.0.0-alpha.18",
|
||||||
"@cherrystudio/embedjs": "^0.1.31",
|
"@cherrystudio/embedjs": "^0.1.31",
|
||||||
"@cherrystudio/embedjs-libsql": "^0.1.31",
|
"@cherrystudio/embedjs-libsql": "^0.1.31",
|
||||||
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
|
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@cherrystudio/ai-core",
|
"name": "@cherrystudio/ai-core",
|
||||||
"version": "1.0.0-alpha.17",
|
"version": "1.0.0-alpha.18",
|
||||||
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"module": "dist/index.mjs",
|
"module": "dist/index.mjs",
|
||||||
|
|||||||
@ -9,7 +9,9 @@ import { createDeepSeek } from '@ai-sdk/deepseek'
|
|||||||
import { createGoogleGenerativeAI } from '@ai-sdk/google'
|
import { createGoogleGenerativeAI } from '@ai-sdk/google'
|
||||||
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
|
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||||
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
|
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
|
||||||
|
import { LanguageModelV2 } from '@ai-sdk/provider'
|
||||||
import { createXai } from '@ai-sdk/xai'
|
import { createXai } from '@ai-sdk/xai'
|
||||||
|
import { createOpenRouter } from '@openrouter/ai-sdk-provider'
|
||||||
import { customProvider, Provider } from 'ai'
|
import { customProvider, Provider } from 'ai'
|
||||||
import { z } from 'zod'
|
import { z } from 'zod'
|
||||||
|
|
||||||
@ -46,7 +48,7 @@ export const isBaseProvider = (id: ProviderId): id is BaseProviderId => {
|
|||||||
type BaseProvider = {
|
type BaseProvider = {
|
||||||
id: BaseProviderId
|
id: BaseProviderId
|
||||||
name: string
|
name: string
|
||||||
creator: (options: any) => Provider
|
creator: (options: any) => Provider | LanguageModelV2
|
||||||
supportsImageGeneration: boolean
|
supportsImageGeneration: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -124,6 +126,12 @@ export const baseProviders = [
|
|||||||
name: 'DeepSeek',
|
name: 'DeepSeek',
|
||||||
creator: createDeepSeek,
|
creator: createDeepSeek,
|
||||||
supportsImageGeneration: false
|
supportsImageGeneration: false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'openrouter',
|
||||||
|
name: 'OpenRouter',
|
||||||
|
creator: createOpenRouter,
|
||||||
|
supportsImageGeneration: true
|
||||||
}
|
}
|
||||||
] as const satisfies BaseProvider[]
|
] as const satisfies BaseProvider[]
|
||||||
|
|
||||||
|
|||||||
@ -140,7 +140,7 @@ export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageMo
|
|||||||
return builder.build()
|
return builder.build()
|
||||||
}
|
}
|
||||||
|
|
||||||
const tagNameArray = ['think', 'thought']
|
const tagNameArray = ['think', 'thought', 'reasoning']
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 添加provider特定的中间件
|
* 添加provider特定的中间件
|
||||||
@ -167,6 +167,16 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config:
|
|||||||
case 'gemini':
|
case 'gemini':
|
||||||
// Gemini特定中间件
|
// Gemini特定中间件
|
||||||
break
|
break
|
||||||
|
case 'aws-bedrock': {
|
||||||
|
if (config.model?.id.includes('gpt-oss')) {
|
||||||
|
const tagName = tagNameArray[2]
|
||||||
|
builder.add({
|
||||||
|
name: 'thinking-tag-extraction',
|
||||||
|
middleware: extractReasoningMiddleware({ tagName })
|
||||||
|
})
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
// 其他provider的通用处理
|
// 其他provider的通用处理
|
||||||
break
|
break
|
||||||
|
|||||||
@ -69,6 +69,9 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
|
|||||||
return resolvedFromType
|
return resolvedFromType
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (provider.apiHost.includes('api.openai.com')) {
|
||||||
|
return 'openai-chat'
|
||||||
|
}
|
||||||
// 3. 最后的fallback(通常会成为openai-compatible)
|
// 3. 最后的fallback(通常会成为openai-compatible)
|
||||||
return provider.id as ProviderId
|
return provider.id as ProviderId
|
||||||
}
|
}
|
||||||
|
|||||||
@ -82,6 +82,7 @@ export function buildProviderOptions(
|
|||||||
// 应该覆盖所有类型
|
// 应该覆盖所有类型
|
||||||
switch (baseProviderId) {
|
switch (baseProviderId) {
|
||||||
case 'openai':
|
case 'openai':
|
||||||
|
case 'openai-chat':
|
||||||
case 'azure':
|
case 'azure':
|
||||||
providerSpecificOptions = {
|
providerSpecificOptions = {
|
||||||
...buildOpenAIProviderOptions(assistant, model, capabilities),
|
...buildOpenAIProviderOptions(assistant, model, capabilities),
|
||||||
@ -101,13 +102,15 @@ export function buildProviderOptions(
|
|||||||
providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities)
|
providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities)
|
||||||
break
|
break
|
||||||
case 'deepseek':
|
case 'deepseek':
|
||||||
case 'openai-compatible':
|
case 'openrouter':
|
||||||
|
case 'openai-compatible': {
|
||||||
// 对于其他 provider,使用通用的构建逻辑
|
// 对于其他 provider,使用通用的构建逻辑
|
||||||
providerSpecificOptions = {
|
providerSpecificOptions = {
|
||||||
...buildGenericProviderOptions(assistant, model, capabilities),
|
...buildGenericProviderOptions(assistant, model, capabilities),
|
||||||
serviceTier: serviceTierSetting
|
serviceTier: serviceTierSetting
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unsupported base provider ${baseProviderId}`)
|
throw new Error(`Unsupported base provider ${baseProviderId}`)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2309,7 +2309,7 @@ __metadata:
|
|||||||
languageName: node
|
languageName: node
|
||||||
linkType: hard
|
linkType: hard
|
||||||
|
|
||||||
"@cherrystudio/ai-core@workspace:^1.0.0-alpha.17, @cherrystudio/ai-core@workspace:packages/aiCore":
|
"@cherrystudio/ai-core@workspace:^1.0.0-alpha.18, @cherrystudio/ai-core@workspace:packages/aiCore":
|
||||||
version: 0.0.0-use.local
|
version: 0.0.0-use.local
|
||||||
resolution: "@cherrystudio/ai-core@workspace:packages/aiCore"
|
resolution: "@cherrystudio/ai-core@workspace:packages/aiCore"
|
||||||
dependencies:
|
dependencies:
|
||||||
@ -13195,7 +13195,7 @@ __metadata:
|
|||||||
"@aws-sdk/client-bedrock-runtime": "npm:^3.840.0"
|
"@aws-sdk/client-bedrock-runtime": "npm:^3.840.0"
|
||||||
"@aws-sdk/client-s3": "npm:^3.840.0"
|
"@aws-sdk/client-s3": "npm:^3.840.0"
|
||||||
"@biomejs/biome": "npm:2.2.4"
|
"@biomejs/biome": "npm:2.2.4"
|
||||||
"@cherrystudio/ai-core": "workspace:^1.0.0-alpha.17"
|
"@cherrystudio/ai-core": "workspace:^1.0.0-alpha.18"
|
||||||
"@cherrystudio/embedjs": "npm:^0.1.31"
|
"@cherrystudio/embedjs": "npm:^0.1.31"
|
||||||
"@cherrystudio/embedjs-libsql": "npm:^0.1.31"
|
"@cherrystudio/embedjs-libsql": "npm:^0.1.31"
|
||||||
"@cherrystudio/embedjs-loader-csv": "npm:^0.1.31"
|
"@cherrystudio/embedjs-loader-csv": "npm:^0.1.31"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user