feat: 完成api层,业务逻辑层,编排层的分离

feat: 为插件系统实现中间件
feat: 实现自定义的思考中间件

- Updated package.json and related files to reflect the correct naming convention for the @cherrystudio/ai-core package.
- Adjusted import paths in various files to ensure consistency with the new package name.
- Enhanced type resolution in tsconfig.web.json to align with the updated package structure.
This commit is contained in:
suyao 2025-06-20 05:44:44 +08:00 committed by MyPrototypeWhat
parent 43d55b7e45
commit 1bccfd3170
21 changed files with 610 additions and 465 deletions

View File

@ -73,7 +73,7 @@
"@agentic/tavily": "^7.3.3",
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@cherry-studio/ai-core": "workspace:*",
"@cherrystudio/ai-core": "workspace:*",
"@cherrystudio/embedjs": "^0.1.31",
"@cherrystudio/embedjs-libsql": "^0.1.31",
"@cherrystudio/embedjs-loader-csv": "^0.1.31",

View File

@ -317,7 +317,7 @@ export class AiCoreService {
### 5.1 多 Provider 支持
```typescript
import { createAiSdkClient, AiCore } from '@cherry-studio/ai-core'
import { createAiSdkClient, AiCore } from '@cherrystudio/ai-core'
// 检查支持的 providers
const providers = AiCore.getSupportedProviders()
@ -339,7 +339,7 @@ const xai = await createAiSdkClient('xai', { apiKey: 'xai-key' })
// const anthropicClient = new AnthropicApiClient(config)
// 现在:
import { createAiSdkClient } from '@cherry-studio/ai-core'
import { createAiSdkClient } from '@cherrystudio/ai-core'
const createProviderClient = async (provider: CherryProvider) => {
return await createAiSdkClient(provider.id, {
@ -359,7 +359,7 @@ import {
PreRequestMiddleware,
StreamProcessingMiddleware,
PostResponseMiddleware
} from '@cherry-studio/ai-core'
} from '@cherrystudio/ai-core'
// 创建完整的工作流
const createEnhancedAiService = async () => {

View File

@ -1,4 +1,4 @@
# @cherry-studio/ai-core
# @cherrystudio/ai-core
Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口包。
@ -42,7 +42,7 @@ Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口
## 安装
```bash
npm install @cherry-studio/ai-core ai
npm install @cherrystudio/ai-core ai
```
还需要安装你要使用的 AI SDK provider:
@ -56,7 +56,7 @@ npm install @ai-sdk/openai @ai-sdk/anthropic @ai-sdk/google
### 基础用法
```typescript
import { createAiSdkClient } from '@cherry-studio/ai-core'
import { createAiSdkClient } from '@cherrystudio/ai-core'
// 创建 OpenAI 客户端
const client = await createAiSdkClient('openai', {
@ -79,7 +79,7 @@ const response = await client.generate({
### 便捷函数
```typescript
import { createOpenAIClient, streamGeneration } from '@cherry-studio/ai-core'
import { createOpenAIClient, streamGeneration } from '@cherrystudio/ai-core'
// 快速创建 OpenAI 客户端
const client = await createOpenAIClient({
@ -95,7 +95,7 @@ const result = await streamGeneration('openai', 'gpt-4', [{ role: 'user', conten
### 多 Provider 支持
```typescript
import { createAiSdkClient } from '@cherry-studio/ai-core'
import { createAiSdkClient } from '@cherrystudio/ai-core'
// 支持多种 AI providers
const openaiClient = await createAiSdkClient('openai', { apiKey: 'openai-key' })

View File

@ -1,5 +1,5 @@
{
"name": "@cherry-studio/ai-core",
"name": "@cherrystudio/ai-core",
"version": "1.0.0",
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
"main": "src/index.ts",

View File

@ -4,7 +4,7 @@
*/
import type { ImageModelV1 } from '@ai-sdk/provider'
import { type LanguageModelV1, wrapLanguageModel } from 'ai'
import { type LanguageModelV1, LanguageModelV1Middleware, wrapLanguageModel } from 'ai'
import { aiProviderRegistry } from '../providers/registry'
import { type ProviderId, type ProviderSettingsMap } from './types'
@ -39,16 +39,23 @@ export class ApiClientFactory {
static async createClient<T extends ProviderId>(
providerId: T,
modelId: string,
options: ProviderSettingsMap[T]
options: ProviderSettingsMap[T],
middlewares?: LanguageModelV1Middleware[]
): Promise<LanguageModelV1>
static async createClient(
providerId: string,
modelId: string,
options: ProviderSettingsMap['openai-compatible']
options: ProviderSettingsMap['openai-compatible'],
middlewares?: LanguageModelV1Middleware[]
): Promise<LanguageModelV1>
static async createClient(providerId: string, modelId: string = 'default', options: any): Promise<LanguageModelV1> {
static async createClient(
providerId: string,
modelId: string = 'default',
options: any,
middlewares?: LanguageModelV1Middleware[]
): Promise<LanguageModelV1> {
try {
// 对于不在注册表中的 provider默认使用 openai-compatible
const effectiveProviderId = aiProviderRegistry.isSupported(providerId) ? providerId : 'openai-compatible'
@ -78,10 +85,10 @@ export class ApiClientFactory {
let model = provider(modelId)
// 应用 AI SDK 中间件
if (providerConfig.aiSdkMiddlewares) {
if (middlewares && middlewares.length > 0) {
model = wrapLanguageModel({
model: model,
middleware: providerConfig.aiSdkMiddlewares
middleware: middlewares
})
}

View File

@ -5,7 +5,7 @@
* ## 使
*
* ```typescript
* import { AiClient } from '@cherry-studio/ai-core'
* import { AiClient } from '@cherrystudio/ai-core'
*
* // 创建客户端(默认带插件系统)
* const client = AiClient.create('openai', {
@ -19,7 +19,14 @@
* })
* ```
*/
import { generateObject, generateText, streamObject, streamText } from 'ai'
import {
generateObject,
generateText,
LanguageModelV1Middleware,
simulateStreamingMiddleware,
streamObject,
streamText
} from 'ai'
import { AiPlugin, createContext, PluginManager } from '../plugins'
import { isProviderSupported } from '../providers/registry'
@ -34,6 +41,7 @@ import { UniversalAiSdkClient } from './UniversalAiSdkClient'
export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
private pluginManager: PluginManager
private baseClient: UniversalAiSdkClient<T>
private middlewares: LanguageModelV1Middleware[] = []
constructor(
private readonly providerId: T,
@ -42,6 +50,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
) {
this.pluginManager = new PluginManager(plugins)
this.baseClient = UniversalAiSdkClient.create(providerId, options)
this.updateMiddlewares()
}
/**
@ -49,6 +58,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
*/
use(plugin: AiPlugin): this {
this.pluginManager.use(plugin)
this.updateMiddlewares()
return this
}
@ -57,6 +67,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
*/
usePlugins(plugins: AiPlugin[]): this {
plugins.forEach((plugin) => this.pluginManager.use(plugin))
this.updateMiddlewares()
return this
}
@ -65,9 +76,19 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
*/
removePlugin(pluginName: string): this {
this.pluginManager.remove(pluginName)
this.updateMiddlewares()
return this
}
/**
*
*
*/
private updateMiddlewares(): void {
const pluginMiddlewares = this.pluginManager.collectAiSdkMiddlewares()
this.middlewares = pluginMiddlewares
}
/**
*
*/
@ -164,6 +185,21 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
}
}
/**
* AI SDK
*
*/
private async getModelWithMiddlewares(modelId: string) {
const middlewares = this.middlewares
// 3. 如果有中间件,创建一个新的、注入了中间件的客户端实例
return await ApiClientFactory.createClient(
this.providerId,
modelId,
this.options,
middlewares.length > 0 ? middlewares : [simulateStreamingMiddleware()] //TODO: 这里需要改成非流时调用simulateStreamingMiddleware(),这里先随便传一个
)
}
/**
* -
*/
@ -176,8 +212,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
modelId,
params,
async (finalModelId, transformedParams, streamTransforms) => {
// 对于流式调用,需要直接调用 AI SDK 以支持流转换器
const model = await ApiClientFactory.createClient(this.providerId, finalModelId, this.options)
const model = await this.getModelWithMiddlewares(finalModelId)
return await streamText({
model,
...transformedParams,
@ -189,13 +224,15 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
/**
* -
*
*/
async generateText(
modelId: string,
params: Omit<Parameters<typeof generateText>[0], 'model'>
): Promise<ReturnType<typeof generateText>> {
return this.executeWithPlugins('generateText', modelId, params, async (finalModelId, transformedParams) => {
return await this.baseClient.generateText(finalModelId, transformedParams)
const model = await this.getModelWithMiddlewares(finalModelId)
return await generateText({ model, ...transformedParams })
})
}
@ -207,7 +244,8 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
params: Omit<Parameters<typeof generateObject>[0], 'model'>
): Promise<ReturnType<typeof generateObject>> {
return this.executeWithPlugins('generateObject', modelId, params, async (finalModelId, transformedParams) => {
return await this.baseClient.generateObject(finalModelId, transformedParams)
const model = await this.getModelWithMiddlewares(finalModelId)
return await generateObject({ model, ...transformedParams })
})
}

View File

@ -6,7 +6,7 @@
*
* ### 1.
* ```typescript
* import { UniversalAiSdkClient } from '@cherry-studio/ai-core'
* import { UniversalAiSdkClient } from '@cherrystudio/ai-core'
*
* // OpenAI
* const openai = UniversalAiSdkClient.create('openai', {

View File

@ -66,6 +66,8 @@ export type {
GenerateTextResult,
InvalidToolArgumentsError,
LanguageModelUsage, // AI SDK 4.0 中 TokenUsage 改名为 LanguageModelUsage
LanguageModelV1Middleware,
LanguageModelV1StreamPart,
// 错误类型
NoSuchToolError,
StreamTextResult,
@ -115,7 +117,7 @@ export { getAllProviders, getProvider, isProviderSupported, registerProvider } f
// ==================== 包信息 ====================
export const AI_CORE_VERSION = '1.0.0'
export const AI_CORE_NAME = '@cherry-studio/ai-core'
export const AI_CORE_NAME = '@cherrystudio/ai-core'
// ==================== 便捷 API ====================
// 主要的便捷工厂类

View File

@ -50,7 +50,7 @@ transformStream?: () => (options) => TransformStream<TextStreamPart, TextStreamP
### 基础用法
```typescript
import { PluginManager, createContext, definePlugin } from '@cherry-studio/ai-core/middleware'
import { PluginManager, createContext, definePlugin } from '@cherrystudio/ai-core/middleware'
// 创建插件管理器
const pluginManager = new PluginManager()
@ -81,7 +81,7 @@ import {
LoggingPlugin,
ParamsValidationPlugin,
createContext
} from '@cherry-studio/ai-core/middleware'
} from '@cherrystudio/ai-core/middleware'
// 创建插件管理器
const manager = new PluginManager([

View File

@ -1,4 +1,4 @@
import type { TextStreamPart, ToolSet } from 'ai'
import type { LanguageModelV1Middleware, TextStreamPart, ToolSet } from 'ai'
import { AiPlugin, AiRequestContext } from './types'
@ -135,6 +135,13 @@ export class PluginManager {
>
}
/**
* AI SDK
*/
collectAiSdkMiddlewares(): LanguageModelV1Middleware[] {
return this.plugins.flatMap((plugin) => plugin.aiSdkMiddlewares || [])
}
/**
*
*/

View File

@ -1,5 +1,3 @@
import type { LanguageModelV1Middleware } from 'ai'
/**
* AI Provider
* +
@ -73,8 +71,6 @@ export interface ProviderConfig {
creatorFunctionName: string
// 是否支持图片生成
supportsImageGeneration?: boolean
// AI SDK 原生中间件
aiSdkMiddlewares?: LanguageModelV1Middleware[]
}
/**

View File

@ -3,7 +3,7 @@
* AI SDK fullStream Cherry Studio chunk
*/
import { TextStreamPart } from '@cherry-studio/ai-core'
import { TextStreamPart } from '@cherrystudio/ai-core'
import { Chunk, ChunkType } from '@renderer/types/chunk'
export interface CherryStudioChunk {
@ -78,38 +78,14 @@ export class AiSdkToChunkAdapter {
type: ChunkType.TEXT_DELTA,
text: chunk.textDelta || ''
})
if (final.reasoning_content) {
this.onChunk({
type: ChunkType.THINKING_COMPLETE,
text: final.reasoning_content || ''
})
final.reasoning_content = ''
}
break
// === 推理相关事件 ===
case 'reasoning':
final.reasoning_content += chunk.textDelta || ''
this.onChunk({
type: ChunkType.THINKING_DELTA,
text: chunk.textDelta || ''
})
break
case 'reasoning-signature':
// 推理签名,可以映射到思考完成
this.onChunk({
type: ChunkType.THINKING_COMPLETE,
text: chunk.signature || ''
})
// 不再需要处理,中间件会发出 THINKING_COMPLETE
break
case 'redacted-reasoning':
// 被编辑的推理内容,也映射到思考
this.onChunk({
type: ChunkType.THINKING_DELTA,
text: chunk.data || ''
})
// 不再需要处理
break
// === 工具调用相关事件 ===

View File

@ -1,6 +1,6 @@
/**
* Cherry Studio AI Core -
* @cherry-studio/ai-core
* @cherrystudio/ai-core
*
*
* 1. 使AI SDK
@ -13,20 +13,20 @@ import {
AiCore,
createClient,
type OpenAICompatibleProviderSettings,
type ProviderId
} from '@cherry-studio/ai-core'
type ProviderId,
StreamTextParams
} from '@cherrystudio/ai-core'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import { RequestOptions } from '@renderer/types/sdk'
import { Chunk } from '@renderer/types/chunk'
// 引入适配器
import AiSdkToChunkAdapter from './AiSdkToChunkAdapter'
// 引入原有的AiProvider作为fallback
import LegacyAiProvider from './index'
import { CompletionsParams, CompletionsResult } from './middleware/schemas'
import thinkingTimeMiddleware from './middleware/aisdk/ThinkingTimeMiddleware'
import { CompletionsResult } from './middleware/schemas'
// 引入参数转换模块
import { buildStreamTextParams } from './transformParameters'
/**
* Provider AI SDK Provider ID
@ -108,21 +108,30 @@ export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider
private provider: Provider
constructor(provider: Provider) {
constructor(provider: Provider, onChunk?: (chunk: Chunk) => void) {
this.provider = provider
this.legacyProvider = new LegacyAiProvider(provider)
const config = providerToAiSdkConfig(provider)
this.modernClient = createClient(config.providerId, config.options)
this.modernClient = createClient(
config.providerId,
config.options,
onChunk ? [{ name: 'thinking-time', aiSdkMiddlewares: [thinkingTimeMiddleware(onChunk)] }] : undefined
)
}
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
public async completions(
modelId: string,
params: StreamTextParams,
onChunk?: (chunk: Chunk) => void
): Promise<CompletionsResult> {
// const model = params.assistant.model
// 检查是否应该使用现代化客户端
// if (this.modernClient && model && isModernSdkSupported(this.provider, model)) {
// try {
return await this.modernCompletions(params, options)
console.log('completions', modelId, params, onChunk)
return await this.modernCompletions(modelId, params, onChunk)
// } catch (error) {
// console.warn('Modern client failed, falling back to legacy:', error)
// fallback到原有实现
@ -137,66 +146,33 @@ export default class ModernAiProvider {
* 使AI SDK的completions实现
* 使 AiSdkUtils
*/
private async modernCompletions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
if (!this.modernClient || !params.assistant.model) {
throw new Error('Modern client not available')
private async modernCompletions(
modelId: string,
params: StreamTextParams,
onChunk?: (chunk: Chunk) => void
): Promise<CompletionsResult> {
if (!this.modernClient) {
throw new Error('Modern AI SDK client not initialized')
}
console.log('Modern completions with params:', params, 'options:', options)
const model = params.assistant.model
const assistant = params.assistant
// 检查 messages 类型并转换
const messages = Array.isArray(params.messages) ? params.messages : []
if (typeof params.messages === 'string') {
console.warn('Messages is string, using empty array')
}
// 使用 transformParameters 模块构建参数
const aiSdkParams = await buildStreamTextParams(messages, assistant, model, {
maxTokens: params.maxTokens,
mcpTools: params.mcpTools
})
console.log('Built AI SDK params:', aiSdkParams)
const chunks: Chunk[] = []
try {
if (params.streamOutput && params.onChunk) {
if (onChunk) {
// 流式处理 - 使用适配器
const adapter = new AiSdkToChunkAdapter(params.onChunk)
const streamResult = await this.modernClient.streamText(model.id, aiSdkParams)
const adapter = new AiSdkToChunkAdapter(onChunk)
const streamResult = await this.modernClient.streamText(modelId, params)
const finalText = await adapter.processStream(streamResult)
return {
getText: () => finalText
}
} else if (params.streamOutput) {
} else {
// 流式处理但没有 onChunk 回调
const streamResult = await this.modernClient.streamText(model.id, aiSdkParams)
const streamResult = await this.modernClient.streamText(modelId, params)
const finalText = await streamResult.text
return {
getText: () => finalText
}
} else {
// 非流式处理
const result = await this.modernClient.generateText(model.id, aiSdkParams)
const cherryChunk: Chunk = {
type: ChunkType.TEXT_COMPLETE,
text: result.text || ''
}
chunks.push(cherryChunk)
if (params.onChunk) {
params.onChunk(cherryChunk)
}
return {
getText: () => result.text || ''
}
}
} catch (error) {
console.error('Modern AI SDK error:', error)

View File

@ -0,0 +1,70 @@
import { LanguageModelV1Middleware, LanguageModelV1StreamPart } from '@cherrystudio/ai-core'
import { Chunk, ChunkType, ThinkingCompleteChunk } from '@renderer/types/chunk'
/**
* LLM "思考时间"Time to First Token AI SDK
*
* :
* 1. `stream`
* 2. `TransformStream`
* 3. (chunk)
* 4. "思考时间"
*/
export default function thinkingTimeMiddleware(onChunkReceived: (chunk: Chunk) => void): LanguageModelV1Middleware {
return {
wrapStream: async ({ doStream }) => {
let hasThinkingContent = false
let thinkingStartTime = 0
let accumulatedThinkingContent = ''
const { stream, ...reset } = await doStream()
const transformStream = new TransformStream<LanguageModelV1StreamPart, LanguageModelV1StreamPart>({
transform(chunk, controller) {
if (chunk.type === 'reasoning' || chunk.type === 'redacted-reasoning') {
if (!hasThinkingContent) {
hasThinkingContent = true
thinkingStartTime = Date.now()
}
accumulatedThinkingContent += chunk.textDelta || ''
onChunkReceived({
type: ChunkType.THINKING_DELTA,
text: chunk.textDelta || ''
})
} else {
if (hasThinkingContent && thinkingStartTime > 0) {
const thinkingTime = Date.now() - thinkingStartTime
const thinkingCompleteChunk: ThinkingCompleteChunk = {
type: ChunkType.THINKING_COMPLETE,
text: accumulatedThinkingContent,
thinking_millsec: thinkingTime
}
onChunkReceived(thinkingCompleteChunk)
hasThinkingContent = false
thinkingStartTime = 0
accumulatedThinkingContent = ''
}
}
// 将所有 chunk 原样传递下去
controller.enqueue(chunk)
},
flush(controller) {
// 如果流的末尾都是 reasoning也需要发送 complete 事件
if (hasThinkingContent && thinkingStartTime > 0) {
const thinkingTime = Date.now() - thinkingStartTime
const thinkingCompleteChunk: ThinkingCompleteChunk = {
type: ChunkType.THINKING_COMPLETE,
text: accumulatedThinkingContent,
thinking_millsec: thinkingTime
}
onChunkReceived(thinkingCompleteChunk)
}
controller.terminate()
}
})
return {
stream: stream.pipeThrough(transformStream),
...reset
}
}
}
}

View File

@ -3,8 +3,19 @@
* apiClient
*/
import type { StreamTextParams } from '@cherry-studio/ai-core'
import { isNotSupportTemperatureAndTopP, isSupportedFlexServiceTier } from '@renderer/config/models'
import type { CoreMessage, StreamTextParams } from '@cherrystudio/ai-core'
import {
isGenerateImageModel,
isNotSupportTemperatureAndTopP,
isOpenRouterBuiltInWebSearchModel,
isReasoningModel,
isSupportedDisableGenerationModel,
isSupportedFlexServiceTier,
isSupportedReasoningEffortModel,
isSupportedThinkingTokenModel,
isWebSearchModel
} from '@renderer/config/models'
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
import { FileTypes } from '@renderer/types'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
@ -183,19 +194,38 @@ export async function convertMessagesToSdkMessages(
*
*/
export async function buildStreamTextParams(
messages: Message[],
sdkMessages: StreamTextParams['messages'],
assistant: Assistant,
model: Model,
options: {
maxTokens?: number
mcpTools?: MCPTool[]
enableTools?: boolean
requestOptions?: {
signal?: AbortSignal
timeout?: number
headers?: Record<string, string>
}
} = {}
): Promise<StreamTextParams> {
const { maxTokens, mcpTools, enableTools = false } = options
): Promise<{ params: StreamTextParams; modelId: string }> {
const { mcpTools, enableTools = false } = options
// 转换消息
const sdkMessages = await convertMessagesToSdkMessages(messages, model)
const model = assistant.model || getDefaultModel()
const { maxTokens, reasoning_effort } = getAssistantSettings(assistant)
const enableReasoning =
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
reasoning_effort !== undefined) ||
(isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model)))
const enableWebSearch =
(assistant.enableWebSearch && isWebSearchModel(model)) ||
isOpenRouterBuiltInWebSearchModel(model) ||
model.id.includes('sonar') ||
false
const enableGenerateImage =
isGenerateImageModel(model) &&
(isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true)
// 构建系统提示
let systemPrompt = assistant.prompt || ''
@ -210,6 +240,20 @@ export async function buildStreamTextParams(
temperature: getTemperature(assistant, model),
topP: getTopP(assistant, model),
system: systemPrompt || undefined,
abortSignal: options.requestOptions?.signal,
headers: options.requestOptions?.headers,
// 随便填着,后面再改
providerOptions: {
reasoning: {
enabled: enableReasoning
},
webSearch: {
enabled: enableWebSearch
},
generateImage: {
enabled: enableGenerateImage
}
},
...getCustomParameters(assistant)
}
@ -219,24 +263,22 @@ export async function buildStreamTextParams(
// params.tools = convertMcpToolsToSdkTools(mcpTools)
}
return params
return { params, modelId: model.id }
}
/**
* generateText
*/
export async function buildGenerateTextParams(
messages: Message[],
messages: CoreMessage[],
assistant: Assistant,
model: Model,
options: {
maxTokens?: number
mcpTools?: MCPTool[]
enableTools?: boolean
} = {}
): Promise<any> {
// 复用流式参数的构建逻辑
return await buildStreamTextParams(messages, assistant, model, options)
return await buildStreamTextParams(messages, assistant, options)
}
/**

View File

@ -1,379 +1,311 @@
/**
* API调用函数
*/
import { StreamTextParams } from '@cherrystudio/ai-core'
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
import Logger from '@renderer/config/logger'
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
import {
isEmbeddingModel,
isGenerateImageModel,
isOpenRouterBuiltInWebSearchModel,
isReasoningModel,
isSupportedDisableGenerationModel,
isSupportedReasoningEffortModel,
isSupportedThinkingTokenModel,
isWebSearchModel
isSupportedThinkingTokenModel
} from '@renderer/config/models'
import {
SEARCH_SUMMARY_PROMPT,
SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY,
SEARCH_SUMMARY_PROMPT_WEB_ONLY
} from '@renderer/config/prompts'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import {
Assistant,
ExternalToolResult,
KnowledgeReference,
MCPTool,
Model,
Provider,
WebSearchResponse,
WebSearchSource
} from '@renderer/types'
import { Assistant, MCPTool, Model, Provider } from '@renderer/types'
import { type Chunk, ChunkType } from '@renderer/types/chunk'
import { Message } from '@renderer/types/newMessage'
import { SdkModel } from '@renderer/types/sdk'
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
import { isAbortError } from '@renderer/utils/error'
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { findLast, isEmpty, takeRight } from 'lodash'
import { isEmpty, takeRight } from 'lodash'
import AiProvider from '../aiCore'
import AiProviderNew from '../aiCore/index_new'
import {
getAssistantProvider,
getAssistantSettings,
getDefaultModel,
getProviderByModel,
getTopNamingModel,
getTranslateModel
} from './AssistantService'
import { getDefaultAssistant } from './AssistantService'
import { processKnowledgeSearch } from './KnowledgeService'
import {
filterContextMessages,
filterEmptyMessages,
filterUsefulMessages,
filterUserRoleStartMessages
} from './MessagesService'
import WebSearchService from './WebSearchService'
// TODO考虑拆开
async function fetchExternalTool(
lastUserMessage: Message,
assistant: Assistant,
onChunkReceived: (chunk: Chunk) => void,
lastAnswer?: Message
): Promise<ExternalToolResult> {
// 可能会有重复?
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
const webSearchProvider = WebSearchService.getWebSearchProvider(assistant.webSearchProviderId)
// // TODO考虑拆开
// async function fetchExternalTool(
// lastUserMessage: Message,
// assistant: Assistant,
// onChunkReceived: (chunk: Chunk) => void,
// lastAnswer?: Message
// ) {
// // 可能会有重复?
// const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
// const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
// const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
// const webSearchProvider = WebSearchService.getWebSearchProvider(assistant.webSearchProviderId)
// 使用外部搜索工具
const shouldWebSearch = !!assistant.webSearchProviderId && webSearchProvider !== null
const shouldKnowledgeSearch = hasKnowledgeBase
// // 使用外部搜索工具
// const shouldWebSearch = !!assistant.webSearchProviderId && webSearchProvider !== null
// const shouldKnowledgeSearch = hasKnowledgeBase
// 在工具链开始时发送进度通知
const willUseTools = shouldWebSearch || shouldKnowledgeSearch
if (willUseTools) {
onChunkReceived({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS })
}
// // 在工具链开始时发送进度通知
// const willUseTools = shouldWebSearch || shouldKnowledgeSearch
// if (willUseTools) {
// onChunkReceived({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS })
// }
// --- Keyword/Question Extraction Function ---
const extract = async (): Promise<ExtractResults | undefined> => {
if (!lastUserMessage) return undefined
// // --- Keyword/Question Extraction Function ---
// const extract = async (): Promise<ExtractResults | undefined> => {
// if (!lastUserMessage) return undefined
// 根据配置决定是否需要提取
const needWebExtract = shouldWebSearch
const needKnowledgeExtract = hasKnowledgeBase && knowledgeRecognition === 'on'
// // 根据配置决定是否需要提取
// const needWebExtract = shouldWebSearch
// const needKnowledgeExtract = hasKnowledgeBase && knowledgeRecognition === 'on'
if (!needWebExtract && !needKnowledgeExtract) return undefined
// if (!needWebExtract && !needKnowledgeExtract) return undefined
let prompt: string
if (needWebExtract && !needKnowledgeExtract) {
prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
} else if (!needWebExtract && needKnowledgeExtract) {
prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
} else {
prompt = SEARCH_SUMMARY_PROMPT
}
// let prompt: string
// if (needWebExtract && !needKnowledgeExtract) {
// prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
// } else if (!needWebExtract && needKnowledgeExtract) {
// prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
// } else {
// prompt = SEARCH_SUMMARY_PROMPT
// }
const summaryAssistant = getDefaultAssistant()
summaryAssistant.model = assistant.model || getDefaultModel()
summaryAssistant.prompt = prompt
// const summaryAssistant = getDefaultAssistant()
// summaryAssistant.model = assistant.model || getDefaultModel()
// summaryAssistant.prompt = prompt
// try {
// const result = await fetchSearchSummary({
// messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage],
// assistant: summaryAssistant
// })
// if (!result) return getFallbackResult()
// const extracted = extractInfoFromXML(result.getText())
// // 根据需求过滤结果
// return {
// websearch: needWebExtract ? extracted?.websearch : undefined,
// knowledge: needKnowledgeExtract ? extracted?.knowledge : undefined
// }
// } catch (e: any) {
// console.error('extract error', e)
// if (isAbortError(e)) throw e
// return getFallbackResult()
// }
// }
// const getFallbackResult = (): ExtractResults => {
// const fallbackContent = getMainTextContent(lastUserMessage)
// return {
// websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
// knowledge: shouldKnowledgeSearch
// ? {
// question: [fallbackContent || 'search'],
// rewrite: fallbackContent
// }
// : undefined
// }
// }
// // --- Web Search Function ---
// const searchTheWeb = async (extractResults: ExtractResults | undefined): Promise<WebSearchResponse | undefined> => {
// if (!shouldWebSearch) return
// // Add check for extractResults existence early
// if (!extractResults?.websearch) {
// console.warn('searchTheWeb called without valid extractResults.websearch')
// return
// }
// if (extractResults.websearch.question[0] === 'not_needed') return
// // Add check for assistant.model before using it
// if (!assistant.model) {
// console.warn('searchTheWeb called without assistant.model')
// return undefined
// }
// try {
// // Use the consolidated processWebsearch function
// WebSearchService.createAbortSignal(lastUserMessage.id)
// return {
// results: await WebSearchService.processWebsearch(webSearchProvider!, extractResults),
// source: WebSearchSource.WEBSEARCH
// }
// } catch (error) {
// if (isAbortError(error)) throw error
// console.error('Web search failed:', error)
// return
// }
// }
// // --- Knowledge Base Search Function ---
// const searchKnowledgeBase = async (
// extractResults: ExtractResults | undefined
// ): Promise<KnowledgeReference[] | undefined> => {
// if (!hasKnowledgeBase) return
// // 知识库搜索条件
// let searchCriteria: { question: string[]; rewrite: string }
// if (knowledgeRecognition === 'off') {
// const directContent = getMainTextContent(lastUserMessage)
// searchCriteria = { question: [directContent || 'search'], rewrite: directContent }
// } else {
// // auto mode
// if (!extractResults?.knowledge) {
// console.warn('searchKnowledgeBase: No valid search criteria in auto mode')
// return
// }
// searchCriteria = extractResults.knowledge
// }
// if (searchCriteria.question[0] === 'not_needed') return
// try {
// const tempExtractResults: ExtractResults = {
// websearch: undefined,
// knowledge: searchCriteria
// }
// // Attempt to get knowledgeBaseIds from the main text block
// // NOTE: This assumes knowledgeBaseIds are ONLY on the main text block
// // NOTE: processKnowledgeSearch needs to handle undefined ids gracefully
// // const mainTextBlock = mainTextBlocks
// // ?.map((blockId) => store.getState().messageBlocks.entities[blockId])
// // .find((block) => block?.type === MessageBlockType.MAIN_TEXT) as MainTextMessageBlock | undefined
// return await processKnowledgeSearch(tempExtractResults, knowledgeBaseIds)
// } catch (error) {
// console.error('Knowledge base search failed:', error)
// return
// }
// }
// // --- Execute Extraction and Searches ---
// let extractResults: ExtractResults | undefined
// try {
// // 根据配置决定是否需要提取
// if (shouldWebSearch || hasKnowledgeBase) {
// extractResults = await extract()
// Logger.log('[fetchExternalTool] Extraction results:', extractResults)
// }
// let webSearchResponseFromSearch: WebSearchResponse | undefined
// let knowledgeReferencesFromSearch: KnowledgeReference[] | undefined
// // 并行执行搜索
// if (shouldWebSearch || shouldKnowledgeSearch) {
// ;[webSearchResponseFromSearch, knowledgeReferencesFromSearch] = await Promise.all([
// searchTheWeb(extractResults),
// searchKnowledgeBase(extractResults)
// ])
// }
// // 存储搜索结果
// if (lastUserMessage) {
// if (webSearchResponseFromSearch) {
// window.keyv.set(`web-search-${lastUserMessage.id}`, webSearchResponseFromSearch)
// }
// if (knowledgeReferencesFromSearch) {
// window.keyv.set(`knowledge-search-${lastUserMessage.id}`, knowledgeReferencesFromSearch)
// }
// }
// // 发送工具执行完成通知
// if (willUseTools) {
// onChunkReceived({
// type: ChunkType.EXTERNEL_TOOL_COMPLETE,
// external_tool: {
// webSearch: webSearchResponseFromSearch,
// knowledge: knowledgeReferencesFromSearch
// }
// })
// }
// } catch (error) {
// if (isAbortError(error)) throw error
// console.error('Tool execution failed:', error)
// // 发送错误状态
// if (willUseTools) {
// onChunkReceived({
// type: ChunkType.EXTERNEL_TOOL_COMPLETE,
// external_tool: {
// webSearch: undefined,
// knowledge: undefined
// }
// })
// }
// return { mcpTools: [] }
// }
// }
export async function fetchMcpTools(assistant: Assistant) {
// Get MCP tools (Fix duplicate declaration)
let mcpTools: MCPTool[] = [] // Initialize as empty array
const allMcpServers = store.getState().mcp.servers || []
const activedMcpServers = allMcpServers.filter((s) => s.isActive)
const assistantMcpServers = assistant.mcpServers || []
const enabledMCPs = activedMcpServers.filter((server) => assistantMcpServers.some((s) => s.id === server.id))
if (enabledMCPs && enabledMCPs.length > 0) {
try {
const result = await fetchSearchSummary({
messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage],
assistant: summaryAssistant
})
if (!result) return getFallbackResult()
const extracted = extractInfoFromXML(result.getText())
// 根据需求过滤结果
return {
websearch: needWebExtract ? extracted?.websearch : undefined,
knowledge: needKnowledgeExtract ? extracted?.knowledge : undefined
}
} catch (e: any) {
console.error('extract error', e)
if (isAbortError(e)) throw e
return getFallbackResult()
}
}
const getFallbackResult = (): ExtractResults => {
const fallbackContent = getMainTextContent(lastUserMessage)
return {
websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
knowledge: shouldKnowledgeSearch
? {
question: [fallbackContent || 'search'],
rewrite: fallbackContent
}
: undefined
}
}
// --- Web Search Function ---
const searchTheWeb = async (extractResults: ExtractResults | undefined): Promise<WebSearchResponse | undefined> => {
if (!shouldWebSearch) return
// Add check for extractResults existence early
if (!extractResults?.websearch) {
console.warn('searchTheWeb called without valid extractResults.websearch')
return
}
if (extractResults.websearch.question[0] === 'not_needed') return
// Add check for assistant.model before using it
if (!assistant.model) {
console.warn('searchTheWeb called without assistant.model')
return undefined
}
try {
// Use the consolidated processWebsearch function
WebSearchService.createAbortSignal(lastUserMessage.id)
return {
results: await WebSearchService.processWebsearch(webSearchProvider!, extractResults),
source: WebSearchSource.WEBSEARCH
}
} catch (error) {
if (isAbortError(error)) throw error
console.error('Web search failed:', error)
return
}
}
// --- Knowledge Base Search Function ---
const searchKnowledgeBase = async (
extractResults: ExtractResults | undefined
): Promise<KnowledgeReference[] | undefined> => {
if (!hasKnowledgeBase) return
// 知识库搜索条件
let searchCriteria: { question: string[]; rewrite: string }
if (knowledgeRecognition === 'off') {
const directContent = getMainTextContent(lastUserMessage)
searchCriteria = { question: [directContent || 'search'], rewrite: directContent }
} else {
// auto mode
if (!extractResults?.knowledge) {
console.warn('searchKnowledgeBase: No valid search criteria in auto mode')
return
}
searchCriteria = extractResults.knowledge
}
if (searchCriteria.question[0] === 'not_needed') return
try {
const tempExtractResults: ExtractResults = {
websearch: undefined,
knowledge: searchCriteria
}
// Attempt to get knowledgeBaseIds from the main text block
// NOTE: This assumes knowledgeBaseIds are ONLY on the main text block
// NOTE: processKnowledgeSearch needs to handle undefined ids gracefully
// const mainTextBlock = mainTextBlocks
// ?.map((blockId) => store.getState().messageBlocks.entities[blockId])
// .find((block) => block?.type === MessageBlockType.MAIN_TEXT) as MainTextMessageBlock | undefined
return await processKnowledgeSearch(tempExtractResults, knowledgeBaseIds)
} catch (error) {
console.error('Knowledge base search failed:', error)
return
}
}
// --- Execute Extraction and Searches ---
let extractResults: ExtractResults | undefined
try {
// 根据配置决定是否需要提取
if (shouldWebSearch || hasKnowledgeBase) {
extractResults = await extract()
Logger.log('[fetchExternalTool] Extraction results:', extractResults)
}
let webSearchResponseFromSearch: WebSearchResponse | undefined
let knowledgeReferencesFromSearch: KnowledgeReference[] | undefined
// 并行执行搜索
if (shouldWebSearch || shouldKnowledgeSearch) {
;[webSearchResponseFromSearch, knowledgeReferencesFromSearch] = await Promise.all([
searchTheWeb(extractResults),
searchKnowledgeBase(extractResults)
])
}
// 存储搜索结果
if (lastUserMessage) {
if (webSearchResponseFromSearch) {
window.keyv.set(`web-search-${lastUserMessage.id}`, webSearchResponseFromSearch)
}
if (knowledgeReferencesFromSearch) {
window.keyv.set(`knowledge-search-${lastUserMessage.id}`, knowledgeReferencesFromSearch)
}
}
// 发送工具执行完成通知
if (willUseTools) {
onChunkReceived({
type: ChunkType.EXTERNEL_TOOL_COMPLETE,
external_tool: {
webSearch: webSearchResponseFromSearch,
knowledge: knowledgeReferencesFromSearch
const toolPromises = enabledMCPs.map<Promise<MCPTool[]>>(async (mcpServer) => {
try {
const tools = await window.api.mcp.listTools(mcpServer)
return tools.filter((tool: any) => !mcpServer.disabledTools?.includes(tool.name))
} catch (error) {
console.error(`Error fetching tools from MCP server ${mcpServer.name}:`, error)
return []
}
})
const results = await Promise.allSettled(toolPromises)
mcpTools = results
.filter((result): result is PromiseFulfilledResult<MCPTool[]> => result.status === 'fulfilled')
.map((result) => result.value)
.flat()
} catch (toolError) {
console.error('Error fetching MCP tools:', toolError)
}
// Get MCP tools (Fix duplicate declaration)
let mcpTools: MCPTool[] = [] // Initialize as empty array
const allMcpServers = store.getState().mcp.servers || []
const activedMcpServers = allMcpServers.filter((s) => s.isActive)
const assistantMcpServers = assistant.mcpServers || []
const enabledMCPs = activedMcpServers.filter((server) => assistantMcpServers.some((s) => s.id === server.id))
if (enabledMCPs && enabledMCPs.length > 0) {
try {
const toolPromises = enabledMCPs.map<Promise<MCPTool[]>>(async (mcpServer) => {
try {
const tools = await window.api.mcp.listTools(mcpServer)
return tools.filter((tool: any) => !mcpServer.disabledTools?.includes(tool.name))
} catch (error) {
console.error(`Error fetching tools from MCP server ${mcpServer.name}:`, error)
return []
}
})
const results = await Promise.allSettled(toolPromises)
mcpTools = results
.filter((result): result is PromiseFulfilledResult<MCPTool[]> => result.status === 'fulfilled')
.map((result) => result.value)
.flat()
} catch (toolError) {
console.error('Error fetching MCP tools:', toolError)
}
}
return { mcpTools }
} catch (error) {
if (isAbortError(error)) throw error
console.error('Tool execution failed:', error)
// 发送错误状态
if (willUseTools) {
onChunkReceived({
type: ChunkType.EXTERNEL_TOOL_COMPLETE,
external_tool: {
webSearch: undefined,
knowledge: undefined
}
})
}
return { mcpTools: [] }
}
return mcpTools
}
export async function fetchChatCompletion({
messages,
assistant,
options,
onChunkReceived
}: {
messages: Message[]
messages: StreamTextParams['messages']
assistant: Assistant
onChunkReceived: (chunk: Chunk) => void
// TODO
// onChunkStatus: (status: 'searching' | 'processing' | 'success' | 'error') => void
}) {
console.log('fetchChatCompletion', messages, assistant)
const provider = getAssistantProvider(assistant)
const AI = new AiProviderNew(provider)
// Make sure that 'Clear Context' works for all scenarios including external tool and normal chat.
messages = filterContextMessages(messages)
const lastUserMessage = findLast(messages, (m) => m.role === 'user')
const lastAnswer = findLast(messages, (m) => m.role === 'assistant')
if (!lastUserMessage) {
console.error('fetchChatCompletion returning early: Missing lastUserMessage or lastAnswer')
return
options: {
signal?: AbortSignal
timeout?: number
headers?: Record<string, string>
}
// try {
// NOTE: The search results are NOT added to the messages sent to the AI here.
// They will be retrieved and used by the messageThunk later to create CitationBlocks.
const { mcpTools } = await fetchExternalTool(lastUserMessage, assistant, onChunkReceived, lastAnswer)
const model = assistant.model || getDefaultModel()
onChunkReceived: (chunk: Chunk) => void
}) {
const provider = getAssistantProvider(assistant)
const AI = new AiProviderNew(provider, onChunkReceived)
const { maxTokens, contextCount } = getAssistantSettings(assistant)
const mcpTools = await fetchMcpTools(assistant)
const filteredMessages = filterUsefulMessages(messages)
const _messages = filterUserRoleStartMessages(
filterEmptyMessages(filterContextMessages(takeRight(filteredMessages, contextCount + 2))) // 取原来几个provider的最大值
)
const enableReasoning =
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
assistant.settings?.reasoning_effort !== undefined) ||
(isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model)))
const enableWebSearch =
(assistant.enableWebSearch && isWebSearchModel(model)) ||
isOpenRouterBuiltInWebSearchModel(model) ||
model.id.includes('sonar') ||
false
const enableGenerateImage =
isGenerateImageModel(model) && (isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage : true)
// 使用 transformParameters 模块构建参数
const { params: aiSdkParams, modelId } = await buildStreamTextParams(messages, assistant, {
mcpTools: mcpTools,
requestOptions: options
})
// --- Call AI Completions ---
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
if (enableWebSearch) {
onChunkReceived({ type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS })
}
await AI.completions(
{
callType: 'chat',
messages: _messages,
assistant,
onChunk: onChunkReceived,
mcpTools: mcpTools,
maxTokens,
streamOutput: assistant.settings?.streamOutput || false,
enableReasoning,
enableWebSearch,
enableGenerateImage
},
{
streamOutput: assistant.settings?.streamOutput || false
}
)
await AI.completions(modelId, aiSdkParams, onChunkReceived)
}
interface FetchTranslateProps {

View File

@ -0,0 +1,34 @@
import { StreamTextParams } from '@cherrystudio/ai-core'
import { convertMessagesToSdkMessages } from '@renderer/aiCore/transformParameters'
import { Assistant, Message } from '@renderer/types'
import { isEmpty, takeRight } from 'lodash'
import { getAssistantSettings, getDefaultModel } from './AssistantService'
import {
filterContextMessages,
filterEmptyMessages,
filterUsefulMessages,
filterUserRoleStartMessages
} from './MessagesService'
export class ConversationService {
static async prepareMessagesForLlm(messages: Message[], assistant: Assistant): Promise<StreamTextParams['messages']> {
const { contextCount } = getAssistantSettings(assistant)
// This logic is extracted from the original ApiService.fetchChatCompletion
const contextMessages = filterContextMessages(messages)
const filteredMessages = filterUsefulMessages(contextMessages)
// Take the last `contextCount` messages, plus 2 to allow for a final user/assistant exchange.
const finalMessages = filterUserRoleStartMessages(
filterEmptyMessages(takeRight(filteredMessages, contextCount + 2))
)
return await convertMessagesToSdkMessages(finalMessages, assistant.model || getDefaultModel())
}
static needsWebSearch(assistant: Assistant): boolean {
return !!assistant.webSearchProviderId
}
static needsKnowledgeSearch(assistant: Assistant): boolean {
return !isEmpty(assistant.knowledge_bases)
}
}

View File

@ -0,0 +1,54 @@
import { Assistant, Message } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import { fetchChatCompletion } from './ApiService'
import { ConversationService } from './ConversationService'
/**
* The request object for handling a user message.
*/
export interface OrchestrationRequest {
messages: Message[]
assistant: Assistant
options: {
signal?: AbortSignal
timeout?: number
headers?: Record<string, string>
}
}
/**
* The OrchestrationService is responsible for orchestrating the different services
* to handle a user's message. It contains the core logic of the application.
*/
export class OrchestrationService {
constructor() {
// In the future, this could be a singleton, but for now, a new instance is fine.
// this.conversationService = new ConversationService()
}
/**
* This is the core method to handle user messages.
* It takes the message context and an events object for callbacks,
* and orchestrates the call to the LLM.
* The logic is moved from `messageThunk.ts`.
* @param request The orchestration request containing messages and assistant info.
* @param events A set of callbacks to report progress and results to the UI layer.
*/
async handleUserMessage(request: OrchestrationRequest, onChunkReceived: (chunk: Chunk) => void) {
const { messages, assistant } = request
try {
const llmMessages = await ConversationService.prepareMessagesForLlm(messages, assistant)
await fetchChatCompletion({
messages: llmMessages,
assistant: assistant,
options: request.options,
onChunkReceived
})
} catch (error: any) {
onChunkReceived({ type: ChunkType.ERROR, error })
}
}
}

View File

@ -1,9 +1,9 @@
import db from '@renderer/databases'
import { autoRenameTopic } from '@renderer/hooks/useTopic'
import { fetchChatCompletion } from '@renderer/services/ApiService'
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
import FileManager from '@renderer/services/FileManager'
import { NotificationService } from '@renderer/services/NotificationService'
import { OrchestrationService } from '@renderer/services/OrchestrateService'
import { createStreamProcessor, type StreamProcessorCallbacks } from '@renderer/services/StreamProcessingService'
import { estimateMessagesUsage } from '@renderer/services/TokenService'
import store from '@renderer/store'
@ -829,15 +829,26 @@ const fetchAndProcessAssistantResponseImpl = async (
const streamProcessorCallbacks = createStreamProcessor(callbacks)
const startTime = Date.now()
await fetchChatCompletion({
messages: messagesForContext,
assistant: assistant,
onChunkReceived: streamProcessorCallbacks
})
const orchestrationService = new OrchestrationService()
await orchestrationService.handleUserMessage(
{
messages: messagesForContext,
assistant,
options: {
timeout: 30000
}
},
streamProcessorCallbacks
)
} catch (error: any) {
console.error('Error fetching chat completion:', error)
if (assistantMessage) {
callbacks.onError?.(error)
console.error('Error in fetchAndProcessAssistantResponseImpl:', error)
// The main error handling is now delegated to OrchestrationService,
// which calls the `onError` callback. This catch block is for
// any errors that might occur outside of that orchestration flow.
if (assistantMessage && callbacks.onError) {
callbacks.onError(error)
} else {
// Fallback if callbacks are not even defined yet
throw error
}
}

View File

@ -16,7 +16,7 @@
"@renderer/*": ["src/renderer/src/*"],
"@shared/*": ["packages/shared/*"],
"@types": ["src/renderer/src/types/index.ts"],
"@cherry-studio/ai-core": ["packages/aiCore/src/"]
"@cherrystudio/ai-core": ["packages/aiCore/src/"]
}
}
}

View File

@ -960,9 +960,9 @@ __metadata:
languageName: node
linkType: hard
"@cherry-studio/ai-core@workspace:*, @cherry-studio/ai-core@workspace:packages/aiCore":
"@cherrystudio/ai-core@workspace:*, @cherrystudio/ai-core@workspace:packages/aiCore":
version: 0.0.0-use.local
resolution: "@cherry-studio/ai-core@workspace:packages/aiCore"
resolution: "@cherrystudio/ai-core@workspace:packages/aiCore"
dependencies:
"@ai-sdk/amazon-bedrock": "npm:^2.2.10"
"@ai-sdk/anthropic": "npm:^1.2.12"
@ -6392,7 +6392,7 @@ __metadata:
"@agentic/tavily": "npm:^7.3.3"
"@ant-design/v5-patch-for-react-19": "npm:^1.0.3"
"@anthropic-ai/sdk": "npm:^0.41.0"
"@cherry-studio/ai-core": "workspace:*"
"@cherrystudio/ai-core": "workspace:*"
"@cherrystudio/embedjs": "npm:^0.1.31"
"@cherrystudio/embedjs-libsql": "npm:^0.1.31"
"@cherrystudio/embedjs-loader-csv": "npm:^0.1.31"