Merge branch 'feat/aisdk-package' of https://github.com/CherryHQ/cherry-studio into feat/aisdk-package

This commit is contained in:
suyao 2025-08-06 17:01:57 +08:00
commit abfec7a228
No known key found for this signature in database
51 changed files with 400 additions and 384 deletions

View File

@ -1,187 +1,16 @@
import { loggerService } from '@logger'
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { withSpanResult } from '@renderer/services/SpanManagerService'
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
import { isPromptToolUse } from '@renderer/utils/mcp-tools'
/**
* Cherry Studio AI Core -
*
*
* legacy AiProvider以保持现有代码的兼容性
*/
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
import { VertexAPIClient } from './clients/gemini/VertexAPIClient'
import { NewAPIClient } from './clients/NewAPIClient'
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
import { CompletionsMiddlewareBuilder } from './middleware/builder'
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
import { applyCompletionsMiddlewares } from './middleware/composer'
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
import { MIDDLEWARE_NAME as ThinkChunkMiddlewareName } from './middleware/core/ThinkChunkMiddleware'
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
import { MiddlewareRegistry } from './middleware/register'
import type { CompletionsParams, CompletionsResult } from './middleware/schemas'
// 导出Legacy AiProvider作为默认导出保持向后兼容
export { default } from './legacy/index'
const logger = loggerService.withContext('AiProvider')
// 同时导出Modern AiProvider供新代码使用
export { default as ModernAiProvider } from './index_new'
export default class AiProvider {
private apiClient: BaseApiClient
constructor(provider: Provider) {
// Use the new ApiClientFactory to get a BaseApiClient instance
this.apiClient = ApiClientFactory.create(provider)
}
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
// 1. 根据模型识别正确的客户端
const model = params.assistant.model
if (!model) {
return Promise.reject(new Error('Model is required'))
}
// 根据client类型选择合适的处理方式
let client: BaseApiClient
if (this.apiClient instanceof AihubmixAPIClient) {
// AihubmixAPIClient: 根据模型选择合适的子client
client = this.apiClient.getClientForModel(model)
if (client instanceof OpenAIResponseAPIClient) {
client = client.getClient(model) as BaseApiClient
}
} else if (this.apiClient instanceof NewAPIClient) {
client = this.apiClient.getClientForModel(model)
if (client instanceof OpenAIResponseAPIClient) {
client = client.getClient(model) as BaseApiClient
}
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
// OpenAIResponseAPIClient: 根据模型特征选择API类型
client = this.apiClient.getClient(model) as BaseApiClient
} else if (this.apiClient instanceof VertexAPIClient) {
client = this.apiClient.getClient(model) as BaseApiClient
} else {
// 其他client直接使用
client = this.apiClient
}
// 2. 构建中间件链
const builder = CompletionsMiddlewareBuilder.withDefaults()
// images api
if (isDedicatedImageGenerationModel(model)) {
builder.clear()
builder
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
} else {
// Existing logic for other models
logger.silly('Builder Params', params)
// 使用兼容性类型检查避免typescript类型收窄和装饰器模式的问题
const clientTypes = client.getClientCompatibilityType(model)
const isOpenAICompatible =
clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
if (!isOpenAICompatible) {
logger.silly('ThinkingTagExtractionMiddleware is removed')
builder.remove(ThinkingTagExtractionMiddlewareName)
}
const isAnthropicOrOpenAIResponseCompatible =
clientTypes.includes('AnthropicAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
if (!isAnthropicOrOpenAIResponseCompatible) {
logger.silly('RawStreamListenerMiddleware is removed')
builder.remove(RawStreamListenerMiddlewareName)
}
if (!params.enableWebSearch) {
logger.silly('WebSearchMiddleware is removed')
builder.remove(WebSearchMiddlewareName)
}
if (!params.mcpTools?.length) {
builder.remove(ToolUseExtractionMiddlewareName)
logger.silly('ToolUseExtractionMiddleware is removed')
builder.remove(McpToolChunkMiddlewareName)
logger.silly('McpToolChunkMiddleware is removed')
}
if (!isPromptToolUse(params.assistant)) {
builder.remove(ToolUseExtractionMiddlewareName)
logger.silly('ToolUseExtractionMiddleware is removed')
}
if (params.callType !== 'chat') {
logger.silly('AbortHandlerMiddleware is removed')
builder.remove(AbortHandlerMiddlewareName)
}
if (params.callType === 'test') {
builder.remove(ErrorHandlerMiddlewareName)
logger.silly('ErrorHandlerMiddleware is removed')
builder.remove(FinalChunkConsumerMiddlewareName)
logger.silly('FinalChunkConsumerMiddleware is removed')
builder.insertBefore(ThinkChunkMiddlewareName, MiddlewareRegistry[ThinkingTagExtractionMiddlewareName])
logger.silly('ThinkingTagExtractionMiddleware is inserted')
}
}
const middlewares = builder.build()
logger.silly('middlewares', middlewares)
// 3. Create the wrapped SDK method with middlewares
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
// 4. Execute the wrapped method with the original params
const result = wrappedCompletionMethod(params, options)
return result
}
public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
const traceName = params.assistant.model?.name
? `${params.assistant.model?.name}.${params.callType}`
: `LLM.${params.callType}`
const traceParams: StartSpanParams = {
name: traceName,
tag: 'LLM',
topicId: params.topicId || '',
modelName: params.assistant.model?.name
}
return await withSpanResult(this.completions.bind(this), traceParams, params, options)
}
public async models(): Promise<SdkModel[]> {
return this.apiClient.listModels()
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
try {
// Use the SDK instance to test embedding capabilities
if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') {
this.apiClient = this.apiClient.getClient(model) as BaseApiClient
}
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
return dimensions
} catch (error) {
logger.error('Error getting embedding dimensions:', error as Error)
throw error
}
}
public async generateImage(params: GenerateImageParams): Promise<string[]> {
if (this.apiClient instanceof AihubmixAPIClient) {
const client = this.apiClient.getClientForModel({ id: params.model } as Model)
return client.generateImage(params)
}
return this.apiClient.generateImage(params)
}
public getBaseURL(): string {
return this.apiClient.getBaseURL()
}
public getApiKey(): string {
return this.apiClient.getApiKey()
}
}
// 导出一些常用的类型和工具
export * from './legacy/clients/types'
export * from './legacy/middleware/schemas'

