refactor: restructure aiCore for improved modularity and legacy support

- Introduced a new `index_new.ts` file to facilitate the modern AI provider while maintaining backward compatibility with the legacy `index.ts`.
- Created a `legacy` directory to house existing clients and middleware, ensuring a clear separation from new implementations.
- Updated import paths across various modules to reflect the new structure, enhancing code organization and maintainability.
- Added comprehensive middleware and utility functions to support the new architecture, improving overall functionality and extensibility.
- Enhanced plugin management with a dedicated `PluginBuilder` for better integration and configuration of AI plugins.
This commit is contained in:
MyPrototypeWhat 2025-08-05 19:42:57 +08:00
parent 71959f577d
commit eeafb99059
49 changed files with 355 additions and 339 deletions

View File

@ -1,143 +1,16 @@
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
import { RequestOptions, SdkModel } from '@renderer/types/sdk'
import { isPromptToolUse } from '@renderer/utils/mcp-tools'
/**
* Cherry Studio AI Core -
*
*
* legacy AiProvider以保持现有代码的兼容性
*/
import { OpenAIAPIClient } from './clients'
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
import { AnthropicAPIClient } from './clients/anthropic/AnthropicAPIClient'
import { NewAPIClient } from './clients/NewAPIClient'
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
import { CompletionsMiddlewareBuilder } from './middleware/builder'
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
import { applyCompletionsMiddlewares } from './middleware/composer'
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
import { MIDDLEWARE_NAME as ThinkChunkMiddlewareName } from './middleware/core/ThinkChunkMiddleware'
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
import { MiddlewareRegistry } from './middleware/register'
import { CompletionsParams, CompletionsResult } from './middleware/schemas'
// 导出Legacy AiProvider作为默认导出保持向后兼容
export { default } from './legacy/index'
export default class AiProvider {
private apiClient: BaseApiClient
// 同时导出Modern AiProvider供新代码使用
export { default as ModernAiProvider } from './index_new'
constructor(provider: Provider) {
// Use the new ApiClientFactory to get a BaseApiClient instance
this.apiClient = ApiClientFactory.create(provider)
}
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
// 1. 根据模型识别正确的客户端
const model = params.assistant.model
if (!model) {
return Promise.reject(new Error('Model is required'))
}
// 根据client类型选择合适的处理方式
let client: BaseApiClient
if (this.apiClient instanceof AihubmixAPIClient) {
// AihubmixAPIClient: 根据模型选择合适的子client
client = this.apiClient.getClientForModel(model)
if (client instanceof OpenAIResponseAPIClient) {
client = client.getClient(model) as BaseApiClient
}
} else if (this.apiClient instanceof NewAPIClient) {
client = this.apiClient.getClientForModel(model)
if (client instanceof OpenAIResponseAPIClient) {
client = client.getClient(model) as BaseApiClient
}
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
// OpenAIResponseAPIClient: 根据模型特征选择API类型
client = this.apiClient.getClient(model) as BaseApiClient
} else {
// 其他client直接使用
client = this.apiClient
}
// 2. 构建中间件链
const builder = CompletionsMiddlewareBuilder.withDefaults()
// images api
if (isDedicatedImageGenerationModel(model)) {
builder.clear()
builder
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
} else {
// Existing logic for other models
if (!params.enableReasoning) {
// 这里注释掉不会影响正常的关闭思考,可忽略不计的性能下降
// builder.remove(ThinkingTagExtractionMiddlewareName)
builder.remove(ThinkChunkMiddlewareName)
}
// 注意用client判断会导致typescript类型收窄
if (!(this.apiClient instanceof OpenAIAPIClient) && !(this.apiClient instanceof OpenAIResponseAPIClient)) {
builder.remove(ThinkingTagExtractionMiddlewareName)
}
if (!(this.apiClient instanceof AnthropicAPIClient) && !(this.apiClient instanceof OpenAIResponseAPIClient)) {
builder.remove(RawStreamListenerMiddlewareName)
}
if (!params.enableWebSearch) {
builder.remove(WebSearchMiddlewareName)
}
if (!params.mcpTools?.length) {
builder.remove(ToolUseExtractionMiddlewareName)
builder.remove(McpToolChunkMiddlewareName)
}
if (!isPromptToolUse(params.assistant)) {
builder.remove(ToolUseExtractionMiddlewareName)
}
if (params.callType !== 'chat') {
builder.remove(AbortHandlerMiddlewareName)
}
if (params.callType === 'test') {
builder.remove(ErrorHandlerMiddlewareName)
builder.remove(FinalChunkConsumerMiddlewareName)
}
}
const middlewares = builder.build()
// 3. Create the wrapped SDK method with middlewares
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
// 4. Execute the wrapped method with the original params
return wrappedCompletionMethod(params, options)
}
public async models(): Promise<SdkModel[]> {
return this.apiClient.listModels()
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
try {
// Use the SDK instance to test embedding capabilities
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
return dimensions
} catch (error) {
console.error('Error getting embedding dimensions:', error)
throw error
}
}
public async generateImage(params: GenerateImageParams): Promise<string[]> {
return this.apiClient.generateImage(params)
}
public getBaseURL(): string {
return this.apiClient.getBaseURL()
}
public getApiKey(): string {
return this.apiClient.getApiKey()
}
}
// 导出一些常用的类型和工具
export * from './legacy/clients/types'
export * from './legacy/middleware/schemas'

