mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
feat: integrate web search tool and enhance tool handling
- Added `webSearchTool` to facilitate web search functionality within the SDK. - Updated `AiSdkToChunkAdapter` to utilize `BaseTool` for improved type handling. - Refactored `transformParameters` to support `webSearchProviderId` for enhanced web search integration. - Introduced new `BaseTool` type structure to unify tool definitions across the codebase. - Adjusted imports and type definitions to align with the new tool handling logic.
This commit is contained in:
parent
0c4e8228af
commit
da455997ad
@ -21,6 +21,7 @@ export { createModel, type ModelConfig } from './core/models'
|
||||
// ==================== 插件系统 ====================
|
||||
export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './core/plugins'
|
||||
export { createContext, definePlugin, PluginManager } from './core/plugins'
|
||||
export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in'
|
||||
export { PluginEngine } from './core/runtime/pluginEngine'
|
||||
|
||||
// ==================== 低级 API ====================
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
*/
|
||||
|
||||
import { TextStreamPart, ToolSet } from '@cherrystudio/ai-core'
|
||||
import { MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types'
|
||||
import { BaseTool, WebSearchResults, WebSearchSource } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
import { ToolCallChunkHandler } from './chunk/handleTooCallChunk'
|
||||
@ -27,7 +27,7 @@ export class AiSdkToChunkAdapter {
|
||||
toolCallHandler: ToolCallChunkHandler
|
||||
constructor(
|
||||
private onChunk: (chunk: Chunk) => void,
|
||||
private mcpTools: MCPTool[] = []
|
||||
private mcpTools: BaseTool[] = []
|
||||
) {
|
||||
this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools)
|
||||
}
|
||||
|
||||
@ -4,22 +4,15 @@
|
||||
* 提供工具调用相关的处理API,每个交互使用一个新的实例
|
||||
*/
|
||||
|
||||
import { ToolCallUnion, ToolResultUnion, ToolSet } from '@cherrystudio/ai-core/index'
|
||||
import { ToolCallUnion, ToolResultUnion, ToolSet } from '@cherrystudio/ai-core'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { MCPTool, MCPToolResponse } from '@renderer/types'
|
||||
import { BaseTool, MCPToolResponse } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
// import type {
|
||||
// AnthropicSearchOutput,
|
||||
// WebSearchPluginConfig
|
||||
// } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin'
|
||||
|
||||
// 为 Provider 执行的工具创建一个通用类型
|
||||
// 这避免了污染 MCPTool 的定义,同时提供了 UI 显示所需的基本信息
|
||||
type GenericProviderTool = {
|
||||
name: string
|
||||
description: string
|
||||
type: 'provider'
|
||||
}
|
||||
/**
|
||||
* 工具调用处理器类
|
||||
*/
|
||||
@ -32,12 +25,12 @@ export class ToolCallChunkHandler {
|
||||
toolName: string
|
||||
args: any
|
||||
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
|
||||
mcpTool: MCPTool | GenericProviderTool
|
||||
mcpTool: BaseTool
|
||||
}
|
||||
>()
|
||||
constructor(
|
||||
private onChunk: (chunk: Chunk) => void,
|
||||
private mcpTools: MCPTool[]
|
||||
private mcpTools: BaseTool[]
|
||||
) {}
|
||||
|
||||
// /**
|
||||
@ -62,13 +55,14 @@ export class ToolCallChunkHandler {
|
||||
return
|
||||
}
|
||||
|
||||
let tool: MCPTool | GenericProviderTool
|
||||
let tool: BaseTool
|
||||
|
||||
// 根据 providerExecuted 标志区分处理逻辑
|
||||
if (providerExecuted) {
|
||||
// 如果是 Provider 执行的工具(如 web_search)
|
||||
Logger.info(`[ToolCallChunkHandler] Handling provider-executed tool: ${toolName}`)
|
||||
tool = {
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
description: toolName,
|
||||
type: 'provider'
|
||||
|
||||
@ -17,7 +17,7 @@ import {
|
||||
type ProviderSettingsMap,
|
||||
StreamTextParams
|
||||
} from '@cherrystudio/ai-core'
|
||||
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/core/plugins/built-in'
|
||||
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
|
||||
@ -3,7 +3,7 @@ import {
|
||||
LanguageModelV2Middleware,
|
||||
simulateStreamingMiddleware
|
||||
} from '@cherrystudio/ai-core'
|
||||
import type { MCPTool, Model, Provider } from '@renderer/types'
|
||||
import type { BaseTool, Model, Provider } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
|
||||
/**
|
||||
@ -18,7 +18,7 @@ export interface AiSdkMiddlewareConfig {
|
||||
// 是否开启提示词工具调用
|
||||
enableTool?: boolean
|
||||
enableWebSearch?: boolean
|
||||
mcpTools?: MCPTool[]
|
||||
mcpTools?: BaseTool[]
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
34
src/renderer/src/aiCore/tools/WebSearchTool.ts
Normal file
34
src/renderer/src/aiCore/tools/WebSearchTool.ts
Normal file
@ -0,0 +1,34 @@
|
||||
import WebSearchService from '@renderer/services/WebSearchService'
|
||||
import { WebSearchProvider } from '@renderer/types'
|
||||
import aiSdk from 'ai'
|
||||
|
||||
import { AiSdkTool, ToolCallResult } from './types'
|
||||
|
||||
export const webSearchTool = (webSearchProviderId: WebSearchProvider['id'], requestId: string): AiSdkTool => {
|
||||
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||
return {
|
||||
name: 'web_search',
|
||||
description: 'Search the web for information',
|
||||
inputSchema: aiSdk.jsonSchema({
|
||||
type: 'object',
|
||||
properties: {
|
||||
query: { type: 'string', description: 'The query to search for' }
|
||||
},
|
||||
required: ['query']
|
||||
}),
|
||||
execute: async ({ query }): Promise<ToolCallResult> => {
|
||||
try {
|
||||
const response = await webSearchService.processWebsearch(query, requestId)
|
||||
return {
|
||||
success: true,
|
||||
data: response
|
||||
}
|
||||
} catch (error) {
|
||||
return {
|
||||
success: false,
|
||||
data: error
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
8
src/renderer/src/aiCore/tools/types.ts
Normal file
8
src/renderer/src/aiCore/tools/types.ts
Normal file
@ -0,0 +1,8 @@
|
||||
import { Tool } from '@cherrystudio/ai-core'
|
||||
|
||||
export type ToolCallResult = {
|
||||
success: boolean
|
||||
data: any
|
||||
}
|
||||
|
||||
export type AiSdkTool = Tool<any, ToolCallResult>
|
||||
@ -39,6 +39,7 @@ import {
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
|
||||
import { webSearchTool } from './tools/WebSearchTool'
|
||||
// import { jsonSchemaToZod } from 'json-schema-to-zod'
|
||||
import { setupToolsConfig } from './utils/mcp'
|
||||
import { buildProviderOptions } from './utils/options'
|
||||
@ -246,6 +247,7 @@ export async function buildStreamTextParams(
|
||||
mcpTools?: MCPTool[]
|
||||
enableTools?: boolean
|
||||
enableWebSearch?: boolean
|
||||
webSearchProviderId?: string
|
||||
requestOptions?: {
|
||||
signal?: AbortSignal
|
||||
timeout?: number
|
||||
@ -257,7 +259,7 @@ export async function buildStreamTextParams(
|
||||
modelId: string
|
||||
capabilities: { enableReasoning?: boolean; enableWebSearch?: boolean; enableGenerateImage?: boolean }
|
||||
}> {
|
||||
const { mcpTools, enableTools } = options
|
||||
const { mcpTools, enableTools, webSearchProviderId } = options
|
||||
|
||||
const model = assistant.model || getDefaultModel()
|
||||
|
||||
@ -286,6 +288,13 @@ export async function buildStreamTextParams(
|
||||
enableToolUse: enableTools
|
||||
})
|
||||
|
||||
if (webSearchProviderId) {
|
||||
// 生成requestId用于网络搜索工具
|
||||
const requestId = `request_${Date.now()}_${Math.random().toString(36).substring(2, 9)}`
|
||||
|
||||
tools['builtin_web_search'] = webSearchTool(webSearchProviderId, requestId)
|
||||
}
|
||||
|
||||
// 构建真正的 providerOptions
|
||||
const providerOptions = buildProviderOptions(assistant, model, provider, {
|
||||
enableReasoning,
|
||||
|
||||
@ -1,17 +1,11 @@
|
||||
import { aiSdk, Tool } from '@cherrystudio/ai-core'
|
||||
import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
|
||||
import { SYSTEM_PROMPT_THRESHOLD } from '@renderer/config/constant'
|
||||
import { isFunctionCallingModel } from '@renderer/config/models'
|
||||
import { MCPCallToolResponse, MCPTool, MCPToolResponse, Model } from '@renderer/types'
|
||||
import { MCPTool, MCPToolResponse, Model } from '@renderer/types'
|
||||
import { callMCPTool } from '@renderer/utils/mcp-tools'
|
||||
import { JSONSchema7 } from 'json-schema'
|
||||
|
||||
type ToolCallResult = {
|
||||
success: boolean
|
||||
data: MCPCallToolResponse
|
||||
}
|
||||
|
||||
type AiSdkTool = Tool<any, ToolCallResult>
|
||||
|
||||
// Setup tools configuration based on provided parameters
|
||||
export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
||||
tools: Record<string, AiSdkTool>
|
||||
|
||||
@ -36,7 +36,21 @@ interface RequestState {
|
||||
/**
|
||||
* 提供网络搜索相关功能的服务类
|
||||
*/
|
||||
class WebSearchService {
|
||||
export default class WebSearchService {
|
||||
private static instance: WebSearchService
|
||||
private webSearchProviderId: WebSearchProvider['id']
|
||||
|
||||
private constructor(webSearchProviderId: WebSearchProvider['id']) {
|
||||
this.webSearchProviderId = webSearchProviderId
|
||||
}
|
||||
|
||||
public static getInstance(webSearchProviderId: WebSearchProvider['id']): WebSearchService {
|
||||
if (!WebSearchService.instance) {
|
||||
WebSearchService.instance = new WebSearchService(webSearchProviderId)
|
||||
}
|
||||
return WebSearchService.instance
|
||||
}
|
||||
|
||||
/**
|
||||
* 是否暂停
|
||||
*/
|
||||
@ -154,17 +168,16 @@ class WebSearchService {
|
||||
/**
|
||||
* 使用指定的提供商执行网络搜索
|
||||
* @public
|
||||
* @param provider 搜索提供商
|
||||
* @param query 搜索查询
|
||||
* @returns 搜索响应
|
||||
*/
|
||||
public async search(
|
||||
provider: WebSearchProvider,
|
||||
query: string,
|
||||
httpOptions?: RequestInit
|
||||
): Promise<WebSearchProviderResponse> {
|
||||
public async search(query: string, httpOptions?: RequestInit): Promise<WebSearchProviderResponse> {
|
||||
const websearch = this.getWebSearchState()
|
||||
const webSearchEngine = new WebSearchEngineProvider(provider)
|
||||
const webSearchProvider = this.getWebSearchProvider(this.webSearchProviderId)
|
||||
if (!webSearchProvider) {
|
||||
throw new Error(`WebSearchProvider ${this.webSearchProviderId} not found`)
|
||||
}
|
||||
const webSearchEngine = new WebSearchEngineProvider(webSearchProvider)
|
||||
|
||||
let formattedQuery = query
|
||||
// FIXME: 有待商榷,效果一般
|
||||
@ -186,9 +199,9 @@ class WebSearchService {
|
||||
* @param provider 要检查的搜索提供商
|
||||
* @returns 如果提供商可用返回true,否则返回false
|
||||
*/
|
||||
public async checkSearch(provider: WebSearchProvider): Promise<{ valid: boolean; error?: any }> {
|
||||
public async checkSearch(): Promise<{ valid: boolean; error?: any }> {
|
||||
try {
|
||||
const response = await this.search(provider, 'test query')
|
||||
const response = await this.search('test query')
|
||||
Logger.log('[checkSearch] Search response:', response)
|
||||
// 优化的判断条件:检查结果是否有效且没有错误
|
||||
return { valid: response.results !== undefined, error: undefined }
|
||||
@ -423,11 +436,7 @@ class WebSearchService {
|
||||
*
|
||||
* @returns 包含搜索结果的响应对象
|
||||
*/
|
||||
public async processWebsearch(
|
||||
webSearchProvider: WebSearchProvider,
|
||||
extractResults: ExtractResults,
|
||||
requestId: string
|
||||
): Promise<WebSearchProviderResponse> {
|
||||
public async processWebsearch(extractResults: ExtractResults, requestId: string): Promise<WebSearchProviderResponse> {
|
||||
// 重置状态
|
||||
await this.setWebSearchStatus(requestId, { phase: 'default' })
|
||||
|
||||
@ -449,7 +458,7 @@ class WebSearchService {
|
||||
return { query: 'summaries', results: contents }
|
||||
}
|
||||
|
||||
const searchPromises = questions.map((q) => this.search(webSearchProvider, q, { signal }))
|
||||
const searchPromises = questions.map((q) => this.search(q, { signal }))
|
||||
const searchResults = await Promise.allSettled(searchPromises)
|
||||
|
||||
// 统计成功完成的搜索数量
|
||||
@ -532,5 +541,3 @@ class WebSearchService {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default new WebSearchService()
|
||||
|
||||
@ -4,12 +4,7 @@ import type OpenAI from 'openai'
|
||||
import type { CSSProperties } from 'react'
|
||||
|
||||
import type { Message } from './newMessage'
|
||||
|
||||
export type GenericProviderTool = {
|
||||
name: string
|
||||
description: string
|
||||
type: 'provider'
|
||||
}
|
||||
import type { BaseTool, MCPTool } from './tool'
|
||||
|
||||
export type Assistant = {
|
||||
id: string
|
||||
@ -616,15 +611,6 @@ export interface MCPToolInputSchema {
|
||||
properties: Record<string, object>
|
||||
}
|
||||
|
||||
export interface MCPTool {
|
||||
id: string
|
||||
serverId: string
|
||||
serverName: string
|
||||
name: string
|
||||
description?: string
|
||||
inputSchema: MCPToolInputSchema
|
||||
}
|
||||
|
||||
export interface MCPPromptArguments {
|
||||
name: string
|
||||
description?: string
|
||||
@ -661,7 +647,7 @@ export interface MCPConfig {
|
||||
|
||||
interface BaseToolResponse {
|
||||
id: string // unique id
|
||||
tool: MCPTool | GenericProviderTool
|
||||
tool: BaseTool
|
||||
arguments: Record<string, unknown> | undefined
|
||||
status: string // 'invoking' | 'done'
|
||||
response?: any
|
||||
@ -755,3 +741,4 @@ export type S3Config = {
|
||||
}
|
||||
|
||||
export type { Message } from './newMessage'
|
||||
export * from './tool'
|
||||
|
||||
26
src/renderer/src/types/tool.ts
Normal file
26
src/renderer/src/types/tool.ts
Normal file
@ -0,0 +1,26 @@
|
||||
import type { MCPToolInputSchema } from './index'
|
||||
|
||||
export type ToolType = 'builtin' | 'provider' | 'mcp'
|
||||
|
||||
export interface BaseTool {
|
||||
id: string
|
||||
name: string
|
||||
description?: string
|
||||
type: ToolType
|
||||
}
|
||||
|
||||
export interface GenericProviderTool extends BaseTool {
|
||||
type: 'provider'
|
||||
}
|
||||
|
||||
export interface BuiltinTool extends BaseTool {
|
||||
inputSchema: MCPToolInputSchema
|
||||
type: 'builtin'
|
||||
}
|
||||
|
||||
export interface MCPTool extends BaseTool {
|
||||
serverId: string
|
||||
serverName: string
|
||||
inputSchema: MCPToolInputSchema
|
||||
type: 'mcp'
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user