mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-27 21:01:32 +08:00
feat: 完成api层,业务逻辑层,编排层的分离
feat: 为插件系统实现中间件 feat: 实现自定义的思考中间件 - Updated package.json and related files to reflect the correct naming convention for the @cherrystudio/ai-core package. - Adjusted import paths in various files to ensure consistency with the new package name. - Enhanced type resolution in tsconfig.web.json to align with the updated package structure.
This commit is contained in:
parent
43d55b7e45
commit
1bccfd3170
@ -73,7 +73,7 @@
|
||||
"@agentic/tavily": "^7.3.3",
|
||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||
"@anthropic-ai/sdk": "^0.41.0",
|
||||
"@cherry-studio/ai-core": "workspace:*",
|
||||
"@cherrystudio/ai-core": "workspace:*",
|
||||
"@cherrystudio/embedjs": "^0.1.31",
|
||||
"@cherrystudio/embedjs-libsql": "^0.1.31",
|
||||
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
|
||||
|
||||
@ -317,7 +317,7 @@ export class AiCoreService {
|
||||
### 5.1 多 Provider 支持
|
||||
|
||||
```typescript
|
||||
import { createAiSdkClient, AiCore } from '@cherry-studio/ai-core'
|
||||
import { createAiSdkClient, AiCore } from '@cherrystudio/ai-core'
|
||||
|
||||
// 检查支持的 providers
|
||||
const providers = AiCore.getSupportedProviders()
|
||||
@ -339,7 +339,7 @@ const xai = await createAiSdkClient('xai', { apiKey: 'xai-key' })
|
||||
// const anthropicClient = new AnthropicApiClient(config)
|
||||
|
||||
// 现在:
|
||||
import { createAiSdkClient } from '@cherry-studio/ai-core'
|
||||
import { createAiSdkClient } from '@cherrystudio/ai-core'
|
||||
|
||||
const createProviderClient = async (provider: CherryProvider) => {
|
||||
return await createAiSdkClient(provider.id, {
|
||||
@ -359,7 +359,7 @@ import {
|
||||
PreRequestMiddleware,
|
||||
StreamProcessingMiddleware,
|
||||
PostResponseMiddleware
|
||||
} from '@cherry-studio/ai-core'
|
||||
} from '@cherrystudio/ai-core'
|
||||
|
||||
// 创建完整的工作流
|
||||
const createEnhancedAiService = async () => {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# @cherry-studio/ai-core
|
||||
# @cherrystudio/ai-core
|
||||
|
||||
Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口包。
|
||||
|
||||
@ -42,7 +42,7 @@ Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口
|
||||
## 安装
|
||||
|
||||
```bash
|
||||
npm install @cherry-studio/ai-core ai
|
||||
npm install @cherrystudio/ai-core ai
|
||||
```
|
||||
|
||||
还需要安装你要使用的 AI SDK provider:
|
||||
@ -56,7 +56,7 @@ npm install @ai-sdk/openai @ai-sdk/anthropic @ai-sdk/google
|
||||
### 基础用法
|
||||
|
||||
```typescript
|
||||
import { createAiSdkClient } from '@cherry-studio/ai-core'
|
||||
import { createAiSdkClient } from '@cherrystudio/ai-core'
|
||||
|
||||
// 创建 OpenAI 客户端
|
||||
const client = await createAiSdkClient('openai', {
|
||||
@ -79,7 +79,7 @@ const response = await client.generate({
|
||||
### 便捷函数
|
||||
|
||||
```typescript
|
||||
import { createOpenAIClient, streamGeneration } from '@cherry-studio/ai-core'
|
||||
import { createOpenAIClient, streamGeneration } from '@cherrystudio/ai-core'
|
||||
|
||||
// 快速创建 OpenAI 客户端
|
||||
const client = await createOpenAIClient({
|
||||
@ -95,7 +95,7 @@ const result = await streamGeneration('openai', 'gpt-4', [{ role: 'user', conten
|
||||
### 多 Provider 支持
|
||||
|
||||
```typescript
|
||||
import { createAiSdkClient } from '@cherry-studio/ai-core'
|
||||
import { createAiSdkClient } from '@cherrystudio/ai-core'
|
||||
|
||||
// 支持多种 AI providers
|
||||
const openaiClient = await createAiSdkClient('openai', { apiKey: 'openai-key' })
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
{
|
||||
"name": "@cherry-studio/ai-core",
|
||||
"name": "@cherrystudio/ai-core",
|
||||
"version": "1.0.0",
|
||||
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
||||
"main": "src/index.ts",
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
*/
|
||||
|
||||
import type { ImageModelV1 } from '@ai-sdk/provider'
|
||||
import { type LanguageModelV1, wrapLanguageModel } from 'ai'
|
||||
import { type LanguageModelV1, LanguageModelV1Middleware, wrapLanguageModel } from 'ai'
|
||||
|
||||
import { aiProviderRegistry } from '../providers/registry'
|
||||
import { type ProviderId, type ProviderSettingsMap } from './types'
|
||||
@ -39,16 +39,23 @@ export class ApiClientFactory {
|
||||
static async createClient<T extends ProviderId>(
|
||||
providerId: T,
|
||||
modelId: string,
|
||||
options: ProviderSettingsMap[T]
|
||||
options: ProviderSettingsMap[T],
|
||||
middlewares?: LanguageModelV1Middleware[]
|
||||
): Promise<LanguageModelV1>
|
||||
|
||||
static async createClient(
|
||||
providerId: string,
|
||||
modelId: string,
|
||||
options: ProviderSettingsMap['openai-compatible']
|
||||
options: ProviderSettingsMap['openai-compatible'],
|
||||
middlewares?: LanguageModelV1Middleware[]
|
||||
): Promise<LanguageModelV1>
|
||||
|
||||
static async createClient(providerId: string, modelId: string = 'default', options: any): Promise<LanguageModelV1> {
|
||||
static async createClient(
|
||||
providerId: string,
|
||||
modelId: string = 'default',
|
||||
options: any,
|
||||
middlewares?: LanguageModelV1Middleware[]
|
||||
): Promise<LanguageModelV1> {
|
||||
try {
|
||||
// 对于不在注册表中的 provider,默认使用 openai-compatible
|
||||
const effectiveProviderId = aiProviderRegistry.isSupported(providerId) ? providerId : 'openai-compatible'
|
||||
@ -78,10 +85,10 @@ export class ApiClientFactory {
|
||||
let model = provider(modelId)
|
||||
|
||||
// 应用 AI SDK 中间件
|
||||
if (providerConfig.aiSdkMiddlewares) {
|
||||
if (middlewares && middlewares.length > 0) {
|
||||
model = wrapLanguageModel({
|
||||
model: model,
|
||||
middleware: providerConfig.aiSdkMiddlewares
|
||||
middleware: middlewares
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
* ## 使用方式
|
||||
*
|
||||
* ```typescript
|
||||
* import { AiClient } from '@cherry-studio/ai-core'
|
||||
* import { AiClient } from '@cherrystudio/ai-core'
|
||||
*
|
||||
* // 创建客户端(默认带插件系统)
|
||||
* const client = AiClient.create('openai', {
|
||||
@ -19,7 +19,14 @@
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
import { generateObject, generateText, streamObject, streamText } from 'ai'
|
||||
import {
|
||||
generateObject,
|
||||
generateText,
|
||||
LanguageModelV1Middleware,
|
||||
simulateStreamingMiddleware,
|
||||
streamObject,
|
||||
streamText
|
||||
} from 'ai'
|
||||
|
||||
import { AiPlugin, createContext, PluginManager } from '../plugins'
|
||||
import { isProviderSupported } from '../providers/registry'
|
||||
@ -34,6 +41,7 @@ import { UniversalAiSdkClient } from './UniversalAiSdkClient'
|
||||
export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
|
||||
private pluginManager: PluginManager
|
||||
private baseClient: UniversalAiSdkClient<T>
|
||||
private middlewares: LanguageModelV1Middleware[] = []
|
||||
|
||||
constructor(
|
||||
private readonly providerId: T,
|
||||
@ -42,6 +50,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
|
||||
) {
|
||||
this.pluginManager = new PluginManager(plugins)
|
||||
this.baseClient = UniversalAiSdkClient.create(providerId, options)
|
||||
this.updateMiddlewares()
|
||||
}
|
||||
|
||||
/**
|
||||
@ -49,6 +58,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
|
||||
*/
|
||||
use(plugin: AiPlugin): this {
|
||||
this.pluginManager.use(plugin)
|
||||
this.updateMiddlewares()
|
||||
return this
|
||||
}
|
||||
|
||||
@ -57,6 +67,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
|
||||
*/
|
||||
usePlugins(plugins: AiPlugin[]): this {
|
||||
plugins.forEach((plugin) => this.pluginManager.use(plugin))
|
||||
this.updateMiddlewares()
|
||||
return this
|
||||
}
|
||||
|
||||
@ -65,9 +76,19 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
|
||||
*/
|
||||
removePlugin(pluginName: string): this {
|
||||
this.pluginManager.remove(pluginName)
|
||||
this.updateMiddlewares()
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 重新计算并更新中间件列表
|
||||
* 这是一个原子操作,以确保中间件列表总是最新的
|
||||
*/
|
||||
private updateMiddlewares(): void {
|
||||
const pluginMiddlewares = this.pluginManager.collectAiSdkMiddlewares()
|
||||
this.middlewares = pluginMiddlewares
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取插件统计信息
|
||||
*/
|
||||
@ -164,6 +185,21 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取注入了中间件的 AI SDK 模型实例
|
||||
* 这是应用原生中间件的关键
|
||||
*/
|
||||
private async getModelWithMiddlewares(modelId: string) {
|
||||
const middlewares = this.middlewares
|
||||
// 3. 如果有中间件,创建一个新的、注入了中间件的客户端实例
|
||||
return await ApiClientFactory.createClient(
|
||||
this.providerId,
|
||||
modelId,
|
||||
this.options,
|
||||
middlewares.length > 0 ? middlewares : [simulateStreamingMiddleware()] //TODO: 这里需要改成非流时调用simulateStreamingMiddleware(),这里先随便传一个
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 流式文本生成 - 集成插件系统
|
||||
*/
|
||||
@ -176,8 +212,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
|
||||
modelId,
|
||||
params,
|
||||
async (finalModelId, transformedParams, streamTransforms) => {
|
||||
// 对于流式调用,需要直接调用 AI SDK 以支持流转换器
|
||||
const model = await ApiClientFactory.createClient(this.providerId, finalModelId, this.options)
|
||||
const model = await this.getModelWithMiddlewares(finalModelId)
|
||||
return await streamText({
|
||||
model,
|
||||
...transformedParams,
|
||||
@ -189,13 +224,15 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
|
||||
|
||||
/**
|
||||
* 生成文本 - 集成插件系统
|
||||
* 可能不需要了,因为内置模拟非流中间件
|
||||
*/
|
||||
async generateText(
|
||||
modelId: string,
|
||||
params: Omit<Parameters<typeof generateText>[0], 'model'>
|
||||
): Promise<ReturnType<typeof generateText>> {
|
||||
return this.executeWithPlugins('generateText', modelId, params, async (finalModelId, transformedParams) => {
|
||||
return await this.baseClient.generateText(finalModelId, transformedParams)
|
||||
const model = await this.getModelWithMiddlewares(finalModelId)
|
||||
return await generateText({ model, ...transformedParams })
|
||||
})
|
||||
}
|
||||
|
||||
@ -207,7 +244,8 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
|
||||
params: Omit<Parameters<typeof generateObject>[0], 'model'>
|
||||
): Promise<ReturnType<typeof generateObject>> {
|
||||
return this.executeWithPlugins('generateObject', modelId, params, async (finalModelId, transformedParams) => {
|
||||
return await this.baseClient.generateObject(finalModelId, transformedParams)
|
||||
const model = await this.getModelWithMiddlewares(finalModelId)
|
||||
return await generateObject({ model, ...transformedParams })
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
*
|
||||
* ### 1. 官方提供商
|
||||
* ```typescript
|
||||
* import { UniversalAiSdkClient } from '@cherry-studio/ai-core'
|
||||
* import { UniversalAiSdkClient } from '@cherrystudio/ai-core'
|
||||
*
|
||||
* // OpenAI
|
||||
* const openai = UniversalAiSdkClient.create('openai', {
|
||||
|
||||
@ -66,6 +66,8 @@ export type {
|
||||
GenerateTextResult,
|
||||
InvalidToolArgumentsError,
|
||||
LanguageModelUsage, // AI SDK 4.0 中 TokenUsage 改名为 LanguageModelUsage
|
||||
LanguageModelV1Middleware,
|
||||
LanguageModelV1StreamPart,
|
||||
// 错误类型
|
||||
NoSuchToolError,
|
||||
StreamTextResult,
|
||||
@ -115,7 +117,7 @@ export { getAllProviders, getProvider, isProviderSupported, registerProvider } f
|
||||
|
||||
// ==================== 包信息 ====================
|
||||
export const AI_CORE_VERSION = '1.0.0'
|
||||
export const AI_CORE_NAME = '@cherry-studio/ai-core'
|
||||
export const AI_CORE_NAME = '@cherrystudio/ai-core'
|
||||
|
||||
// ==================== 便捷 API ====================
|
||||
// 主要的便捷工厂类
|
||||
|
||||
@ -50,7 +50,7 @@ transformStream?: () => (options) => TransformStream<TextStreamPart, TextStreamP
|
||||
### 基础用法
|
||||
|
||||
```typescript
|
||||
import { PluginManager, createContext, definePlugin } from '@cherry-studio/ai-core/middleware'
|
||||
import { PluginManager, createContext, definePlugin } from '@cherrystudio/ai-core/middleware'
|
||||
|
||||
// 创建插件管理器
|
||||
const pluginManager = new PluginManager()
|
||||
@ -81,7 +81,7 @@ import {
|
||||
LoggingPlugin,
|
||||
ParamsValidationPlugin,
|
||||
createContext
|
||||
} from '@cherry-studio/ai-core/middleware'
|
||||
} from '@cherrystudio/ai-core/middleware'
|
||||
|
||||
// 创建插件管理器
|
||||
const manager = new PluginManager([
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
import type { LanguageModelV1Middleware, TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { AiPlugin, AiRequestContext } from './types'
|
||||
|
||||
@ -135,6 +135,13 @@ export class PluginManager {
|
||||
>
|
||||
}
|
||||
|
||||
/**
|
||||
* 收集所有 AI SDK 原生中间件
|
||||
*/
|
||||
collectAiSdkMiddlewares(): LanguageModelV1Middleware[] {
|
||||
return this.plugins.flatMap((plugin) => plugin.aiSdkMiddlewares || [])
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有插件信息
|
||||
*/
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
import type { LanguageModelV1Middleware } from 'ai'
|
||||
|
||||
/**
|
||||
* AI Provider 注册表
|
||||
* 静态类型 + 动态导入模式:所有类型静态导入,所有实现动态导入
|
||||
@ -73,8 +71,6 @@ export interface ProviderConfig {
|
||||
creatorFunctionName: string
|
||||
// 是否支持图片生成
|
||||
supportsImageGeneration?: boolean
|
||||
// AI SDK 原生中间件
|
||||
aiSdkMiddlewares?: LanguageModelV1Middleware[]
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
* 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式
|
||||
*/
|
||||
|
||||
import { TextStreamPart } from '@cherry-studio/ai-core'
|
||||
import { TextStreamPart } from '@cherrystudio/ai-core'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
export interface CherryStudioChunk {
|
||||
@ -78,38 +78,14 @@ export class AiSdkToChunkAdapter {
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: chunk.textDelta || ''
|
||||
})
|
||||
if (final.reasoning_content) {
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: final.reasoning_content || ''
|
||||
})
|
||||
final.reasoning_content = ''
|
||||
}
|
||||
break
|
||||
|
||||
// === 推理相关事件 ===
|
||||
case 'reasoning':
|
||||
final.reasoning_content += chunk.textDelta || ''
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: chunk.textDelta || ''
|
||||
})
|
||||
break
|
||||
|
||||
case 'reasoning-signature':
|
||||
// 推理签名,可以映射到思考完成
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: chunk.signature || ''
|
||||
})
|
||||
// 不再需要处理,中间件会发出 THINKING_COMPLETE
|
||||
break
|
||||
|
||||
case 'redacted-reasoning':
|
||||
// 被编辑的推理内容,也映射到思考
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: chunk.data || ''
|
||||
})
|
||||
// 不再需要处理
|
||||
break
|
||||
|
||||
// === 工具调用相关事件 ===
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
/**
|
||||
* Cherry Studio AI Core - 新版本入口
|
||||
* 集成 @cherry-studio/ai-core 库的渐进式重构方案
|
||||
* 集成 @cherrystudio/ai-core 库的渐进式重构方案
|
||||
*
|
||||
* 融合方案:简化实现,专注于核心功能
|
||||
* 1. 优先使用新AI SDK
|
||||
@ -13,20 +13,20 @@ import {
|
||||
AiCore,
|
||||
createClient,
|
||||
type OpenAICompatibleProviderSettings,
|
||||
type ProviderId
|
||||
} from '@cherry-studio/ai-core'
|
||||
type ProviderId,
|
||||
StreamTextParams
|
||||
} from '@cherrystudio/ai-core'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import { RequestOptions } from '@renderer/types/sdk'
|
||||
import { Chunk } from '@renderer/types/chunk'
|
||||
|
||||
// 引入适配器
|
||||
import AiSdkToChunkAdapter from './AiSdkToChunkAdapter'
|
||||
// 引入原有的AiProvider作为fallback
|
||||
import LegacyAiProvider from './index'
|
||||
import { CompletionsParams, CompletionsResult } from './middleware/schemas'
|
||||
import thinkingTimeMiddleware from './middleware/aisdk/ThinkingTimeMiddleware'
|
||||
import { CompletionsResult } from './middleware/schemas'
|
||||
// 引入参数转换模块
|
||||
import { buildStreamTextParams } from './transformParameters'
|
||||
|
||||
/**
|
||||
* 将现有 Provider 类型映射到 AI SDK 的 Provider ID
|
||||
@ -108,21 +108,30 @@ export default class ModernAiProvider {
|
||||
private legacyProvider: LegacyAiProvider
|
||||
private provider: Provider
|
||||
|
||||
constructor(provider: Provider) {
|
||||
constructor(provider: Provider, onChunk?: (chunk: Chunk) => void) {
|
||||
this.provider = provider
|
||||
this.legacyProvider = new LegacyAiProvider(provider)
|
||||
|
||||
const config = providerToAiSdkConfig(provider)
|
||||
this.modernClient = createClient(config.providerId, config.options)
|
||||
this.modernClient = createClient(
|
||||
config.providerId,
|
||||
config.options,
|
||||
onChunk ? [{ name: 'thinking-time', aiSdkMiddlewares: [thinkingTimeMiddleware(onChunk)] }] : undefined
|
||||
)
|
||||
}
|
||||
|
||||
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
public async completions(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
onChunk?: (chunk: Chunk) => void
|
||||
): Promise<CompletionsResult> {
|
||||
// const model = params.assistant.model
|
||||
|
||||
// 检查是否应该使用现代化客户端
|
||||
// if (this.modernClient && model && isModernSdkSupported(this.provider, model)) {
|
||||
// try {
|
||||
return await this.modernCompletions(params, options)
|
||||
console.log('completions', modelId, params, onChunk)
|
||||
return await this.modernCompletions(modelId, params, onChunk)
|
||||
// } catch (error) {
|
||||
// console.warn('Modern client failed, falling back to legacy:', error)
|
||||
// fallback到原有实现
|
||||
@ -137,66 +146,33 @@ export default class ModernAiProvider {
|
||||
* 使用现代化AI SDK的completions实现
|
||||
* 使用 AiSdkUtils 工具模块进行参数构建
|
||||
*/
|
||||
private async modernCompletions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
if (!this.modernClient || !params.assistant.model) {
|
||||
throw new Error('Modern client not available')
|
||||
private async modernCompletions(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
onChunk?: (chunk: Chunk) => void
|
||||
): Promise<CompletionsResult> {
|
||||
if (!this.modernClient) {
|
||||
throw new Error('Modern AI SDK client not initialized')
|
||||
}
|
||||
|
||||
console.log('Modern completions with params:', params, 'options:', options)
|
||||
|
||||
const model = params.assistant.model
|
||||
const assistant = params.assistant
|
||||
|
||||
// 检查 messages 类型并转换
|
||||
const messages = Array.isArray(params.messages) ? params.messages : []
|
||||
if (typeof params.messages === 'string') {
|
||||
console.warn('Messages is string, using empty array')
|
||||
}
|
||||
|
||||
// 使用 transformParameters 模块构建参数
|
||||
const aiSdkParams = await buildStreamTextParams(messages, assistant, model, {
|
||||
maxTokens: params.maxTokens,
|
||||
mcpTools: params.mcpTools
|
||||
})
|
||||
|
||||
console.log('Built AI SDK params:', aiSdkParams)
|
||||
const chunks: Chunk[] = []
|
||||
|
||||
try {
|
||||
if (params.streamOutput && params.onChunk) {
|
||||
if (onChunk) {
|
||||
// 流式处理 - 使用适配器
|
||||
const adapter = new AiSdkToChunkAdapter(params.onChunk)
|
||||
const streamResult = await this.modernClient.streamText(model.id, aiSdkParams)
|
||||
const adapter = new AiSdkToChunkAdapter(onChunk)
|
||||
const streamResult = await this.modernClient.streamText(modelId, params)
|
||||
const finalText = await adapter.processStream(streamResult)
|
||||
|
||||
return {
|
||||
getText: () => finalText
|
||||
}
|
||||
} else if (params.streamOutput) {
|
||||
} else {
|
||||
// 流式处理但没有 onChunk 回调
|
||||
const streamResult = await this.modernClient.streamText(model.id, aiSdkParams)
|
||||
const streamResult = await this.modernClient.streamText(modelId, params)
|
||||
const finalText = await streamResult.text
|
||||
|
||||
return {
|
||||
getText: () => finalText
|
||||
}
|
||||
} else {
|
||||
// 非流式处理
|
||||
const result = await this.modernClient.generateText(model.id, aiSdkParams)
|
||||
|
||||
const cherryChunk: Chunk = {
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: result.text || ''
|
||||
}
|
||||
chunks.push(cherryChunk)
|
||||
|
||||
if (params.onChunk) {
|
||||
params.onChunk(cherryChunk)
|
||||
}
|
||||
|
||||
return {
|
||||
getText: () => result.text || ''
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Modern AI SDK error:', error)
|
||||
|
||||
@ -0,0 +1,70 @@
|
||||
import { LanguageModelV1Middleware, LanguageModelV1StreamPart } from '@cherrystudio/ai-core'
|
||||
import { Chunk, ChunkType, ThinkingCompleteChunk } from '@renderer/types/chunk'
|
||||
|
||||
/**
|
||||
* 一个用于统计 LLM "思考时间"(Time to First Token)的 AI SDK 中间件。
|
||||
*
|
||||
* 工作原理:
|
||||
* 1. 在 `stream` 方法被调用时,记录一个起始时间。
|
||||
* 2. 它会创建一个新的 `TransformStream` 来代理原始的流。
|
||||
* 3. 当第一个数据块 (chunk) 从原始流中到达时,记录结束时间。
|
||||
* 4. 计算两者之差,即为 "思考时间"
|
||||
*/
|
||||
export default function thinkingTimeMiddleware(onChunkReceived: (chunk: Chunk) => void): LanguageModelV1Middleware {
|
||||
return {
|
||||
wrapStream: async ({ doStream }) => {
|
||||
let hasThinkingContent = false
|
||||
let thinkingStartTime = 0
|
||||
let accumulatedThinkingContent = ''
|
||||
const { stream, ...reset } = await doStream()
|
||||
const transformStream = new TransformStream<LanguageModelV1StreamPart, LanguageModelV1StreamPart>({
|
||||
transform(chunk, controller) {
|
||||
if (chunk.type === 'reasoning' || chunk.type === 'redacted-reasoning') {
|
||||
if (!hasThinkingContent) {
|
||||
hasThinkingContent = true
|
||||
thinkingStartTime = Date.now()
|
||||
}
|
||||
accumulatedThinkingContent += chunk.textDelta || ''
|
||||
onChunkReceived({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: chunk.textDelta || ''
|
||||
})
|
||||
} else {
|
||||
if (hasThinkingContent && thinkingStartTime > 0) {
|
||||
const thinkingTime = Date.now() - thinkingStartTime
|
||||
const thinkingCompleteChunk: ThinkingCompleteChunk = {
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: accumulatedThinkingContent,
|
||||
thinking_millsec: thinkingTime
|
||||
}
|
||||
onChunkReceived(thinkingCompleteChunk)
|
||||
hasThinkingContent = false
|
||||
thinkingStartTime = 0
|
||||
accumulatedThinkingContent = ''
|
||||
}
|
||||
}
|
||||
// 将所有 chunk 原样传递下去
|
||||
controller.enqueue(chunk)
|
||||
},
|
||||
flush(controller) {
|
||||
// 如果流的末尾都是 reasoning,也需要发送 complete 事件
|
||||
if (hasThinkingContent && thinkingStartTime > 0) {
|
||||
const thinkingTime = Date.now() - thinkingStartTime
|
||||
const thinkingCompleteChunk: ThinkingCompleteChunk = {
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: accumulatedThinkingContent,
|
||||
thinking_millsec: thinkingTime
|
||||
}
|
||||
onChunkReceived(thinkingCompleteChunk)
|
||||
}
|
||||
controller.terminate()
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
stream: stream.pipeThrough(transformStream),
|
||||
...reset
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3,8 +3,19 @@
|
||||
* 统一管理从各个 apiClient 提取的参数处理和转换功能
|
||||
*/
|
||||
|
||||
import type { StreamTextParams } from '@cherry-studio/ai-core'
|
||||
import { isNotSupportTemperatureAndTopP, isSupportedFlexServiceTier } from '@renderer/config/models'
|
||||
import type { CoreMessage, StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import {
|
||||
isGenerateImageModel,
|
||||
isNotSupportTemperatureAndTopP,
|
||||
isOpenRouterBuiltInWebSearchModel,
|
||||
isReasoningModel,
|
||||
isSupportedDisableGenerationModel,
|
||||
isSupportedFlexServiceTier,
|
||||
isSupportedReasoningEffortModel,
|
||||
isSupportedThinkingTokenModel,
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
|
||||
import { FileTypes } from '@renderer/types'
|
||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
@ -183,19 +194,38 @@ export async function convertMessagesToSdkMessages(
|
||||
* 这是主要的参数构建函数,整合所有转换逻辑
|
||||
*/
|
||||
export async function buildStreamTextParams(
|
||||
messages: Message[],
|
||||
sdkMessages: StreamTextParams['messages'],
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
options: {
|
||||
maxTokens?: number
|
||||
mcpTools?: MCPTool[]
|
||||
enableTools?: boolean
|
||||
requestOptions?: {
|
||||
signal?: AbortSignal
|
||||
timeout?: number
|
||||
headers?: Record<string, string>
|
||||
}
|
||||
} = {}
|
||||
): Promise<StreamTextParams> {
|
||||
const { maxTokens, mcpTools, enableTools = false } = options
|
||||
): Promise<{ params: StreamTextParams; modelId: string }> {
|
||||
const { mcpTools, enableTools = false } = options
|
||||
|
||||
// 转换消息
|
||||
const sdkMessages = await convertMessagesToSdkMessages(messages, model)
|
||||
const model = assistant.model || getDefaultModel()
|
||||
|
||||
const { maxTokens, reasoning_effort } = getAssistantSettings(assistant)
|
||||
|
||||
const enableReasoning =
|
||||
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
|
||||
reasoning_effort !== undefined) ||
|
||||
(isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model)))
|
||||
|
||||
const enableWebSearch =
|
||||
(assistant.enableWebSearch && isWebSearchModel(model)) ||
|
||||
isOpenRouterBuiltInWebSearchModel(model) ||
|
||||
model.id.includes('sonar') ||
|
||||
false
|
||||
|
||||
const enableGenerateImage =
|
||||
isGenerateImageModel(model) &&
|
||||
(isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true)
|
||||
|
||||
// 构建系统提示
|
||||
let systemPrompt = assistant.prompt || ''
|
||||
@ -210,6 +240,20 @@ export async function buildStreamTextParams(
|
||||
temperature: getTemperature(assistant, model),
|
||||
topP: getTopP(assistant, model),
|
||||
system: systemPrompt || undefined,
|
||||
abortSignal: options.requestOptions?.signal,
|
||||
headers: options.requestOptions?.headers,
|
||||
// 随便填着,后面再改
|
||||
providerOptions: {
|
||||
reasoning: {
|
||||
enabled: enableReasoning
|
||||
},
|
||||
webSearch: {
|
||||
enabled: enableWebSearch
|
||||
},
|
||||
generateImage: {
|
||||
enabled: enableGenerateImage
|
||||
}
|
||||
},
|
||||
...getCustomParameters(assistant)
|
||||
}
|
||||
|
||||
@ -219,24 +263,22 @@ export async function buildStreamTextParams(
|
||||
// params.tools = convertMcpToolsToSdkTools(mcpTools)
|
||||
}
|
||||
|
||||
return params
|
||||
return { params, modelId: model.id }
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建非流式的 generateText 参数
|
||||
*/
|
||||
export async function buildGenerateTextParams(
|
||||
messages: Message[],
|
||||
messages: CoreMessage[],
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
options: {
|
||||
maxTokens?: number
|
||||
mcpTools?: MCPTool[]
|
||||
enableTools?: boolean
|
||||
} = {}
|
||||
): Promise<any> {
|
||||
// 复用流式参数的构建逻辑
|
||||
return await buildStreamTextParams(messages, assistant, model, options)
|
||||
return await buildStreamTextParams(messages, assistant, options)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -1,379 +1,311 @@
|
||||
/**
|
||||
* 职责:提供原子化的、无状态的API调用函数
|
||||
*/
|
||||
|
||||
import { StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
|
||||
import {
|
||||
isEmbeddingModel,
|
||||
isGenerateImageModel,
|
||||
isOpenRouterBuiltInWebSearchModel,
|
||||
isReasoningModel,
|
||||
isSupportedDisableGenerationModel,
|
||||
isSupportedReasoningEffortModel,
|
||||
isSupportedThinkingTokenModel,
|
||||
isWebSearchModel
|
||||
isSupportedThinkingTokenModel
|
||||
} from '@renderer/config/models'
|
||||
import {
|
||||
SEARCH_SUMMARY_PROMPT,
|
||||
SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY,
|
||||
SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||
} from '@renderer/config/prompts'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import {
|
||||
Assistant,
|
||||
ExternalToolResult,
|
||||
KnowledgeReference,
|
||||
MCPTool,
|
||||
Model,
|
||||
Provider,
|
||||
WebSearchResponse,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import { Assistant, MCPTool, Model, Provider } from '@renderer/types'
|
||||
import { type Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import { SdkModel } from '@renderer/types/sdk'
|
||||
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
||||
import { isAbortError } from '@renderer/utils/error'
|
||||
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
|
||||
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { findLast, isEmpty, takeRight } from 'lodash'
|
||||
import { isEmpty, takeRight } from 'lodash'
|
||||
|
||||
import AiProvider from '../aiCore'
|
||||
import AiProviderNew from '../aiCore/index_new'
|
||||
import {
|
||||
getAssistantProvider,
|
||||
getAssistantSettings,
|
||||
getDefaultModel,
|
||||
getProviderByModel,
|
||||
getTopNamingModel,
|
||||
getTranslateModel
|
||||
} from './AssistantService'
|
||||
import { getDefaultAssistant } from './AssistantService'
|
||||
import { processKnowledgeSearch } from './KnowledgeService'
|
||||
import {
|
||||
filterContextMessages,
|
||||
filterEmptyMessages,
|
||||
filterUsefulMessages,
|
||||
filterUserRoleStartMessages
|
||||
} from './MessagesService'
|
||||
import WebSearchService from './WebSearchService'
|
||||
|
||||
// TODO:考虑拆开
|
||||
async function fetchExternalTool(
|
||||
lastUserMessage: Message,
|
||||
assistant: Assistant,
|
||||
onChunkReceived: (chunk: Chunk) => void,
|
||||
lastAnswer?: Message
|
||||
): Promise<ExternalToolResult> {
|
||||
// 可能会有重复?
|
||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
const webSearchProvider = WebSearchService.getWebSearchProvider(assistant.webSearchProviderId)
|
||||
// // TODO:考虑拆开
|
||||
// async function fetchExternalTool(
|
||||
// lastUserMessage: Message,
|
||||
// assistant: Assistant,
|
||||
// onChunkReceived: (chunk: Chunk) => void,
|
||||
// lastAnswer?: Message
|
||||
// ) {
|
||||
// // 可能会有重复?
|
||||
// const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
// const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
// const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
// const webSearchProvider = WebSearchService.getWebSearchProvider(assistant.webSearchProviderId)
|
||||
|
||||
// 使用外部搜索工具
|
||||
const shouldWebSearch = !!assistant.webSearchProviderId && webSearchProvider !== null
|
||||
const shouldKnowledgeSearch = hasKnowledgeBase
|
||||
// // 使用外部搜索工具
|
||||
// const shouldWebSearch = !!assistant.webSearchProviderId && webSearchProvider !== null
|
||||
// const shouldKnowledgeSearch = hasKnowledgeBase
|
||||
|
||||
// 在工具链开始时发送进度通知
|
||||
const willUseTools = shouldWebSearch || shouldKnowledgeSearch
|
||||
if (willUseTools) {
|
||||
onChunkReceived({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS })
|
||||
}
|
||||
// // 在工具链开始时发送进度通知
|
||||
// const willUseTools = shouldWebSearch || shouldKnowledgeSearch
|
||||
// if (willUseTools) {
|
||||
// onChunkReceived({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS })
|
||||
// }
|
||||
|
||||
// --- Keyword/Question Extraction Function ---
|
||||
const extract = async (): Promise<ExtractResults | undefined> => {
|
||||
if (!lastUserMessage) return undefined
|
||||
// // --- Keyword/Question Extraction Function ---
|
||||
// const extract = async (): Promise<ExtractResults | undefined> => {
|
||||
// if (!lastUserMessage) return undefined
|
||||
|
||||
// 根据配置决定是否需要提取
|
||||
const needWebExtract = shouldWebSearch
|
||||
const needKnowledgeExtract = hasKnowledgeBase && knowledgeRecognition === 'on'
|
||||
// // 根据配置决定是否需要提取
|
||||
// const needWebExtract = shouldWebSearch
|
||||
// const needKnowledgeExtract = hasKnowledgeBase && knowledgeRecognition === 'on'
|
||||
|
||||
if (!needWebExtract && !needKnowledgeExtract) return undefined
|
||||
// if (!needWebExtract && !needKnowledgeExtract) return undefined
|
||||
|
||||
let prompt: string
|
||||
if (needWebExtract && !needKnowledgeExtract) {
|
||||
prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||
} else if (!needWebExtract && needKnowledgeExtract) {
|
||||
prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
|
||||
} else {
|
||||
prompt = SEARCH_SUMMARY_PROMPT
|
||||
}
|
||||
// let prompt: string
|
||||
// if (needWebExtract && !needKnowledgeExtract) {
|
||||
// prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||
// } else if (!needWebExtract && needKnowledgeExtract) {
|
||||
// prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
|
||||
// } else {
|
||||
// prompt = SEARCH_SUMMARY_PROMPT
|
||||
// }
|
||||
|
||||
const summaryAssistant = getDefaultAssistant()
|
||||
summaryAssistant.model = assistant.model || getDefaultModel()
|
||||
summaryAssistant.prompt = prompt
|
||||
// const summaryAssistant = getDefaultAssistant()
|
||||
// summaryAssistant.model = assistant.model || getDefaultModel()
|
||||
// summaryAssistant.prompt = prompt
|
||||
|
||||
// try {
|
||||
// const result = await fetchSearchSummary({
|
||||
// messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage],
|
||||
// assistant: summaryAssistant
|
||||
// })
|
||||
|
||||
// if (!result) return getFallbackResult()
|
||||
|
||||
// const extracted = extractInfoFromXML(result.getText())
|
||||
// // 根据需求过滤结果
|
||||
// return {
|
||||
// websearch: needWebExtract ? extracted?.websearch : undefined,
|
||||
// knowledge: needKnowledgeExtract ? extracted?.knowledge : undefined
|
||||
// }
|
||||
// } catch (e: any) {
|
||||
// console.error('extract error', e)
|
||||
// if (isAbortError(e)) throw e
|
||||
// return getFallbackResult()
|
||||
// }
|
||||
// }
|
||||
|
||||
// const getFallbackResult = (): ExtractResults => {
|
||||
// const fallbackContent = getMainTextContent(lastUserMessage)
|
||||
// return {
|
||||
// websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
|
||||
// knowledge: shouldKnowledgeSearch
|
||||
// ? {
|
||||
// question: [fallbackContent || 'search'],
|
||||
// rewrite: fallbackContent
|
||||
// }
|
||||
// : undefined
|
||||
// }
|
||||
// }
|
||||
|
||||
// // --- Web Search Function ---
|
||||
// const searchTheWeb = async (extractResults: ExtractResults | undefined): Promise<WebSearchResponse | undefined> => {
|
||||
// if (!shouldWebSearch) return
|
||||
|
||||
// // Add check for extractResults existence early
|
||||
// if (!extractResults?.websearch) {
|
||||
// console.warn('searchTheWeb called without valid extractResults.websearch')
|
||||
// return
|
||||
// }
|
||||
|
||||
// if (extractResults.websearch.question[0] === 'not_needed') return
|
||||
|
||||
// // Add check for assistant.model before using it
|
||||
// if (!assistant.model) {
|
||||
// console.warn('searchTheWeb called without assistant.model')
|
||||
// return undefined
|
||||
// }
|
||||
|
||||
// try {
|
||||
// // Use the consolidated processWebsearch function
|
||||
// WebSearchService.createAbortSignal(lastUserMessage.id)
|
||||
// return {
|
||||
// results: await WebSearchService.processWebsearch(webSearchProvider!, extractResults),
|
||||
// source: WebSearchSource.WEBSEARCH
|
||||
// }
|
||||
// } catch (error) {
|
||||
// if (isAbortError(error)) throw error
|
||||
// console.error('Web search failed:', error)
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
|
||||
// // --- Knowledge Base Search Function ---
|
||||
// const searchKnowledgeBase = async (
|
||||
// extractResults: ExtractResults | undefined
|
||||
// ): Promise<KnowledgeReference[] | undefined> => {
|
||||
// if (!hasKnowledgeBase) return
|
||||
|
||||
// // 知识库搜索条件
|
||||
// let searchCriteria: { question: string[]; rewrite: string }
|
||||
// if (knowledgeRecognition === 'off') {
|
||||
// const directContent = getMainTextContent(lastUserMessage)
|
||||
// searchCriteria = { question: [directContent || 'search'], rewrite: directContent }
|
||||
// } else {
|
||||
// // auto mode
|
||||
// if (!extractResults?.knowledge) {
|
||||
// console.warn('searchKnowledgeBase: No valid search criteria in auto mode')
|
||||
// return
|
||||
// }
|
||||
// searchCriteria = extractResults.knowledge
|
||||
// }
|
||||
|
||||
// if (searchCriteria.question[0] === 'not_needed') return
|
||||
|
||||
// try {
|
||||
// const tempExtractResults: ExtractResults = {
|
||||
// websearch: undefined,
|
||||
// knowledge: searchCriteria
|
||||
// }
|
||||
// // Attempt to get knowledgeBaseIds from the main text block
|
||||
// // NOTE: This assumes knowledgeBaseIds are ONLY on the main text block
|
||||
// // NOTE: processKnowledgeSearch needs to handle undefined ids gracefully
|
||||
// // const mainTextBlock = mainTextBlocks
|
||||
// // ?.map((blockId) => store.getState().messageBlocks.entities[blockId])
|
||||
// // .find((block) => block?.type === MessageBlockType.MAIN_TEXT) as MainTextMessageBlock | undefined
|
||||
// return await processKnowledgeSearch(tempExtractResults, knowledgeBaseIds)
|
||||
// } catch (error) {
|
||||
// console.error('Knowledge base search failed:', error)
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
|
||||
// // --- Execute Extraction and Searches ---
|
||||
// let extractResults: ExtractResults | undefined
|
||||
|
||||
// try {
|
||||
// // 根据配置决定是否需要提取
|
||||
// if (shouldWebSearch || hasKnowledgeBase) {
|
||||
// extractResults = await extract()
|
||||
// Logger.log('[fetchExternalTool] Extraction results:', extractResults)
|
||||
// }
|
||||
|
||||
// let webSearchResponseFromSearch: WebSearchResponse | undefined
|
||||
// let knowledgeReferencesFromSearch: KnowledgeReference[] | undefined
|
||||
|
||||
// // 并行执行搜索
|
||||
// if (shouldWebSearch || shouldKnowledgeSearch) {
|
||||
// ;[webSearchResponseFromSearch, knowledgeReferencesFromSearch] = await Promise.all([
|
||||
// searchTheWeb(extractResults),
|
||||
// searchKnowledgeBase(extractResults)
|
||||
// ])
|
||||
// }
|
||||
|
||||
// // 存储搜索结果
|
||||
// if (lastUserMessage) {
|
||||
// if (webSearchResponseFromSearch) {
|
||||
// window.keyv.set(`web-search-${lastUserMessage.id}`, webSearchResponseFromSearch)
|
||||
// }
|
||||
// if (knowledgeReferencesFromSearch) {
|
||||
// window.keyv.set(`knowledge-search-${lastUserMessage.id}`, knowledgeReferencesFromSearch)
|
||||
// }
|
||||
// }
|
||||
|
||||
// // 发送工具执行完成通知
|
||||
// if (willUseTools) {
|
||||
// onChunkReceived({
|
||||
// type: ChunkType.EXTERNEL_TOOL_COMPLETE,
|
||||
// external_tool: {
|
||||
// webSearch: webSearchResponseFromSearch,
|
||||
// knowledge: knowledgeReferencesFromSearch
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// } catch (error) {
|
||||
// if (isAbortError(error)) throw error
|
||||
// console.error('Tool execution failed:', error)
|
||||
|
||||
// // 发送错误状态
|
||||
// if (willUseTools) {
|
||||
// onChunkReceived({
|
||||
// type: ChunkType.EXTERNEL_TOOL_COMPLETE,
|
||||
// external_tool: {
|
||||
// webSearch: undefined,
|
||||
// knowledge: undefined
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
|
||||
// return { mcpTools: [] }
|
||||
// }
|
||||
// }
|
||||
|
||||
export async function fetchMcpTools(assistant: Assistant) {
|
||||
// Get MCP tools (Fix duplicate declaration)
|
||||
let mcpTools: MCPTool[] = [] // Initialize as empty array
|
||||
const allMcpServers = store.getState().mcp.servers || []
|
||||
const activedMcpServers = allMcpServers.filter((s) => s.isActive)
|
||||
const assistantMcpServers = assistant.mcpServers || []
|
||||
|
||||
const enabledMCPs = activedMcpServers.filter((server) => assistantMcpServers.some((s) => s.id === server.id))
|
||||
|
||||
if (enabledMCPs && enabledMCPs.length > 0) {
|
||||
try {
|
||||
const result = await fetchSearchSummary({
|
||||
messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage],
|
||||
assistant: summaryAssistant
|
||||
})
|
||||
|
||||
if (!result) return getFallbackResult()
|
||||
|
||||
const extracted = extractInfoFromXML(result.getText())
|
||||
// 根据需求过滤结果
|
||||
return {
|
||||
websearch: needWebExtract ? extracted?.websearch : undefined,
|
||||
knowledge: needKnowledgeExtract ? extracted?.knowledge : undefined
|
||||
}
|
||||
} catch (e: any) {
|
||||
console.error('extract error', e)
|
||||
if (isAbortError(e)) throw e
|
||||
return getFallbackResult()
|
||||
}
|
||||
}
|
||||
|
||||
const getFallbackResult = (): ExtractResults => {
|
||||
const fallbackContent = getMainTextContent(lastUserMessage)
|
||||
return {
|
||||
websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
|
||||
knowledge: shouldKnowledgeSearch
|
||||
? {
|
||||
question: [fallbackContent || 'search'],
|
||||
rewrite: fallbackContent
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
}
|
||||
|
||||
// --- Web Search Function ---
|
||||
const searchTheWeb = async (extractResults: ExtractResults | undefined): Promise<WebSearchResponse | undefined> => {
|
||||
if (!shouldWebSearch) return
|
||||
|
||||
// Add check for extractResults existence early
|
||||
if (!extractResults?.websearch) {
|
||||
console.warn('searchTheWeb called without valid extractResults.websearch')
|
||||
return
|
||||
}
|
||||
|
||||
if (extractResults.websearch.question[0] === 'not_needed') return
|
||||
|
||||
// Add check for assistant.model before using it
|
||||
if (!assistant.model) {
|
||||
console.warn('searchTheWeb called without assistant.model')
|
||||
return undefined
|
||||
}
|
||||
|
||||
try {
|
||||
// Use the consolidated processWebsearch function
|
||||
WebSearchService.createAbortSignal(lastUserMessage.id)
|
||||
return {
|
||||
results: await WebSearchService.processWebsearch(webSearchProvider!, extractResults),
|
||||
source: WebSearchSource.WEBSEARCH
|
||||
}
|
||||
} catch (error) {
|
||||
if (isAbortError(error)) throw error
|
||||
console.error('Web search failed:', error)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// --- Knowledge Base Search Function ---
|
||||
const searchKnowledgeBase = async (
|
||||
extractResults: ExtractResults | undefined
|
||||
): Promise<KnowledgeReference[] | undefined> => {
|
||||
if (!hasKnowledgeBase) return
|
||||
|
||||
// 知识库搜索条件
|
||||
let searchCriteria: { question: string[]; rewrite: string }
|
||||
if (knowledgeRecognition === 'off') {
|
||||
const directContent = getMainTextContent(lastUserMessage)
|
||||
searchCriteria = { question: [directContent || 'search'], rewrite: directContent }
|
||||
} else {
|
||||
// auto mode
|
||||
if (!extractResults?.knowledge) {
|
||||
console.warn('searchKnowledgeBase: No valid search criteria in auto mode')
|
||||
return
|
||||
}
|
||||
searchCriteria = extractResults.knowledge
|
||||
}
|
||||
|
||||
if (searchCriteria.question[0] === 'not_needed') return
|
||||
|
||||
try {
|
||||
const tempExtractResults: ExtractResults = {
|
||||
websearch: undefined,
|
||||
knowledge: searchCriteria
|
||||
}
|
||||
// Attempt to get knowledgeBaseIds from the main text block
|
||||
// NOTE: This assumes knowledgeBaseIds are ONLY on the main text block
|
||||
// NOTE: processKnowledgeSearch needs to handle undefined ids gracefully
|
||||
// const mainTextBlock = mainTextBlocks
|
||||
// ?.map((blockId) => store.getState().messageBlocks.entities[blockId])
|
||||
// .find((block) => block?.type === MessageBlockType.MAIN_TEXT) as MainTextMessageBlock | undefined
|
||||
return await processKnowledgeSearch(tempExtractResults, knowledgeBaseIds)
|
||||
} catch (error) {
|
||||
console.error('Knowledge base search failed:', error)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// --- Execute Extraction and Searches ---
|
||||
let extractResults: ExtractResults | undefined
|
||||
|
||||
try {
|
||||
// 根据配置决定是否需要提取
|
||||
if (shouldWebSearch || hasKnowledgeBase) {
|
||||
extractResults = await extract()
|
||||
Logger.log('[fetchExternalTool] Extraction results:', extractResults)
|
||||
}
|
||||
|
||||
let webSearchResponseFromSearch: WebSearchResponse | undefined
|
||||
let knowledgeReferencesFromSearch: KnowledgeReference[] | undefined
|
||||
|
||||
// 并行执行搜索
|
||||
if (shouldWebSearch || shouldKnowledgeSearch) {
|
||||
;[webSearchResponseFromSearch, knowledgeReferencesFromSearch] = await Promise.all([
|
||||
searchTheWeb(extractResults),
|
||||
searchKnowledgeBase(extractResults)
|
||||
])
|
||||
}
|
||||
|
||||
// 存储搜索结果
|
||||
if (lastUserMessage) {
|
||||
if (webSearchResponseFromSearch) {
|
||||
window.keyv.set(`web-search-${lastUserMessage.id}`, webSearchResponseFromSearch)
|
||||
}
|
||||
if (knowledgeReferencesFromSearch) {
|
||||
window.keyv.set(`knowledge-search-${lastUserMessage.id}`, knowledgeReferencesFromSearch)
|
||||
}
|
||||
}
|
||||
|
||||
// 发送工具执行完成通知
|
||||
if (willUseTools) {
|
||||
onChunkReceived({
|
||||
type: ChunkType.EXTERNEL_TOOL_COMPLETE,
|
||||
external_tool: {
|
||||
webSearch: webSearchResponseFromSearch,
|
||||
knowledge: knowledgeReferencesFromSearch
|
||||
const toolPromises = enabledMCPs.map<Promise<MCPTool[]>>(async (mcpServer) => {
|
||||
try {
|
||||
const tools = await window.api.mcp.listTools(mcpServer)
|
||||
return tools.filter((tool: any) => !mcpServer.disabledTools?.includes(tool.name))
|
||||
} catch (error) {
|
||||
console.error(`Error fetching tools from MCP server ${mcpServer.name}:`, error)
|
||||
return []
|
||||
}
|
||||
})
|
||||
const results = await Promise.allSettled(toolPromises)
|
||||
mcpTools = results
|
||||
.filter((result): result is PromiseFulfilledResult<MCPTool[]> => result.status === 'fulfilled')
|
||||
.map((result) => result.value)
|
||||
.flat()
|
||||
} catch (toolError) {
|
||||
console.error('Error fetching MCP tools:', toolError)
|
||||
}
|
||||
|
||||
// Get MCP tools (Fix duplicate declaration)
|
||||
let mcpTools: MCPTool[] = [] // Initialize as empty array
|
||||
const allMcpServers = store.getState().mcp.servers || []
|
||||
const activedMcpServers = allMcpServers.filter((s) => s.isActive)
|
||||
const assistantMcpServers = assistant.mcpServers || []
|
||||
|
||||
const enabledMCPs = activedMcpServers.filter((server) => assistantMcpServers.some((s) => s.id === server.id))
|
||||
|
||||
if (enabledMCPs && enabledMCPs.length > 0) {
|
||||
try {
|
||||
const toolPromises = enabledMCPs.map<Promise<MCPTool[]>>(async (mcpServer) => {
|
||||
try {
|
||||
const tools = await window.api.mcp.listTools(mcpServer)
|
||||
return tools.filter((tool: any) => !mcpServer.disabledTools?.includes(tool.name))
|
||||
} catch (error) {
|
||||
console.error(`Error fetching tools from MCP server ${mcpServer.name}:`, error)
|
||||
return []
|
||||
}
|
||||
})
|
||||
const results = await Promise.allSettled(toolPromises)
|
||||
mcpTools = results
|
||||
.filter((result): result is PromiseFulfilledResult<MCPTool[]> => result.status === 'fulfilled')
|
||||
.map((result) => result.value)
|
||||
.flat()
|
||||
} catch (toolError) {
|
||||
console.error('Error fetching MCP tools:', toolError)
|
||||
}
|
||||
}
|
||||
|
||||
return { mcpTools }
|
||||
} catch (error) {
|
||||
if (isAbortError(error)) throw error
|
||||
console.error('Tool execution failed:', error)
|
||||
|
||||
// 发送错误状态
|
||||
if (willUseTools) {
|
||||
onChunkReceived({
|
||||
type: ChunkType.EXTERNEL_TOOL_COMPLETE,
|
||||
external_tool: {
|
||||
webSearch: undefined,
|
||||
knowledge: undefined
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return { mcpTools: [] }
|
||||
}
|
||||
|
||||
return mcpTools
|
||||
}
|
||||
|
||||
export async function fetchChatCompletion({
|
||||
messages,
|
||||
assistant,
|
||||
options,
|
||||
onChunkReceived
|
||||
}: {
|
||||
messages: Message[]
|
||||
messages: StreamTextParams['messages']
|
||||
assistant: Assistant
|
||||
onChunkReceived: (chunk: Chunk) => void
|
||||
// TODO
|
||||
// onChunkStatus: (status: 'searching' | 'processing' | 'success' | 'error') => void
|
||||
}) {
|
||||
console.log('fetchChatCompletion', messages, assistant)
|
||||
|
||||
const provider = getAssistantProvider(assistant)
|
||||
const AI = new AiProviderNew(provider)
|
||||
|
||||
// Make sure that 'Clear Context' works for all scenarios including external tool and normal chat.
|
||||
messages = filterContextMessages(messages)
|
||||
|
||||
const lastUserMessage = findLast(messages, (m) => m.role === 'user')
|
||||
const lastAnswer = findLast(messages, (m) => m.role === 'assistant')
|
||||
if (!lastUserMessage) {
|
||||
console.error('fetchChatCompletion returning early: Missing lastUserMessage or lastAnswer')
|
||||
return
|
||||
options: {
|
||||
signal?: AbortSignal
|
||||
timeout?: number
|
||||
headers?: Record<string, string>
|
||||
}
|
||||
// try {
|
||||
// NOTE: The search results are NOT added to the messages sent to the AI here.
|
||||
// They will be retrieved and used by the messageThunk later to create CitationBlocks.
|
||||
const { mcpTools } = await fetchExternalTool(lastUserMessage, assistant, onChunkReceived, lastAnswer)
|
||||
const model = assistant.model || getDefaultModel()
|
||||
onChunkReceived: (chunk: Chunk) => void
|
||||
}) {
|
||||
const provider = getAssistantProvider(assistant)
|
||||
const AI = new AiProviderNew(provider, onChunkReceived)
|
||||
|
||||
const { maxTokens, contextCount } = getAssistantSettings(assistant)
|
||||
const mcpTools = await fetchMcpTools(assistant)
|
||||
|
||||
const filteredMessages = filterUsefulMessages(messages)
|
||||
|
||||
const _messages = filterUserRoleStartMessages(
|
||||
filterEmptyMessages(filterContextMessages(takeRight(filteredMessages, contextCount + 2))) // 取原来几个provider的最大值
|
||||
)
|
||||
|
||||
const enableReasoning =
|
||||
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
|
||||
assistant.settings?.reasoning_effort !== undefined) ||
|
||||
(isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model)))
|
||||
|
||||
const enableWebSearch =
|
||||
(assistant.enableWebSearch && isWebSearchModel(model)) ||
|
||||
isOpenRouterBuiltInWebSearchModel(model) ||
|
||||
model.id.includes('sonar') ||
|
||||
false
|
||||
|
||||
const enableGenerateImage =
|
||||
isGenerateImageModel(model) && (isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage : true)
|
||||
// 使用 transformParameters 模块构建参数
|
||||
const { params: aiSdkParams, modelId } = await buildStreamTextParams(messages, assistant, {
|
||||
mcpTools: mcpTools,
|
||||
requestOptions: options
|
||||
})
|
||||
|
||||
// --- Call AI Completions ---
|
||||
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
if (enableWebSearch) {
|
||||
onChunkReceived({ type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS })
|
||||
}
|
||||
await AI.completions(
|
||||
{
|
||||
callType: 'chat',
|
||||
messages: _messages,
|
||||
assistant,
|
||||
onChunk: onChunkReceived,
|
||||
mcpTools: mcpTools,
|
||||
maxTokens,
|
||||
streamOutput: assistant.settings?.streamOutput || false,
|
||||
enableReasoning,
|
||||
enableWebSearch,
|
||||
enableGenerateImage
|
||||
},
|
||||
{
|
||||
streamOutput: assistant.settings?.streamOutput || false
|
||||
}
|
||||
)
|
||||
await AI.completions(modelId, aiSdkParams, onChunkReceived)
|
||||
}
|
||||
|
||||
interface FetchTranslateProps {
|
||||
|
||||
34
src/renderer/src/services/ConversationService.ts
Normal file
34
src/renderer/src/services/ConversationService.ts
Normal file
@ -0,0 +1,34 @@
|
||||
import { StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { convertMessagesToSdkMessages } from '@renderer/aiCore/transformParameters'
|
||||
import { Assistant, Message } from '@renderer/types'
|
||||
import { isEmpty, takeRight } from 'lodash'
|
||||
|
||||
import { getAssistantSettings, getDefaultModel } from './AssistantService'
|
||||
import {
|
||||
filterContextMessages,
|
||||
filterEmptyMessages,
|
||||
filterUsefulMessages,
|
||||
filterUserRoleStartMessages
|
||||
} from './MessagesService'
|
||||
|
||||
export class ConversationService {
|
||||
static async prepareMessagesForLlm(messages: Message[], assistant: Assistant): Promise<StreamTextParams['messages']> {
|
||||
const { contextCount } = getAssistantSettings(assistant)
|
||||
// This logic is extracted from the original ApiService.fetchChatCompletion
|
||||
const contextMessages = filterContextMessages(messages)
|
||||
const filteredMessages = filterUsefulMessages(contextMessages)
|
||||
// Take the last `contextCount` messages, plus 2 to allow for a final user/assistant exchange.
|
||||
const finalMessages = filterUserRoleStartMessages(
|
||||
filterEmptyMessages(takeRight(filteredMessages, contextCount + 2))
|
||||
)
|
||||
return await convertMessagesToSdkMessages(finalMessages, assistant.model || getDefaultModel())
|
||||
}
|
||||
|
||||
static needsWebSearch(assistant: Assistant): boolean {
|
||||
return !!assistant.webSearchProviderId
|
||||
}
|
||||
|
||||
static needsKnowledgeSearch(assistant: Assistant): boolean {
|
||||
return !isEmpty(assistant.knowledge_bases)
|
||||
}
|
||||
}
|
||||
54
src/renderer/src/services/OrchestrateService.ts
Normal file
54
src/renderer/src/services/OrchestrateService.ts
Normal file
@ -0,0 +1,54 @@
|
||||
import { Assistant, Message } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
import { fetchChatCompletion } from './ApiService'
|
||||
import { ConversationService } from './ConversationService'
|
||||
|
||||
/**
|
||||
* The request object for handling a user message.
|
||||
*/
|
||||
export interface OrchestrationRequest {
|
||||
messages: Message[]
|
||||
assistant: Assistant
|
||||
options: {
|
||||
signal?: AbortSignal
|
||||
timeout?: number
|
||||
headers?: Record<string, string>
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The OrchestrationService is responsible for orchestrating the different services
|
||||
* to handle a user's message. It contains the core logic of the application.
|
||||
*/
|
||||
export class OrchestrationService {
|
||||
constructor() {
|
||||
// In the future, this could be a singleton, but for now, a new instance is fine.
|
||||
// this.conversationService = new ConversationService()
|
||||
}
|
||||
|
||||
/**
|
||||
* This is the core method to handle user messages.
|
||||
* It takes the message context and an events object for callbacks,
|
||||
* and orchestrates the call to the LLM.
|
||||
* The logic is moved from `messageThunk.ts`.
|
||||
* @param request The orchestration request containing messages and assistant info.
|
||||
* @param events A set of callbacks to report progress and results to the UI layer.
|
||||
*/
|
||||
async handleUserMessage(request: OrchestrationRequest, onChunkReceived: (chunk: Chunk) => void) {
|
||||
const { messages, assistant } = request
|
||||
|
||||
try {
|
||||
const llmMessages = await ConversationService.prepareMessagesForLlm(messages, assistant)
|
||||
|
||||
await fetchChatCompletion({
|
||||
messages: llmMessages,
|
||||
assistant: assistant,
|
||||
options: request.options,
|
||||
onChunkReceived
|
||||
})
|
||||
} catch (error: any) {
|
||||
onChunkReceived({ type: ChunkType.ERROR, error })
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,9 +1,9 @@
|
||||
import db from '@renderer/databases'
|
||||
import { autoRenameTopic } from '@renderer/hooks/useTopic'
|
||||
import { fetchChatCompletion } from '@renderer/services/ApiService'
|
||||
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import { NotificationService } from '@renderer/services/NotificationService'
|
||||
import { OrchestrationService } from '@renderer/services/OrchestrateService'
|
||||
import { createStreamProcessor, type StreamProcessorCallbacks } from '@renderer/services/StreamProcessingService'
|
||||
import { estimateMessagesUsage } from '@renderer/services/TokenService'
|
||||
import store from '@renderer/store'
|
||||
@ -829,15 +829,26 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
const streamProcessorCallbacks = createStreamProcessor(callbacks)
|
||||
|
||||
const startTime = Date.now()
|
||||
await fetchChatCompletion({
|
||||
messages: messagesForContext,
|
||||
assistant: assistant,
|
||||
onChunkReceived: streamProcessorCallbacks
|
||||
})
|
||||
const orchestrationService = new OrchestrationService()
|
||||
await orchestrationService.handleUserMessage(
|
||||
{
|
||||
messages: messagesForContext,
|
||||
assistant,
|
||||
options: {
|
||||
timeout: 30000
|
||||
}
|
||||
},
|
||||
streamProcessorCallbacks
|
||||
)
|
||||
} catch (error: any) {
|
||||
console.error('Error fetching chat completion:', error)
|
||||
if (assistantMessage) {
|
||||
callbacks.onError?.(error)
|
||||
console.error('Error in fetchAndProcessAssistantResponseImpl:', error)
|
||||
// The main error handling is now delegated to OrchestrationService,
|
||||
// which calls the `onError` callback. This catch block is for
|
||||
// any errors that might occur outside of that orchestration flow.
|
||||
if (assistantMessage && callbacks.onError) {
|
||||
callbacks.onError(error)
|
||||
} else {
|
||||
// Fallback if callbacks are not even defined yet
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
"@renderer/*": ["src/renderer/src/*"],
|
||||
"@shared/*": ["packages/shared/*"],
|
||||
"@types": ["src/renderer/src/types/index.ts"],
|
||||
"@cherry-studio/ai-core": ["packages/aiCore/src/"]
|
||||
"@cherrystudio/ai-core": ["packages/aiCore/src/"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -960,9 +960,9 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@cherry-studio/ai-core@workspace:*, @cherry-studio/ai-core@workspace:packages/aiCore":
|
||||
"@cherrystudio/ai-core@workspace:*, @cherrystudio/ai-core@workspace:packages/aiCore":
|
||||
version: 0.0.0-use.local
|
||||
resolution: "@cherry-studio/ai-core@workspace:packages/aiCore"
|
||||
resolution: "@cherrystudio/ai-core@workspace:packages/aiCore"
|
||||
dependencies:
|
||||
"@ai-sdk/amazon-bedrock": "npm:^2.2.10"
|
||||
"@ai-sdk/anthropic": "npm:^1.2.12"
|
||||
@ -6392,7 +6392,7 @@ __metadata:
|
||||
"@agentic/tavily": "npm:^7.3.3"
|
||||
"@ant-design/v5-patch-for-react-19": "npm:^1.0.3"
|
||||
"@anthropic-ai/sdk": "npm:^0.41.0"
|
||||
"@cherry-studio/ai-core": "workspace:*"
|
||||
"@cherrystudio/ai-core": "workspace:*"
|
||||
"@cherrystudio/embedjs": "npm:^0.1.31"
|
||||
"@cherrystudio/embedjs-libsql": "npm:^0.1.31"
|
||||
"@cherrystudio/embedjs-loader-csv": "npm:^0.1.31"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user