View File

@ -8,136 +8,17 @@
* 3.
*/
import {
AiCore,
AiPlugin,
createExecutor,
generateImage,
ProviderConfigFactory,
type ProviderId,
type ProviderSettingsMap,
StreamTextParams
} from '@cherrystudio/ai-core'
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
import { isDedicatedImageGenerationModel, isNotSupportedImageSizeModel } from '@renderer/config/models'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { createExecutor, generateImage, StreamTextParams } from '@cherrystudio/ai-core'
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
import { ChunkType } from '@renderer/types/chunk'
import { formatApiHost } from '@renderer/utils/api'
import { cloneDeep } from 'lodash'
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
import LegacyAiProvider from './index'
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
import { CompletionsResult } from './middleware/schemas'
import reasoningTimePlugin from './plugins/reasoningTimePlugin'
import { searchOrchestrationPlugin } from './plugins/searchOrchestrationPlugin'
import { createAihubmixProvider } from './provider/aihubmix'
import { getAiSdkProviderId } from './provider/factory'
function getActualProvider(model: Model): Provider {
const provider = getProviderByModel(model)
// 如果是 vertexai 类型且没有 googleCredentials转换为 VertexProvider
let actualProvider = cloneDeep(provider)
if (provider.type === 'vertexai' && !isVertexProvider(provider)) {
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}
actualProvider = createVertexProvider(provider)
}
if (provider.id === 'aihubmix') {
actualProvider = createAihubmixProvider(model, actualProvider)
}
if (actualProvider.type === 'gemini') {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta')
} else {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost)
}
return actualProvider
}
/**
* Provider AI SDK
*/
function providerToAiSdkConfig(actualProvider: Provider): {
providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
// console.log('actualProvider', actualProvider)
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// console.log('aiSdkProviderId', aiSdkProviderId)
// 如果provider是openai则使用strict模式并且默认responses api
const actualProviderId = actualProvider.type
const openaiResponseOptions =
// 对于实际是openai的需要走responses,aiCore内部会判断model是否可用responses
actualProviderId === 'openai-response'
? {
mode: 'responses'
}
: aiSdkProviderId === 'openai'
? {
mode: 'chat'
}
: undefined
console.log('openaiResponseOptions', openaiResponseOptions)
console.log('actualProvider', actualProvider)
console.log('aiSdkProviderId', aiSdkProviderId)
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(
aiSdkProviderId,
{
baseURL: actualProvider.apiHost,
apiKey: actualProvider.apiKey
},
{ ...openaiResponseOptions, headers: actualProvider.extra_headers }
)
return {
providerId: aiSdkProviderId as ProviderId,
options
}
} else {
console.log(`Using openai-compatible fallback for provider: ${actualProvider.type}`)
const options = ProviderConfigFactory.createOpenAICompatible(actualProvider.apiHost, actualProvider.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: actualProvider.id
}
}
}
}
/**
* 使AI SDK
*/
function isModernSdkSupported(provider: Provider, model?: Model): boolean {
// 目前支持主要的providers
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai']
// 检查provider类型
if (!supportedProviders.includes(provider.type)) {
return false
}
// 对于 vertexai检查配置是否完整
if (provider.type === 'vertexai' && !isVertexAIConfigured()) {
return false
}
// 图像生成模型现在支持新的 AI SDK
// (但需要确保 provider 是支持的
if (model && isDedicatedImageGenerationModel(model)) {
return true
}
return true
}
import LegacyAiProvider from './legacy/index'
import { CompletionsResult } from './legacy/middleware/schemas'
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
import { buildPlugins } from './plugins/PluginBuilder'
import { getActualProvider, isModernSdkSupported, providerToAiSdkConfig } from './provider/ProviderConfigProcessor'
export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider
@ -156,62 +37,6 @@ export default class ModernAiProvider {
return this.actualProvider
}
/**
*
*/
private buildPlugins(middlewareConfig: AiSdkMiddlewareConfig) {
const plugins: AiPlugin[] = []
// 1. 总是添加通用插件
// plugins.push(textPlugin)
if (middlewareConfig.enableWebSearch) {
// 内置了默认搜索参数如果改的话可以传config进去
plugins.push(webSearchPlugin())
}
// 2. 支持工具调用时添加搜索插件
if (middlewareConfig.isSupportedToolUse) {
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant))
}
// 3. 推理模型时添加推理插件
if (middlewareConfig.enableReasoning) {
plugins.push(reasoningTimePlugin)
}
// 4. 启用Prompt工具调用时添加工具插件
if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
plugins.push(
createPromptToolUsePlugin({
enabled: true,
createSystemMessage: (systemPrompt, params, context) => {
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
if (context.isRecursiveCall) {
return null
}
params.messages = [
{
role: 'assistant',
content: systemPrompt
},
...params.messages
]
return null
}
return systemPrompt
}
})
)
}
// if (!middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
// plugins.push(createNativeToolUsePlugin())
// }
console.log(
'最终插件列表:',
plugins.map((p) => p.name)
)
return plugins
}
public async completions(
modelId: string,
params: StreamTextParams,
@ -236,7 +61,7 @@ export default class ModernAiProvider {
): Promise<CompletionsResult> {
// try {
// 根据条件构建插件数组
const plugins = this.buildPlugins(middlewareConfig)
const plugins = buildPlugins(middlewareConfig)
console.log('this.config.providerId', this.config.providerId)
console.log('this.config.options', this.config.options)
console.log('plugins', plugins)

View File

@ -24,7 +24,6 @@ import {
WebSearchToolResultError
} from '@anthropic-ai/sdk/resources/messages'
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import Logger from '@renderer/config/logger'
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
@ -71,6 +70,7 @@ import {
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { GenericChunk } from '../../middleware/schemas'
import { BaseApiClient } from '../BaseApiClient'
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'

View File

@ -17,7 +17,6 @@ import {
Tool
} from '@google/genai'
import { nanoid } from '@reduxjs/toolkit'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import {
findTokenLimit,
GEMINI_FLASH_MODEL_REGEX,
@ -61,6 +60,7 @@ import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/u
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { defaultTimeout, MB } from '@shared/config/constant'
import { GenericChunk } from '../../middleware/schemas'
import { BaseApiClient } from '../BaseApiClient'
import { RequestTransformer, ResponseChunkTransformer } from '../types'

View File

@ -1,5 +1,3 @@
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { CompletionsContext } from '@renderer/aiCore/middleware/types'
import {
isOpenAIChatCompletionOnlyModel,
isOpenAILLMModel,
@ -42,6 +40,8 @@ import { isEmpty } from 'lodash'
import OpenAI, { AzureOpenAI } from 'openai'
import { ResponseInput } from 'openai/resources/responses/responses'
import { GenericChunk } from '../../middleware/schemas'
import { CompletionsContext } from '../../middleware/types'
import { RequestTransformer, ResponseChunkTransformer } from '../types'
import { OpenAIAPIClient } from './OpenAIApiClient'
import { OpenAIBaseClient } from './OpenAIBaseClient'

View File

@ -0,0 +1,143 @@
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
import { RequestOptions, SdkModel } from '@renderer/types/sdk'
import { isPromptToolUse } from '@renderer/utils/mcp-tools'
import { OpenAIAPIClient } from './clients'
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
import { AnthropicAPIClient } from './clients/anthropic/AnthropicAPIClient'
import { ApiClientFactory } from './clients/ApiClientFactory'
import { BaseApiClient } from './clients/BaseApiClient'
import { NewAPIClient } from './clients/NewAPIClient'
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
import { CompletionsMiddlewareBuilder } from './middleware/builder'
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
import { applyCompletionsMiddlewares } from './middleware/composer'
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
import { MIDDLEWARE_NAME as ThinkChunkMiddlewareName } from './middleware/core/ThinkChunkMiddleware'
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
import { MiddlewareRegistry } from './middleware/register'
import { CompletionsParams, CompletionsResult } from './middleware/schemas'
export default class AiProvider {
private apiClient: BaseApiClient
constructor(provider: Provider) {
// Use the new ApiClientFactory to get a BaseApiClient instance
this.apiClient = ApiClientFactory.create(provider)
}
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
// 1. 根据模型识别正确的客户端
const model = params.assistant.model
if (!model) {
return Promise.reject(new Error('Model is required'))
}
// 根据client类型选择合适的处理方式
let client: BaseApiClient
if (this.apiClient instanceof AihubmixAPIClient) {
// AihubmixAPIClient: 根据模型选择合适的子client
client = this.apiClient.getClientForModel(model)
if (client instanceof OpenAIResponseAPIClient) {
client = client.getClient(model) as BaseApiClient
}
} else if (this.apiClient instanceof NewAPIClient) {
client = this.apiClient.getClientForModel(model)
if (client instanceof OpenAIResponseAPIClient) {
client = client.getClient(model) as BaseApiClient
}
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
// OpenAIResponseAPIClient: 根据模型特征选择API类型
client = this.apiClient.getClient(model) as BaseApiClient
} else {
// 其他client直接使用
client = this.apiClient
}
// 2. 构建中间件链
const builder = CompletionsMiddlewareBuilder.withDefaults()
// images api
if (isDedicatedImageGenerationModel(model)) {
builder.clear()
builder
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
} else {
// Existing logic for other models
if (!params.enableReasoning) {
// 这里注释掉不会影响正常的关闭思考,可忽略不计的性能下降
// builder.remove(ThinkingTagExtractionMiddlewareName)
builder.remove(ThinkChunkMiddlewareName)
}
// 注意用client判断会导致typescript类型收窄
if (!(this.apiClient instanceof OpenAIAPIClient) && !(this.apiClient instanceof OpenAIResponseAPIClient)) {
builder.remove(ThinkingTagExtractionMiddlewareName)
}
if (!(this.apiClient instanceof AnthropicAPIClient) && !(this.apiClient instanceof OpenAIResponseAPIClient)) {
builder.remove(RawStreamListenerMiddlewareName)
}
if (!params.enableWebSearch) {
builder.remove(WebSearchMiddlewareName)
}
if (!params.mcpTools?.length) {
builder.remove(ToolUseExtractionMiddlewareName)
builder.remove(McpToolChunkMiddlewareName)
}
if (!isPromptToolUse(params.assistant)) {
builder.remove(ToolUseExtractionMiddlewareName)
}
if (params.callType !== 'chat') {
builder.remove(AbortHandlerMiddlewareName)
}
if (params.callType === 'test') {
builder.remove(ErrorHandlerMiddlewareName)
builder.remove(FinalChunkConsumerMiddlewareName)
}
}
const middlewares = builder.build()
// 3. Create the wrapped SDK method with middlewares
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
// 4. Execute the wrapped method with the original params
return wrappedCompletionMethod(params, options)
}
public async models(): Promise<SdkModel[]> {
return this.apiClient.listModels()
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
try {
// Use the SDK instance to test embedding capabilities
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
return dimensions
} catch (error) {
console.error('Error getting embedding dimensions:', error)
throw error
}
}
public async generateImage(params: GenerateImageParams): Promise<string[]> {
return this.apiClient.generateImage(params)
}
public getBaseURL(): string {
return this.apiClient.getBaseURL()
}
public getApiKey(): string {
return this.apiClient.getApiKey()
}
}

View File

@ -1,4 +1,4 @@
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
import { AnthropicAPIClient } from '../../clients/anthropic/AnthropicAPIClient'
import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk'
import { AnthropicStreamListener } from '../../clients/types'

View File

@ -1,4 +1,4 @@
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
import { BaseApiClient } from '../../clients/BaseApiClient'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import FileManager from '@renderer/services/FileManager'
import { ChunkType } from '@renderer/types/chunk'

View File

@ -0,0 +1,62 @@
import { AiPlugin } from '@cherrystudio/ai-core'
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
import { AiSdkMiddlewareConfig } from '../middleware/AiSdkMiddlewareBuilder'
import reasoningTimePlugin from './reasoningTimePlugin'
import { searchOrchestrationPlugin } from './searchOrchestrationPlugin'
/**
*
*/
export function buildPlugins(middlewareConfig: AiSdkMiddlewareConfig): AiPlugin[] {
const plugins: AiPlugin[] = []
// 1. 总是添加通用插件
// plugins.push(textPlugin)
if (middlewareConfig.enableWebSearch) {
// 内置了默认搜索参数如果改的话可以传config进去
plugins.push(webSearchPlugin())
}
// 2. 支持工具调用时添加搜索插件
if (middlewareConfig.isSupportedToolUse) {
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant))
}
// 3. 推理模型时添加推理插件
if (middlewareConfig.enableReasoning) {
plugins.push(reasoningTimePlugin)
}
// 4. 启用Prompt工具调用时添加工具插件
if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
plugins.push(
createPromptToolUsePlugin({
enabled: true,
createSystemMessage: (systemPrompt, params, context) => {
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
if (context.isRecursiveCall) {
return null
}
params.messages = [
{
role: 'assistant',
content: systemPrompt
},
...params.messages
]
return null
}
return systemPrompt
}
})
)
}
// if (!middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
// plugins.push(createNativeToolUsePlugin())
// }
console.log(
'最终插件列表:',
plugins.map((p) => p.name)
)
return plugins
}