View File

@ -8,136 +8,17 @@
* 3.
*/
import {
AiCore,
AiPlugin,
createExecutor,
generateImage,
ProviderConfigFactory,
type ProviderId,
type ProviderSettingsMap,
StreamTextParams
} from '@cherrystudio/ai-core'
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
import { isDedicatedImageGenerationModel, isNotSupportedImageSizeModel } from '@renderer/config/models'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { createExecutor, generateImage, StreamTextParams } from '@cherrystudio/ai-core'
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
import { ChunkType } from '@renderer/types/chunk'
import { formatApiHost } from '@renderer/utils/api'
import { cloneDeep } from 'lodash'
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
import LegacyAiProvider from './index'
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
import { CompletionsResult } from './middleware/schemas'
import reasoningTimePlugin from './plugins/reasoningTimePlugin'
import { searchOrchestrationPlugin } from './plugins/searchOrchestrationPlugin'
import { createAihubmixProvider } from './provider/aihubmix'
import { getAiSdkProviderId } from './provider/factory'
function getActualProvider(model: Model): Provider {
const provider = getProviderByModel(model)
// 如果是 vertexai 类型且没有 googleCredentials转换为 VertexProvider
let actualProvider = cloneDeep(provider)
if (provider.type === 'vertexai' && !isVertexProvider(provider)) {
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}
actualProvider = createVertexProvider(provider)
}
if (provider.id === 'aihubmix') {
actualProvider = createAihubmixProvider(model, actualProvider)
}
if (actualProvider.type === 'gemini') {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta')
} else {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost)
}
return actualProvider
}
/**
* Provider AI SDK
*/
function providerToAiSdkConfig(actualProvider: Provider): {
providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
// console.log('actualProvider', actualProvider)
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// console.log('aiSdkProviderId', aiSdkProviderId)
// 如果provider是openai则使用strict模式并且默认responses api
const actualProviderType = actualProvider.type
const openaiResponseOptions =
// 对于实际是openai的需要走responses,aiCore内部会判断model是否可用responses
actualProviderType === 'openai-response'
? {
mode: 'responses'
}
: aiSdkProviderId === 'openai'
? {
mode: 'chat'
}
: undefined
console.log('openaiResponseOptions', openaiResponseOptions)
console.log('actualProvider', actualProvider)
console.log('aiSdkProviderId', aiSdkProviderId)
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(
aiSdkProviderId,
{
baseURL: actualProvider.apiHost,
apiKey: actualProvider.apiKey
},
{ ...openaiResponseOptions, headers: actualProvider.extra_headers }
)
return {
providerId: aiSdkProviderId as ProviderId,
options
}
} else {
console.log(`Using openai-compatible fallback for provider: ${actualProvider.type}`)
const options = ProviderConfigFactory.createOpenAICompatible(actualProvider.apiHost, actualProvider.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: actualProvider.id
}
}
}
}
/**
* 使AI SDK
*/
function isModernSdkSupported(provider: Provider, model?: Model): boolean {
// 目前支持主要的providers
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai']
// 检查provider类型
if (!supportedProviders.includes(provider.type)) {
return false
}
// 对于 vertexai检查配置是否完整
if (provider.type === 'vertexai' && !isVertexAIConfigured()) {
return false
}
// 图像生成模型现在支持新的 AI SDK
// (但需要确保 provider 是支持的
if (model && isDedicatedImageGenerationModel(model)) {
return true
}
return true
}
import LegacyAiProvider from './legacy/index'
import { CompletionsResult } from './legacy/middleware/schemas'
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
import { buildPlugins } from './plugins/PluginBuilder'
import { getActualProvider, isModernSdkSupported, providerToAiSdkConfig } from './provider/ProviderConfigProcessor'
export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider
@ -156,62 +37,6 @@ export default class ModernAiProvider {
return this.actualProvider
}
/**
*
*/
private buildPlugins(middlewareConfig: AiSdkMiddlewareConfig) {
const plugins: AiPlugin[] = []
// 1. 总是添加通用插件
// plugins.push(textPlugin)
if (middlewareConfig.enableWebSearch) {
// 内置了默认搜索参数如果改的话可以传config进去
plugins.push(webSearchPlugin())
}
// 2. 支持工具调用时添加搜索插件
if (middlewareConfig.isSupportedToolUse) {
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant))
}
// 3. 推理模型时添加推理插件
if (middlewareConfig.enableReasoning) {
plugins.push(reasoningTimePlugin)
}
// 4. 启用Prompt工具调用时添加工具插件
if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
plugins.push(
createPromptToolUsePlugin({
enabled: true,
createSystemMessage: (systemPrompt, params, context) => {
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
if (context.isRecursiveCall) {
return null
}
params.messages = [
{
role: 'assistant',
content: systemPrompt
},
...params.messages
]
return null
}
return systemPrompt
}
})
)
}
// if (!middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
// plugins.push(createNativeToolUsePlugin())
// }
console.log(
'最终插件列表:',
plugins.map((p) => p.name)
)
return plugins
}
public async completions(
modelId: string,
params: StreamTextParams,
@ -236,7 +61,7 @@ export default class ModernAiProvider {
): Promise<CompletionsResult> {
// try {
// 根据条件构建插件数组
const plugins = this.buildPlugins(middlewareConfig)
const plugins = buildPlugins(middlewareConfig)
console.log('this.config.providerId', this.config.providerId)
console.log('this.config.options', this.config.options)
console.log('plugins', plugins)

View File

@ -26,7 +26,6 @@ import {
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
import { loggerService } from '@logger'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
import { getAssistantSettings } from '@renderer/services/AssistantService'
@ -71,6 +70,7 @@ import {
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { GenericChunk } from '../../middleware/schemas'
import { BaseApiClient } from '../BaseApiClient'
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'

View File

@ -18,7 +18,6 @@ import {
} from '@google/genai'
import { loggerService } from '@logger'
import { nanoid } from '@reduxjs/toolkit'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import {
findTokenLimit,
GEMINI_FLASH_MODEL_REGEX,
@ -61,6 +60,7 @@ import {
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { defaultTimeout, MB } from '@shared/config/constant'
import { GenericChunk } from '../../middleware/schemas'
import { BaseApiClient } from '../BaseApiClient'
import { RequestTransformer, ResponseChunkTransformer } from '../types'

View File

@ -1,5 +1,3 @@
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { CompletionsContext } from '@renderer/aiCore/middleware/types'
import {
isOpenAIChatCompletionOnlyModel,
isOpenAILLMModel,
@ -42,6 +40,8 @@ import { isEmpty } from 'lodash'
import OpenAI, { AzureOpenAI } from 'openai'
import { ResponseInput } from 'openai/resources/responses/responses'
import { GenericChunk } from '../../middleware/schemas'
import { CompletionsContext } from '../../middleware/types'
import { RequestTransformer, ResponseChunkTransformer } from '../types'
import { OpenAIAPIClient } from './OpenAIApiClient'
import { OpenAIBaseClient } from './OpenAIBaseClient'

View File

@ -0,0 +1,187 @@
import { loggerService } from '@logger'
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
import { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { withSpanResult } from '@renderer/services/SpanManagerService'
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
import { isPromptToolUse } from '@renderer/utils/mcp-tools'
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
import { VertexAPIClient } from './clients/gemini/VertexAPIClient'
import { NewAPIClient } from './clients/NewAPIClient'
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
import { CompletionsMiddlewareBuilder } from './middleware/builder'
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
import { applyCompletionsMiddlewares } from './middleware/composer'
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
import { MIDDLEWARE_NAME as ThinkChunkMiddlewareName } from './middleware/core/ThinkChunkMiddleware'
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
import { MiddlewareRegistry } from './middleware/register'
import type { CompletionsParams, CompletionsResult } from './middleware/schemas'
const logger = loggerService.withContext('AiProvider')
export default class AiProvider {
private apiClient: BaseApiClient
constructor(provider: Provider) {
// Use the new ApiClientFactory to get a BaseApiClient instance
this.apiClient = ApiClientFactory.create(provider)
}
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
// 1. 根据模型识别正确的客户端
const model = params.assistant.model
if (!model) {
return Promise.reject(new Error('Model is required'))
}
// 根据client类型选择合适的处理方式
let client: BaseApiClient
if (this.apiClient instanceof AihubmixAPIClient) {
// AihubmixAPIClient: 根据模型选择合适的子client
client = this.apiClient.getClientForModel(model)
if (client instanceof OpenAIResponseAPIClient) {
client = client.getClient(model) as BaseApiClient
}
} else if (this.apiClient instanceof NewAPIClient) {
client = this.apiClient.getClientForModel(model)
if (client instanceof OpenAIResponseAPIClient) {
client = client.getClient(model) as BaseApiClient
}
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
// OpenAIResponseAPIClient: 根据模型特征选择API类型
client = this.apiClient.getClient(model) as BaseApiClient
} else if (this.apiClient instanceof VertexAPIClient) {
client = this.apiClient.getClient(model) as BaseApiClient
} else {
// 其他client直接使用
client = this.apiClient
}
// 2. 构建中间件链
const builder = CompletionsMiddlewareBuilder.withDefaults()
// images api
if (isDedicatedImageGenerationModel(model)) {
builder.clear()
builder
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
} else {
// Existing logic for other models
logger.silly('Builder Params', params)
// 使用兼容性类型检查避免typescript类型收窄和装饰器模式的问题
const clientTypes = client.getClientCompatibilityType(model)
const isOpenAICompatible =
clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
if (!isOpenAICompatible) {
logger.silly('ThinkingTagExtractionMiddleware is removed')
builder.remove(ThinkingTagExtractionMiddlewareName)
}
const isAnthropicOrOpenAIResponseCompatible =
clientTypes.includes('AnthropicAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
if (!isAnthropicOrOpenAIResponseCompatible) {
logger.silly('RawStreamListenerMiddleware is removed')
builder.remove(RawStreamListenerMiddlewareName)
}
if (!params.enableWebSearch) {
logger.silly('WebSearchMiddleware is removed')
builder.remove(WebSearchMiddlewareName)
}
if (!params.mcpTools?.length) {
builder.remove(ToolUseExtractionMiddlewareName)
logger.silly('ToolUseExtractionMiddleware is removed')
builder.remove(McpToolChunkMiddlewareName)
logger.silly('McpToolChunkMiddleware is removed')
}
if (!isPromptToolUse(params.assistant)) {
builder.remove(ToolUseExtractionMiddlewareName)
logger.silly('ToolUseExtractionMiddleware is removed')
}
if (params.callType !== 'chat') {
logger.silly('AbortHandlerMiddleware is removed')
builder.remove(AbortHandlerMiddlewareName)
}
if (params.callType === 'test') {
builder.remove(ErrorHandlerMiddlewareName)
logger.silly('ErrorHandlerMiddleware is removed')
builder.remove(FinalChunkConsumerMiddlewareName)
logger.silly('FinalChunkConsumerMiddleware is removed')
builder.insertBefore(ThinkChunkMiddlewareName, MiddlewareRegistry[ThinkingTagExtractionMiddlewareName])
logger.silly('ThinkingTagExtractionMiddleware is inserted')
}
}
const middlewares = builder.build()
logger.silly('middlewares', middlewares)
// 3. Create the wrapped SDK method with middlewares
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
// 4. Execute the wrapped method with the original params
const result = wrappedCompletionMethod(params, options)
return result
}
public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
const traceName = params.assistant.model?.name
? `${params.assistant.model?.name}.${params.callType}`
: `LLM.${params.callType}`
const traceParams: StartSpanParams = {
name: traceName,
tag: 'LLM',
topicId: params.topicId || '',
modelName: params.assistant.model?.name
}
return await withSpanResult(this.completions.bind(this), traceParams, params, options)
}
public async models(): Promise<SdkModel[]> {
return this.apiClient.listModels()
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
try {
// Use the SDK instance to test embedding capabilities
if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') {
this.apiClient = this.apiClient.getClient(model) as BaseApiClient
}
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
return dimensions
} catch (error) {
logger.error('Error getting embedding dimensions:', error as Error)
throw error
}
}
public async generateImage(params: GenerateImageParams): Promise<string[]> {
if (this.apiClient instanceof AihubmixAPIClient) {
const client = this.apiClient.getClientForModel({ id: params.model } as Model)
return client.generateImage(params)
}
return this.apiClient.generateImage(params)
}
public getBaseURL(): string {
return this.apiClient.getBaseURL()
}
public getApiKey(): string {
return this.apiClient.getApiKey()
}
}

View File

@ -1,4 +1,4 @@
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
import { isAnthropicModel } from '@renderer/config/models'
import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk'

View File

@ -1,4 +1,3 @@
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import FileManager from '@renderer/services/FileManager'
import { ChunkType } from '@renderer/types/chunk'
@ -7,6 +6,7 @@ import { defaultTimeout } from '@shared/config/constant'
import OpenAI from 'openai'
import { toFile } from 'openai/uploads'
import { BaseApiClient } from '../../clients/BaseApiClient'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'

View File

@ -0,0 +1,62 @@
import { AiPlugin } from '@cherrystudio/ai-core'
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
import { AiSdkMiddlewareConfig } from '../middleware/AiSdkMiddlewareBuilder'
import reasoningTimePlugin from './reasoningTimePlugin'
import { searchOrchestrationPlugin } from './searchOrchestrationPlugin'
/**
*
*/
export function buildPlugins(middlewareConfig: AiSdkMiddlewareConfig): AiPlugin[] {
const plugins: AiPlugin[] = []
// 1. 总是添加通用插件
// plugins.push(textPlugin)
if (middlewareConfig.enableWebSearch) {
// 内置了默认搜索参数如果改的话可以传config进去
plugins.push(webSearchPlugin())
}
// 2. 支持工具调用时添加搜索插件
if (middlewareConfig.isSupportedToolUse) {
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant))
}
// 3. 推理模型时添加推理插件
if (middlewareConfig.enableReasoning) {
plugins.push(reasoningTimePlugin)
}
// 4. 启用Prompt工具调用时添加工具插件
if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
plugins.push(
createPromptToolUsePlugin({
enabled: true,
createSystemMessage: (systemPrompt, params, context) => {
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
if (context.isRecursiveCall) {
return null
}
params.messages = [
{
role: 'assistant',
content: systemPrompt
},
...params.messages
]
return null
}
return systemPrompt
}
})
)
}
// if (!middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
// plugins.push(createNativeToolUsePlugin())
// }
console.log(
'最终插件列表:',
plugins.map((p) => p.name)
)
return plugins
}

View File

@ -0,0 +1,113 @@
import { AiCore, ProviderConfigFactory, type ProviderId, type ProviderSettingsMap } from '@cherrystudio/ai-core'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api'
import { cloneDeep } from 'lodash'
import { createAihubmixProvider } from './aihubmix'
import { getAiSdkProviderId } from './factory'
export function getActualProvider(model: Model): Provider {
const provider = getProviderByModel(model)
// 如果是 vertexai 类型且没有 googleCredentials转换为 VertexProvider
let actualProvider = cloneDeep(provider)
if (provider.type === 'vertexai' && !isVertexProvider(provider)) {
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}
actualProvider = createVertexProvider(provider)
}
if (provider.id === 'aihubmix') {
actualProvider = createAihubmixProvider(model, actualProvider)
}
if (actualProvider.type === 'gemini') {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta')
} else {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost)
}
return actualProvider
}
/**
* Provider AI SDK
*/
export function providerToAiSdkConfig(actualProvider: Provider): {
providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
// console.log('actualProvider', actualProvider)
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// console.log('aiSdkProviderId', aiSdkProviderId)
// 如果provider是openai则使用strict模式并且默认responses api
const actualProviderType = actualProvider.type
const openaiResponseOptions =
// 对于实际是openai的需要走responses,aiCore内部会判断model是否可用responses
actualProviderType === 'openai-response'
? {
mode: 'responses'
}
: aiSdkProviderId === 'openai'
? {
mode: 'chat'
}
: undefined
console.log('openaiResponseOptions', openaiResponseOptions)
console.log('actualProvider', actualProvider)
console.log('aiSdkProviderId', aiSdkProviderId)
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(
aiSdkProviderId,
{
baseURL: actualProvider.apiHost,
apiKey: actualProvider.apiKey
},
{ ...openaiResponseOptions, headers: actualProvider.extra_headers }
)
return {
providerId: aiSdkProviderId as ProviderId,
options
}
} else {
console.log(`Using openai-compatible fallback for provider: ${actualProvider.type}`)
const options = ProviderConfigFactory.createOpenAICompatible(actualProvider.apiHost, actualProvider.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: actualProvider.id
}
}
}
}
/**
* 使AI SDK
*/
export function isModernSdkSupported(provider: Provider, model?: Model): boolean {
// 目前支持主要的providers
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai']
// 检查provider类型
if (!supportedProviders.includes(provider.type)) {
return false
}
// 对于 vertexai检查配置是否完整
if (provider.type === 'vertexai' && !isVertexAIConfigured()) {
return false
}
// 图像生成模型现在支持新的 AI SDK
// (但需要确保 provider 是支持的
if (model && isDedicatedImageGenerationModel(model)) {
return true
}
return true
}

