mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
Merge branch 'feat/aisdk-package' of https://github.com/CherryHQ/cherry-studio into feat/aisdk-package
This commit is contained in:
commit
abfec7a228
@ -1,187 +1,16 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { withSpanResult } from '@renderer/services/SpanManagerService'
|
||||
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
||||
import { isPromptToolUse } from '@renderer/utils/mcp-tools'
|
||||
/**
|
||||
* Cherry Studio AI Core - 统一入口点
|
||||
*
|
||||
* 这是新的统一入口,保持向后兼容性
|
||||
* 默认导出legacy AiProvider以保持现有代码的兼容性
|
||||
*/
|
||||
|
||||
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
|
||||
import { VertexAPIClient } from './clients/gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from './clients/NewAPIClient'
|
||||
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
||||
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
||||
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
|
||||
import { applyCompletionsMiddlewares } from './middleware/composer'
|
||||
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
|
||||
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
|
||||
import { MIDDLEWARE_NAME as ThinkChunkMiddlewareName } from './middleware/core/ThinkChunkMiddleware'
|
||||
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
|
||||
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
|
||||
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
|
||||
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
|
||||
import { MiddlewareRegistry } from './middleware/register'
|
||||
import type { CompletionsParams, CompletionsResult } from './middleware/schemas'
|
||||
// 导出Legacy AiProvider作为默认导出(保持向后兼容)
|
||||
export { default } from './legacy/index'
|
||||
|
||||
const logger = loggerService.withContext('AiProvider')
|
||||
// 同时导出Modern AiProvider供新代码使用
|
||||
export { default as ModernAiProvider } from './index_new'
|
||||
|
||||
export default class AiProvider {
|
||||
private apiClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
// Use the new ApiClientFactory to get a BaseApiClient instance
|
||||
this.apiClient = ApiClientFactory.create(provider)
|
||||
}
|
||||
|
||||
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
// 1. 根据模型识别正确的客户端
|
||||
const model = params.assistant.model
|
||||
if (!model) {
|
||||
return Promise.reject(new Error('Model is required'))
|
||||
}
|
||||
|
||||
// 根据client类型选择合适的处理方式
|
||||
let client: BaseApiClient
|
||||
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
// AihubmixAPIClient: 根据模型选择合适的子client
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof NewAPIClient) {
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
|
||||
// OpenAIResponseAPIClient: 根据模型特征选择API类型
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else if (this.apiClient instanceof VertexAPIClient) {
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else {
|
||||
// 其他client直接使用
|
||||
client = this.apiClient
|
||||
}
|
||||
|
||||
// 2. 构建中间件链
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
// images api
|
||||
if (isDedicatedImageGenerationModel(model)) {
|
||||
builder.clear()
|
||||
builder
|
||||
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
|
||||
} else {
|
||||
// Existing logic for other models
|
||||
logger.silly('Builder Params', params)
|
||||
// 使用兼容性类型检查,避免typescript类型收窄和装饰器模式的问题
|
||||
const clientTypes = client.getClientCompatibilityType(model)
|
||||
const isOpenAICompatible =
|
||||
clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
|
||||
if (!isOpenAICompatible) {
|
||||
logger.silly('ThinkingTagExtractionMiddleware is removed')
|
||||
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||
}
|
||||
|
||||
const isAnthropicOrOpenAIResponseCompatible =
|
||||
clientTypes.includes('AnthropicAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
|
||||
if (!isAnthropicOrOpenAIResponseCompatible) {
|
||||
logger.silly('RawStreamListenerMiddleware is removed')
|
||||
builder.remove(RawStreamListenerMiddlewareName)
|
||||
}
|
||||
if (!params.enableWebSearch) {
|
||||
logger.silly('WebSearchMiddleware is removed')
|
||||
builder.remove(WebSearchMiddlewareName)
|
||||
}
|
||||
if (!params.mcpTools?.length) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
logger.silly('ToolUseExtractionMiddleware is removed')
|
||||
builder.remove(McpToolChunkMiddlewareName)
|
||||
logger.silly('McpToolChunkMiddleware is removed')
|
||||
}
|
||||
if (!isPromptToolUse(params.assistant)) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
logger.silly('ToolUseExtractionMiddleware is removed')
|
||||
}
|
||||
if (params.callType !== 'chat') {
|
||||
logger.silly('AbortHandlerMiddleware is removed')
|
||||
builder.remove(AbortHandlerMiddlewareName)
|
||||
}
|
||||
if (params.callType === 'test') {
|
||||
builder.remove(ErrorHandlerMiddlewareName)
|
||||
logger.silly('ErrorHandlerMiddleware is removed')
|
||||
builder.remove(FinalChunkConsumerMiddlewareName)
|
||||
logger.silly('FinalChunkConsumerMiddleware is removed')
|
||||
builder.insertBefore(ThinkChunkMiddlewareName, MiddlewareRegistry[ThinkingTagExtractionMiddlewareName])
|
||||
logger.silly('ThinkingTagExtractionMiddleware is inserted')
|
||||
}
|
||||
}
|
||||
|
||||
const middlewares = builder.build()
|
||||
logger.silly('middlewares', middlewares)
|
||||
|
||||
// 3. Create the wrapped SDK method with middlewares
|
||||
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
|
||||
|
||||
// 4. Execute the wrapped method with the original params
|
||||
const result = wrappedCompletionMethod(params, options)
|
||||
return result
|
||||
}
|
||||
|
||||
public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
const traceName = params.assistant.model?.name
|
||||
? `${params.assistant.model?.name}.${params.callType}`
|
||||
: `LLM.${params.callType}`
|
||||
|
||||
const traceParams: StartSpanParams = {
|
||||
name: traceName,
|
||||
tag: 'LLM',
|
||||
topicId: params.topicId || '',
|
||||
modelName: params.assistant.model?.name
|
||||
}
|
||||
|
||||
return await withSpanResult(this.completions.bind(this), traceParams, params, options)
|
||||
}
|
||||
|
||||
public async models(): Promise<SdkModel[]> {
|
||||
return this.apiClient.listModels()
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
try {
|
||||
// Use the SDK instance to test embedding capabilities
|
||||
if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') {
|
||||
this.apiClient = this.apiClient.getClient(model) as BaseApiClient
|
||||
}
|
||||
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
||||
return dimensions
|
||||
} catch (error) {
|
||||
logger.error('Error getting embedding dimensions:', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
const client = this.apiClient.getClientForModel({ id: params.model } as Model)
|
||||
return client.generateImage(params)
|
||||
}
|
||||
return this.apiClient.generateImage(params)
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.apiClient.getBaseURL()
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.apiClient.getApiKey()
|
||||
}
|
||||
}
|
||||
// 导出一些常用的类型和工具
|
||||
export * from './legacy/clients/types'
|
||||
export * from './legacy/middleware/schemas'
|
||||
|
||||
@ -8,136 +8,17 @@
|
||||
* 3. 暂时保持接口兼容性
|
||||
*/
|
||||
|
||||
import {
|
||||
AiCore,
|
||||
AiPlugin,
|
||||
createExecutor,
|
||||
generateImage,
|
||||
ProviderConfigFactory,
|
||||
type ProviderId,
|
||||
type ProviderSettingsMap,
|
||||
StreamTextParams
|
||||
} from '@cherrystudio/ai-core'
|
||||
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import { isDedicatedImageGenerationModel, isNotSupportedImageSizeModel } from '@renderer/config/models'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { createExecutor, generateImage, StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import { cloneDeep } from 'lodash'
|
||||
|
||||
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
|
||||
import LegacyAiProvider from './index'
|
||||
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
|
||||
import { CompletionsResult } from './middleware/schemas'
|
||||
import reasoningTimePlugin from './plugins/reasoningTimePlugin'
|
||||
import { searchOrchestrationPlugin } from './plugins/searchOrchestrationPlugin'
|
||||
import { createAihubmixProvider } from './provider/aihubmix'
|
||||
import { getAiSdkProviderId } from './provider/factory'
|
||||
|
||||
function getActualProvider(model: Model): Provider {
|
||||
const provider = getProviderByModel(model)
|
||||
// 如果是 vertexai 类型且没有 googleCredentials,转换为 VertexProvider
|
||||
let actualProvider = cloneDeep(provider)
|
||||
if (provider.type === 'vertexai' && !isVertexProvider(provider)) {
|
||||
if (!isVertexAIConfigured()) {
|
||||
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
|
||||
}
|
||||
actualProvider = createVertexProvider(provider)
|
||||
}
|
||||
|
||||
if (provider.id === 'aihubmix') {
|
||||
actualProvider = createAihubmixProvider(model, actualProvider)
|
||||
}
|
||||
if (actualProvider.type === 'gemini') {
|
||||
actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta')
|
||||
} else {
|
||||
actualProvider.apiHost = formatApiHost(actualProvider.apiHost)
|
||||
}
|
||||
return actualProvider
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 Provider 配置转换为新 AI SDK 格式
|
||||
*/
|
||||
function providerToAiSdkConfig(actualProvider: Provider): {
|
||||
providerId: ProviderId | 'openai-compatible'
|
||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||
} {
|
||||
// console.log('actualProvider', actualProvider)
|
||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||
// console.log('aiSdkProviderId', aiSdkProviderId)
|
||||
// 如果provider是openai,则使用strict模式并且默认responses api
|
||||
const actualProviderType = actualProvider.type
|
||||
const openaiResponseOptions =
|
||||
// 对于实际是openai的需要走responses,aiCore内部会判断model是否可用responses
|
||||
actualProviderType === 'openai-response'
|
||||
? {
|
||||
mode: 'responses'
|
||||
}
|
||||
: aiSdkProviderId === 'openai'
|
||||
? {
|
||||
mode: 'chat'
|
||||
}
|
||||
: undefined
|
||||
console.log('openaiResponseOptions', openaiResponseOptions)
|
||||
console.log('actualProvider', actualProvider)
|
||||
console.log('aiSdkProviderId', aiSdkProviderId)
|
||||
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||
const options = ProviderConfigFactory.fromProvider(
|
||||
aiSdkProviderId,
|
||||
{
|
||||
baseURL: actualProvider.apiHost,
|
||||
apiKey: actualProvider.apiKey
|
||||
},
|
||||
{ ...openaiResponseOptions, headers: actualProvider.extra_headers }
|
||||
)
|
||||
|
||||
return {
|
||||
providerId: aiSdkProviderId as ProviderId,
|
||||
options
|
||||
}
|
||||
} else {
|
||||
console.log(`Using openai-compatible fallback for provider: ${actualProvider.type}`)
|
||||
const options = ProviderConfigFactory.createOpenAICompatible(actualProvider.apiHost, actualProvider.apiKey)
|
||||
|
||||
return {
|
||||
providerId: 'openai-compatible',
|
||||
options: {
|
||||
...options,
|
||||
name: actualProvider.id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否支持使用新的AI SDK
|
||||
*/
|
||||
function isModernSdkSupported(provider: Provider, model?: Model): boolean {
|
||||
// 目前支持主要的providers
|
||||
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai']
|
||||
|
||||
// 检查provider类型
|
||||
if (!supportedProviders.includes(provider.type)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 对于 vertexai,检查配置是否完整
|
||||
if (provider.type === 'vertexai' && !isVertexAIConfigured()) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 图像生成模型现在支持新的 AI SDK
|
||||
// (但需要确保 provider 是支持的
|
||||
|
||||
if (model && isDedicatedImageGenerationModel(model)) {
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
import LegacyAiProvider from './legacy/index'
|
||||
import { CompletionsResult } from './legacy/middleware/schemas'
|
||||
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
|
||||
import { buildPlugins } from './plugins/PluginBuilder'
|
||||
import { getActualProvider, isModernSdkSupported, providerToAiSdkConfig } from './provider/ProviderConfigProcessor'
|
||||
|
||||
export default class ModernAiProvider {
|
||||
private legacyProvider: LegacyAiProvider
|
||||
@ -156,62 +37,6 @@ export default class ModernAiProvider {
|
||||
return this.actualProvider
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据条件构建插件数组
|
||||
*/
|
||||
private buildPlugins(middlewareConfig: AiSdkMiddlewareConfig) {
|
||||
const plugins: AiPlugin[] = []
|
||||
// 1. 总是添加通用插件
|
||||
// plugins.push(textPlugin)
|
||||
if (middlewareConfig.enableWebSearch) {
|
||||
// 内置了默认搜索参数,如果改的话可以传config进去
|
||||
plugins.push(webSearchPlugin())
|
||||
}
|
||||
// 2. 支持工具调用时添加搜索插件
|
||||
if (middlewareConfig.isSupportedToolUse) {
|
||||
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant))
|
||||
}
|
||||
|
||||
// 3. 推理模型时添加推理插件
|
||||
if (middlewareConfig.enableReasoning) {
|
||||
plugins.push(reasoningTimePlugin)
|
||||
}
|
||||
|
||||
// 4. 启用Prompt工具调用时添加工具插件
|
||||
if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
|
||||
plugins.push(
|
||||
createPromptToolUsePlugin({
|
||||
enabled: true,
|
||||
createSystemMessage: (systemPrompt, params, context) => {
|
||||
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
|
||||
if (context.isRecursiveCall) {
|
||||
return null
|
||||
}
|
||||
params.messages = [
|
||||
{
|
||||
role: 'assistant',
|
||||
content: systemPrompt
|
||||
},
|
||||
...params.messages
|
||||
]
|
||||
return null
|
||||
}
|
||||
return systemPrompt
|
||||
}
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
// if (!middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
|
||||
// plugins.push(createNativeToolUsePlugin())
|
||||
// }
|
||||
console.log(
|
||||
'最终插件列表:',
|
||||
plugins.map((p) => p.name)
|
||||
)
|
||||
return plugins
|
||||
}
|
||||
|
||||
public async completions(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
@ -236,7 +61,7 @@ export default class ModernAiProvider {
|
||||
): Promise<CompletionsResult> {
|
||||
// try {
|
||||
// 根据条件构建插件数组
|
||||
const plugins = this.buildPlugins(middlewareConfig)
|
||||
const plugins = buildPlugins(middlewareConfig)
|
||||
console.log('this.config.providerId', this.config.providerId)
|
||||
console.log('this.config.options', this.config.options)
|
||||
console.log('plugins', plugins)
|
||||
|
||||
@ -26,7 +26,6 @@ import {
|
||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||
import { loggerService } from '@logger'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
@ -71,6 +70,7 @@ import {
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
|
||||
import { GenericChunk } from '../../middleware/schemas'
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
@ -18,7 +18,6 @@ import {
|
||||
} from '@google/genai'
|
||||
import { loggerService } from '@logger'
|
||||
import { nanoid } from '@reduxjs/toolkit'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import {
|
||||
findTokenLimit,
|
||||
GEMINI_FLASH_MODEL_REGEX,
|
||||
@ -61,6 +60,7 @@ import {
|
||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { defaultTimeout, MB } from '@shared/config/constant'
|
||||
|
||||
import { GenericChunk } from '../../middleware/schemas'
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { CompletionsContext } from '@renderer/aiCore/middleware/types'
|
||||
import {
|
||||
isOpenAIChatCompletionOnlyModel,
|
||||
isOpenAILLMModel,
|
||||
@ -42,6 +40,8 @@ import { isEmpty } from 'lodash'
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
import { ResponseInput } from 'openai/resources/responses/responses'
|
||||
|
||||
import { GenericChunk } from '../../middleware/schemas'
|
||||
import { CompletionsContext } from '../../middleware/types'
|
||||
import { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
import { OpenAIAPIClient } from './OpenAIApiClient'
|
||||
import { OpenAIBaseClient } from './OpenAIBaseClient'
|
||||
187
src/renderer/src/aiCore/legacy/index.ts
Normal file
187
src/renderer/src/aiCore/legacy/index.ts
Normal file
@ -0,0 +1,187 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
|
||||
import { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { withSpanResult } from '@renderer/services/SpanManagerService'
|
||||
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
||||
import { isPromptToolUse } from '@renderer/utils/mcp-tools'
|
||||
|
||||
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
|
||||
import { VertexAPIClient } from './clients/gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from './clients/NewAPIClient'
|
||||
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
||||
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
||||
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
|
||||
import { applyCompletionsMiddlewares } from './middleware/composer'
|
||||
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
|
||||
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
|
||||
import { MIDDLEWARE_NAME as ThinkChunkMiddlewareName } from './middleware/core/ThinkChunkMiddleware'
|
||||
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
|
||||
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
|
||||
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
|
||||
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
|
||||
import { MiddlewareRegistry } from './middleware/register'
|
||||
import type { CompletionsParams, CompletionsResult } from './middleware/schemas'
|
||||
|
||||
const logger = loggerService.withContext('AiProvider')
|
||||
|
||||
export default class AiProvider {
|
||||
private apiClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
// Use the new ApiClientFactory to get a BaseApiClient instance
|
||||
this.apiClient = ApiClientFactory.create(provider)
|
||||
}
|
||||
|
||||
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
// 1. 根据模型识别正确的客户端
|
||||
const model = params.assistant.model
|
||||
if (!model) {
|
||||
return Promise.reject(new Error('Model is required'))
|
||||
}
|
||||
|
||||
// 根据client类型选择合适的处理方式
|
||||
let client: BaseApiClient
|
||||
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
// AihubmixAPIClient: 根据模型选择合适的子client
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof NewAPIClient) {
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
|
||||
// OpenAIResponseAPIClient: 根据模型特征选择API类型
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else if (this.apiClient instanceof VertexAPIClient) {
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else {
|
||||
// 其他client直接使用
|
||||
client = this.apiClient
|
||||
}
|
||||
|
||||
// 2. 构建中间件链
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
// images api
|
||||
if (isDedicatedImageGenerationModel(model)) {
|
||||
builder.clear()
|
||||
builder
|
||||
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
|
||||
} else {
|
||||
// Existing logic for other models
|
||||
logger.silly('Builder Params', params)
|
||||
// 使用兼容性类型检查,避免typescript类型收窄和装饰器模式的问题
|
||||
const clientTypes = client.getClientCompatibilityType(model)
|
||||
const isOpenAICompatible =
|
||||
clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
|
||||
if (!isOpenAICompatible) {
|
||||
logger.silly('ThinkingTagExtractionMiddleware is removed')
|
||||
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||
}
|
||||
|
||||
const isAnthropicOrOpenAIResponseCompatible =
|
||||
clientTypes.includes('AnthropicAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
|
||||
if (!isAnthropicOrOpenAIResponseCompatible) {
|
||||
logger.silly('RawStreamListenerMiddleware is removed')
|
||||
builder.remove(RawStreamListenerMiddlewareName)
|
||||
}
|
||||
if (!params.enableWebSearch) {
|
||||
logger.silly('WebSearchMiddleware is removed')
|
||||
builder.remove(WebSearchMiddlewareName)
|
||||
}
|
||||
if (!params.mcpTools?.length) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
logger.silly('ToolUseExtractionMiddleware is removed')
|
||||
builder.remove(McpToolChunkMiddlewareName)
|
||||
logger.silly('McpToolChunkMiddleware is removed')
|
||||
}
|
||||
if (!isPromptToolUse(params.assistant)) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
logger.silly('ToolUseExtractionMiddleware is removed')
|
||||
}
|
||||
if (params.callType !== 'chat') {
|
||||
logger.silly('AbortHandlerMiddleware is removed')
|
||||
builder.remove(AbortHandlerMiddlewareName)
|
||||
}
|
||||
if (params.callType === 'test') {
|
||||
builder.remove(ErrorHandlerMiddlewareName)
|
||||
logger.silly('ErrorHandlerMiddleware is removed')
|
||||
builder.remove(FinalChunkConsumerMiddlewareName)
|
||||
logger.silly('FinalChunkConsumerMiddleware is removed')
|
||||
builder.insertBefore(ThinkChunkMiddlewareName, MiddlewareRegistry[ThinkingTagExtractionMiddlewareName])
|
||||
logger.silly('ThinkingTagExtractionMiddleware is inserted')
|
||||
}
|
||||
}
|
||||
|
||||
const middlewares = builder.build()
|
||||
logger.silly('middlewares', middlewares)
|
||||
|
||||
// 3. Create the wrapped SDK method with middlewares
|
||||
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
|
||||
|
||||
// 4. Execute the wrapped method with the original params
|
||||
const result = wrappedCompletionMethod(params, options)
|
||||
return result
|
||||
}
|
||||
|
||||
public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
const traceName = params.assistant.model?.name
|
||||
? `${params.assistant.model?.name}.${params.callType}`
|
||||
: `LLM.${params.callType}`
|
||||
|
||||
const traceParams: StartSpanParams = {
|
||||
name: traceName,
|
||||
tag: 'LLM',
|
||||
topicId: params.topicId || '',
|
||||
modelName: params.assistant.model?.name
|
||||
}
|
||||
|
||||
return await withSpanResult(this.completions.bind(this), traceParams, params, options)
|
||||
}
|
||||
|
||||
public async models(): Promise<SdkModel[]> {
|
||||
return this.apiClient.listModels()
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
try {
|
||||
// Use the SDK instance to test embedding capabilities
|
||||
if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') {
|
||||
this.apiClient = this.apiClient.getClient(model) as BaseApiClient
|
||||
}
|
||||
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
||||
return dimensions
|
||||
} catch (error) {
|
||||
logger.error('Error getting embedding dimensions:', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
const client = this.apiClient.getClientForModel({ id: params.model } as Model)
|
||||
return client.generateImage(params)
|
||||
}
|
||||
return this.apiClient.generateImage(params)
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.apiClient.getBaseURL()
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.apiClient.getApiKey()
|
||||
}
|
||||
}
|
||||
@ -1,4 +1,4 @@
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
|
||||
import { isAnthropicModel } from '@renderer/config/models'
|
||||
import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk'
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
@ -7,6 +6,7 @@ import { defaultTimeout } from '@shared/config/constant'
|
||||
import OpenAI from 'openai'
|
||||
import { toFile } from 'openai/uploads'
|
||||
|
||||
import { BaseApiClient } from '../../clients/BaseApiClient'
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
62
src/renderer/src/aiCore/plugins/PluginBuilder.ts
Normal file
62
src/renderer/src/aiCore/plugins/PluginBuilder.ts
Normal file
@ -0,0 +1,62 @@
|
||||
import { AiPlugin } from '@cherrystudio/ai-core'
|
||||
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
|
||||
import { AiSdkMiddlewareConfig } from '../middleware/AiSdkMiddlewareBuilder'
|
||||
import reasoningTimePlugin from './reasoningTimePlugin'
|
||||
import { searchOrchestrationPlugin } from './searchOrchestrationPlugin'
|
||||
|
||||
/**
|
||||
* 根据条件构建插件数组
|
||||
*/
|
||||
export function buildPlugins(middlewareConfig: AiSdkMiddlewareConfig): AiPlugin[] {
|
||||
const plugins: AiPlugin[] = []
|
||||
// 1. 总是添加通用插件
|
||||
// plugins.push(textPlugin)
|
||||
if (middlewareConfig.enableWebSearch) {
|
||||
// 内置了默认搜索参数,如果改的话可以传config进去
|
||||
plugins.push(webSearchPlugin())
|
||||
}
|
||||
// 2. 支持工具调用时添加搜索插件
|
||||
if (middlewareConfig.isSupportedToolUse) {
|
||||
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant))
|
||||
}
|
||||
|
||||
// 3. 推理模型时添加推理插件
|
||||
if (middlewareConfig.enableReasoning) {
|
||||
plugins.push(reasoningTimePlugin)
|
||||
}
|
||||
|
||||
// 4. 启用Prompt工具调用时添加工具插件
|
||||
if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
|
||||
plugins.push(
|
||||
createPromptToolUsePlugin({
|
||||
enabled: true,
|
||||
createSystemMessage: (systemPrompt, params, context) => {
|
||||
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
|
||||
if (context.isRecursiveCall) {
|
||||
return null
|
||||
}
|
||||
params.messages = [
|
||||
{
|
||||
role: 'assistant',
|
||||
content: systemPrompt
|
||||
},
|
||||
...params.messages
|
||||
]
|
||||
return null
|
||||
}
|
||||
return systemPrompt
|
||||
}
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
// if (!middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
|
||||
// plugins.push(createNativeToolUsePlugin())
|
||||
// }
|
||||
console.log(
|
||||
'最终插件列表:',
|
||||
plugins.map((p) => p.name)
|
||||
)
|
||||
return plugins
|
||||
}
|
||||
113
src/renderer/src/aiCore/provider/ProviderConfigProcessor.ts
Normal file
113
src/renderer/src/aiCore/provider/ProviderConfigProcessor.ts
Normal file
@ -0,0 +1,113 @@
|
||||
import { AiCore, ProviderConfigFactory, type ProviderId, type ProviderSettingsMap } from '@cherrystudio/ai-core'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import { cloneDeep } from 'lodash'
|
||||
|
||||
import { createAihubmixProvider } from './aihubmix'
|
||||
import { getAiSdkProviderId } from './factory'
|
||||
|
||||
export function getActualProvider(model: Model): Provider {
|
||||
const provider = getProviderByModel(model)
|
||||
// 如果是 vertexai 类型且没有 googleCredentials,转换为 VertexProvider
|
||||
let actualProvider = cloneDeep(provider)
|
||||
if (provider.type === 'vertexai' && !isVertexProvider(provider)) {
|
||||
if (!isVertexAIConfigured()) {
|
||||
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
|
||||
}
|
||||
actualProvider = createVertexProvider(provider)
|
||||
}
|
||||
|
||||
if (provider.id === 'aihubmix') {
|
||||
actualProvider = createAihubmixProvider(model, actualProvider)
|
||||
}
|
||||
if (actualProvider.type === 'gemini') {
|
||||
actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta')
|
||||
} else {
|
||||
actualProvider.apiHost = formatApiHost(actualProvider.apiHost)
|
||||
}
|
||||
return actualProvider
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 Provider 配置转换为新 AI SDK 格式
|
||||
*/
|
||||
export function providerToAiSdkConfig(actualProvider: Provider): {
|
||||
providerId: ProviderId | 'openai-compatible'
|
||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||
} {
|
||||
// console.log('actualProvider', actualProvider)
|
||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||
// console.log('aiSdkProviderId', aiSdkProviderId)
|
||||
// 如果provider是openai,则使用strict模式并且默认responses api
|
||||
const actualProviderType = actualProvider.type
|
||||
const openaiResponseOptions =
|
||||
// 对于实际是openai的需要走responses,aiCore内部会判断model是否可用responses
|
||||
actualProviderType === 'openai-response'
|
||||
? {
|
||||
mode: 'responses'
|
||||
}
|
||||
: aiSdkProviderId === 'openai'
|
||||
? {
|
||||
mode: 'chat'
|
||||
}
|
||||
: undefined
|
||||
console.log('openaiResponseOptions', openaiResponseOptions)
|
||||
console.log('actualProvider', actualProvider)
|
||||
console.log('aiSdkProviderId', aiSdkProviderId)
|
||||
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||
const options = ProviderConfigFactory.fromProvider(
|
||||
aiSdkProviderId,
|
||||
{
|
||||
baseURL: actualProvider.apiHost,
|
||||
apiKey: actualProvider.apiKey
|
||||
},
|
||||
{ ...openaiResponseOptions, headers: actualProvider.extra_headers }
|
||||
)
|
||||
|
||||
return {
|
||||
providerId: aiSdkProviderId as ProviderId,
|
||||
options
|
||||
}
|
||||
} else {
|
||||
console.log(`Using openai-compatible fallback for provider: ${actualProvider.type}`)
|
||||
const options = ProviderConfigFactory.createOpenAICompatible(actualProvider.apiHost, actualProvider.apiKey)
|
||||
|
||||
return {
|
||||
providerId: 'openai-compatible',
|
||||
options: {
|
||||
...options,
|
||||
name: actualProvider.id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否支持使用新的AI SDK
|
||||
*/
|
||||
export function isModernSdkSupported(provider: Provider, model?: Model): boolean {
|
||||
// 目前支持主要的providers
|
||||
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai']
|
||||
|
||||
// 检查provider类型
|
||||
if (!supportedProviders.includes(provider.type)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 对于 vertexai,检查配置是否完整
|
||||
if (provider.type === 'vertexai' && !isVertexAIConfigured()) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 图像生成模型现在支持新的 AI SDK
|
||||
// (但需要确保 provider 是支持的
|
||||
|
||||
if (model && isDedicatedImageGenerationModel(model)) {
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@ -13,8 +13,6 @@ import {
|
||||
TextPart,
|
||||
UserModelMessage
|
||||
} from '@cherrystudio/ai-core'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import {
|
||||
isGenerateImageModel,
|
||||
@ -49,6 +47,8 @@ import {
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import AiProvider from './legacy/index'
|
||||
import { CompletionsParams } from './legacy/middleware/schemas'
|
||||
// import { webSearchTool } from './tools/WebSearchTool'
|
||||
// import { jsonSchemaToZod } from 'json-schema-to-zod'
|
||||
import { setupToolsConfig } from './utils/mcp'
|
||||
|
||||
@ -4,8 +4,8 @@
|
||||
import { StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder'
|
||||
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
|
||||
import { CompletionsParams } from '@renderer/aiCore/legacy/middleware/schemas'
|
||||
import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/AiSdkMiddlewareBuilder'
|
||||
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
|
||||
import {
|
||||
isDedicatedImageGenerationModel,
|
||||
|
||||
@ -9,13 +9,13 @@ import {
|
||||
import { FinishReason, MediaModality } from '@google/genai'
|
||||
import { FunctionCall } from '@google/genai'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import { OpenAIAPIClient, ResponseChunkTransformerContext } from '@renderer/aiCore/clients'
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
||||
import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient'
|
||||
import { OpenAIResponseAPIClient } from '@renderer/aiCore/clients/openai/OpenAIResponseAPIClient'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { OpenAIAPIClient, ResponseChunkTransformerContext } from '@renderer/aiCore/legacy/clients'
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
|
||||
import { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient'
|
||||
import { GeminiAPIClient } from '@renderer/aiCore/legacy/clients/gemini/GeminiAPIClient'
|
||||
import { OpenAIResponseAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIResponseAPIClient'
|
||||
import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas'
|
||||
import { isVisionModel } from '@renderer/config/models'
|
||||
import { Assistant, MCPCallToolResponse, MCPToolResponse, Model, Provider, WebSearchSource } from '@renderer/types'
|
||||
import {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user