View File

@ -0,0 +1,113 @@
import { AiCore, ProviderConfigFactory, type ProviderId, type ProviderSettingsMap } from '@cherrystudio/ai-core'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api'
import { cloneDeep } from 'lodash'
import { createAihubmixProvider } from './aihubmix'
import { getAiSdkProviderId } from './factory'
export function getActualProvider(model: Model): Provider {
const provider = getProviderByModel(model)
// 如果是 vertexai 类型且没有 googleCredentials转换为 VertexProvider
let actualProvider = cloneDeep(provider)
if (provider.type === 'vertexai' && !isVertexProvider(provider)) {
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}
actualProvider = createVertexProvider(provider)
}
if (provider.id === 'aihubmix') {
actualProvider = createAihubmixProvider(model, actualProvider)
}
if (actualProvider.type === 'gemini') {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta')
} else {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost)
}
return actualProvider
}
/**
* Provider AI SDK
*/
export function providerToAiSdkConfig(actualProvider: Provider): {
providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
// console.log('actualProvider', actualProvider)
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// console.log('aiSdkProviderId', aiSdkProviderId)
// 如果provider是openai则使用strict模式并且默认responses api
const actualProviderId = actualProvider.type
const openaiResponseOptions =
// 对于实际是openai的需要走responses,aiCore内部会判断model是否可用responses
actualProviderId === 'openai-response'
? {
mode: 'responses'
}
: aiSdkProviderId === 'openai'
? {
mode: 'chat'
}
: undefined
console.log('openaiResponseOptions', openaiResponseOptions)
console.log('actualProvider', actualProvider)
console.log('aiSdkProviderId', aiSdkProviderId)
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(
aiSdkProviderId,
{
baseURL: actualProvider.apiHost,
apiKey: actualProvider.apiKey
},
{ ...openaiResponseOptions, headers: actualProvider.extra_headers }
)
return {
providerId: aiSdkProviderId as ProviderId,
options
}
} else {
console.log(`Using openai-compatible fallback for provider: ${actualProvider.type}`)
const options = ProviderConfigFactory.createOpenAICompatible(actualProvider.apiHost, actualProvider.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: actualProvider.id
}
}
}
}
/**
* 使AI SDK
*/
export function isModernSdkSupported(provider: Provider, model?: Model): boolean {
// 目前支持主要的providers
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai']
// 检查provider类型
if (!supportedProviders.includes(provider.type)) {
return false
}
// 对于 vertexai检查配置是否完整
if (provider.type === 'vertexai' && !isVertexAIConfigured()) {
return false
}
// 图像生成模型现在支持新的 AI SDK
// (但需要确保 provider 是支持的
if (model && isDedicatedImageGenerationModel(model)) {
return true
}
return true
}