View File

@ -13,8 +13,6 @@ import {
TextPart,
UserModelMessage
} from '@cherrystudio/ai-core'
import AiProvider from '@renderer/aiCore'
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {
isGenerateImageModel,
@ -49,6 +47,8 @@ import {
import { defaultTimeout } from '@shared/config/constant'
import { isEmpty } from 'lodash'
import AiProvider from './legacy/index'
import { CompletionsParams } from './legacy/middleware/schemas'
// import { webSearchTool } from './tools/WebSearchTool'
// import { jsonSchemaToZod } from 'json-schema-to-zod'
import { setupToolsConfig } from './utils/mcp'

View File

@ -4,8 +4,8 @@
import { StreamTextParams } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import AiProvider from '@renderer/aiCore'
import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder'
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
import { CompletionsParams } from '@renderer/aiCore/legacy/middleware/schemas'
import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/AiSdkMiddlewareBuilder'
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
import {
isDedicatedImageGenerationModel,

View File

@ -9,13 +9,13 @@ import {
import { FinishReason, MediaModality } from '@google/genai'
import { FunctionCall } from '@google/genai'
import AiProvider from '@renderer/aiCore'
import { OpenAIAPIClient, ResponseChunkTransformerContext } from '@renderer/aiCore/clients'
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient'
import { OpenAIResponseAPIClient } from '@renderer/aiCore/clients/openai/OpenAIResponseAPIClient'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { OpenAIAPIClient, ResponseChunkTransformerContext } from '@renderer/aiCore/legacy/clients'
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
import { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient'
import { GeminiAPIClient } from '@renderer/aiCore/legacy/clients/gemini/GeminiAPIClient'
import { OpenAIResponseAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIResponseAPIClient'
import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas'
import { isVisionModel } from '@renderer/config/models'
import { Assistant, MCPCallToolResponse, MCPToolResponse, Model, Provider, WebSearchSource } from '@renderer/types'
import {