View File

@ -13,8 +13,6 @@ import {
TextPart,
UserModelMessage
} from '@cherrystudio/ai-core'
import AiProvider from '@renderer/aiCore'
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {
isGenerateImageModel,
@ -50,6 +48,8 @@ import { buildSystemPrompt } from '@renderer/utils/prompt'
import { defaultTimeout } from '@shared/config/constant'
import { isEmpty } from 'lodash'
import AiProvider from './legacy/index'
import { CompletionsParams } from './legacy/middleware/schemas'
// import { webSearchTool } from './tools/WebSearchTool'
// import { jsonSchemaToZod } from 'json-schema-to-zod'
import { setupToolsConfig } from './utils/mcp'

View File

@ -1,5 +1,5 @@
import { CheckOutlined, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons'
import { isOpenAIProvider } from '@renderer/aiCore/clients/ApiClientFactory'
import { isOpenAIProvider } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
import OpenAIAlert from '@renderer/components/Alert/OpenAIAlert'
import { StreamlineGoodHealthAndWellBeing } from '@renderer/components/Icons/SVGIcon'
import { HStack } from '@renderer/components/Layout'

View File

@ -3,8 +3,8 @@
*/
import { StreamTextParams } from '@cherrystudio/ai-core'
import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder'
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
import { CompletionsParams } from '@renderer/aiCore/legacy/middleware/schemas'
import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/AiSdkMiddlewareBuilder'
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
import {
isDedicatedImageGenerationModel,

View File

@ -1,10 +1,10 @@
import { FinishReason, MediaModality } from '@google/genai'
import { FunctionCall } from '@google/genai'
import AiProvider from '@renderer/aiCore'
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
import { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient'
import { GeminiAPIClient } from '@renderer/aiCore/legacy/clients/gemini/GeminiAPIClient'
import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas'
import { Assistant, Provider, WebSearchSource } from '@renderer/types'
import {
ChunkType,

View File

@ -26,7 +26,7 @@ import {
ChatCompletionTool
} from 'openai/resources'
import { CompletionsParams } from '../aiCore/middleware/schemas'
import { CompletionsParams } from '../aiCore/legacy/middleware/schemas'
import { confirmSameNameTools, requestToolConfirmation, setToolIdToNameMapping } from './userConfirmation'
const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install'