feat: integrate web search tool and enhance tool handling

- Added `webSearchTool` to facilitate web search functionality within the SDK.
- Updated `AiSdkToChunkAdapter` to utilize `BaseTool` for improved type handling.
- Refactored `transformParameters` to support `webSearchProviderId` for enhanced web search integration.
- Introduced new `BaseTool` type structure to unify tool definitions across the codebase.
- Adjusted imports and type definitions to align with the new tool handling logic.
This commit is contained in:
suyao 2025-07-15 22:47:43 +08:00
parent 0c4e8228af
commit da455997ad
No known key found for this signature in database
12 changed files with 120 additions and 60 deletions

View File

@ -21,6 +21,7 @@ export { createModel, type ModelConfig } from './core/models'
// ==================== 插件系统 ====================
export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './core/plugins'
export { createContext, definePlugin, PluginManager } from './core/plugins'
export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in'
export { PluginEngine } from './core/runtime/pluginEngine'
// ==================== 低级 API ====================

View File

@ -4,7 +4,7 @@
*/
import { TextStreamPart, ToolSet } from '@cherrystudio/ai-core'
import { MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types'
import { BaseTool, WebSearchResults, WebSearchSource } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import { ToolCallChunkHandler } from './chunk/handleTooCallChunk'
@ -27,7 +27,7 @@ export class AiSdkToChunkAdapter {
toolCallHandler: ToolCallChunkHandler
constructor(
private onChunk: (chunk: Chunk) => void,
private mcpTools: MCPTool[] = []
private mcpTools: BaseTool[] = []
) {
this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools)
}

View File

@ -4,22 +4,15 @@
* API使
*/
import { ToolCallUnion, ToolResultUnion, ToolSet } from '@cherrystudio/ai-core/index'
import { ToolCallUnion, ToolResultUnion, ToolSet } from '@cherrystudio/ai-core'
import Logger from '@renderer/config/logger'
import { MCPTool, MCPToolResponse } from '@renderer/types'
import { BaseTool, MCPToolResponse } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
// import type {
// AnthropicSearchOutput,
// WebSearchPluginConfig
// } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin'
// 为 Provider 执行的工具创建一个通用类型
// 这避免了污染 MCPTool 的定义,同时提供了 UI 显示所需的基本信息
type GenericProviderTool = {
name: string
description: string
type: 'provider'
}
/**
*
*/
@ -32,12 +25,12 @@ export class ToolCallChunkHandler {
toolName: string
args: any
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
mcpTool: MCPTool | GenericProviderTool
mcpTool: BaseTool
}
>()
constructor(
private onChunk: (chunk: Chunk) => void,
private mcpTools: MCPTool[]
private mcpTools: BaseTool[]
) {}
// /**
@ -62,13 +55,14 @@ export class ToolCallChunkHandler {
return
}
let tool: MCPTool | GenericProviderTool
let tool: BaseTool
// 根据 providerExecuted 标志区分处理逻辑
if (providerExecuted) {
// 如果是 Provider 执行的工具(如 web_search
Logger.info(`[ToolCallChunkHandler] Handling provider-executed tool: ${toolName}`)
tool = {
id: toolCallId,
name: toolName,
description: toolName,
type: 'provider'

View File

@ -17,7 +17,7 @@ import {
type ProviderSettingsMap,
StreamTextParams
} from '@cherrystudio/ai-core'
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/core/plugins/built-in'
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'

View File

@ -3,7 +3,7 @@ import {
LanguageModelV2Middleware,
simulateStreamingMiddleware
} from '@cherrystudio/ai-core'
import type { MCPTool, Model, Provider } from '@renderer/types'
import type { BaseTool, Model, Provider } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
/**
@ -18,7 +18,7 @@ export interface AiSdkMiddlewareConfig {
// 是否开启提示词工具调用
enableTool?: boolean
enableWebSearch?: boolean
mcpTools?: MCPTool[]
mcpTools?: BaseTool[]
}
/**

View File

@ -0,0 +1,34 @@
import WebSearchService from '@renderer/services/WebSearchService'
import { WebSearchProvider } from '@renderer/types'
import aiSdk from 'ai'
import { AiSdkTool, ToolCallResult } from './types'
export const webSearchTool = (webSearchProviderId: WebSearchProvider['id'], requestId: string): AiSdkTool => {
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
return {
name: 'web_search',
description: 'Search the web for information',
inputSchema: aiSdk.jsonSchema({
type: 'object',
properties: {
query: { type: 'string', description: 'The query to search for' }
},
required: ['query']
}),
execute: async ({ query }): Promise<ToolCallResult> => {
try {
const response = await webSearchService.processWebsearch(query, requestId)
return {
success: true,
data: response
}
} catch (error) {
return {
success: false,
data: error
}
}
}
}
}

View File

@ -0,0 +1,8 @@
import { Tool } from '@cherrystudio/ai-core'
export type ToolCallResult = {
success: boolean
data: any
}
export type AiSdkTool = Tool<any, ToolCallResult>

View File

@ -39,6 +39,7 @@ import {
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { defaultTimeout } from '@shared/config/constant'
import { webSearchTool } from './tools/WebSearchTool'
// import { jsonSchemaToZod } from 'json-schema-to-zod'
import { setupToolsConfig } from './utils/mcp'
import { buildProviderOptions } from './utils/options'
@ -246,6 +247,7 @@ export async function buildStreamTextParams(
mcpTools?: MCPTool[]
enableTools?: boolean
enableWebSearch?: boolean
webSearchProviderId?: string
requestOptions?: {
signal?: AbortSignal
timeout?: number
@ -257,7 +259,7 @@ export async function buildStreamTextParams(
modelId: string
capabilities: { enableReasoning?: boolean; enableWebSearch?: boolean; enableGenerateImage?: boolean }
}> {
const { mcpTools, enableTools } = options
const { mcpTools, enableTools, webSearchProviderId } = options
const model = assistant.model || getDefaultModel()
@ -286,6 +288,13 @@ export async function buildStreamTextParams(
enableToolUse: enableTools
})
if (webSearchProviderId) {
// 生成requestId用于网络搜索工具
const requestId = `request_${Date.now()}_${Math.random().toString(36).substring(2, 9)}`
tools['builtin_web_search'] = webSearchTool(webSearchProviderId, requestId)
}
// 构建真正的 providerOptions
const providerOptions = buildProviderOptions(assistant, model, provider, {
enableReasoning,

View File

@ -1,17 +1,11 @@
import { aiSdk, Tool } from '@cherrystudio/ai-core'
import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
import { SYSTEM_PROMPT_THRESHOLD } from '@renderer/config/constant'
import { isFunctionCallingModel } from '@renderer/config/models'
import { MCPCallToolResponse, MCPTool, MCPToolResponse, Model } from '@renderer/types'
import { MCPTool, MCPToolResponse, Model } from '@renderer/types'
import { callMCPTool } from '@renderer/utils/mcp-tools'
import { JSONSchema7 } from 'json-schema'
type ToolCallResult = {
success: boolean
data: MCPCallToolResponse
}
type AiSdkTool = Tool<any, ToolCallResult>
// Setup tools configuration based on provided parameters
export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
tools: Record<string, AiSdkTool>

View File

@ -36,7 +36,21 @@ interface RequestState {
/**
*
*/
class WebSearchService {
export default class WebSearchService {
private static instance: WebSearchService
private webSearchProviderId: WebSearchProvider['id']
private constructor(webSearchProviderId: WebSearchProvider['id']) {
this.webSearchProviderId = webSearchProviderId
}
public static getInstance(webSearchProviderId: WebSearchProvider['id']): WebSearchService {
if (!WebSearchService.instance) {
WebSearchService.instance = new WebSearchService(webSearchProviderId)
}
return WebSearchService.instance
}
/**
*
*/
@ -154,17 +168,16 @@ class WebSearchService {
/**
* 使
* @public
* @param provider
* @param query
* @returns
*/
public async search(
provider: WebSearchProvider,
query: string,
httpOptions?: RequestInit
): Promise<WebSearchProviderResponse> {
public async search(query: string, httpOptions?: RequestInit): Promise<WebSearchProviderResponse> {
const websearch = this.getWebSearchState()
const webSearchEngine = new WebSearchEngineProvider(provider)
const webSearchProvider = this.getWebSearchProvider(this.webSearchProviderId)
if (!webSearchProvider) {
throw new Error(`WebSearchProvider ${this.webSearchProviderId} not found`)
}
const webSearchEngine = new WebSearchEngineProvider(webSearchProvider)
let formattedQuery = query
// FIXME: 有待商榷,效果一般
@ -186,9 +199,9 @@ class WebSearchService {
* @param provider
* @returns truefalse
*/
public async checkSearch(provider: WebSearchProvider): Promise<{ valid: boolean; error?: any }> {
public async checkSearch(): Promise<{ valid: boolean; error?: any }> {
try {
const response = await this.search(provider, 'test query')
const response = await this.search('test query')
Logger.log('[checkSearch] Search response:', response)
// 优化的判断条件:检查结果是否有效且没有错误
return { valid: response.results !== undefined, error: undefined }
@ -423,11 +436,7 @@ class WebSearchService {
*
* @returns
*/
public async processWebsearch(
webSearchProvider: WebSearchProvider,
extractResults: ExtractResults,
requestId: string
): Promise<WebSearchProviderResponse> {
public async processWebsearch(extractResults: ExtractResults, requestId: string): Promise<WebSearchProviderResponse> {
// 重置状态
await this.setWebSearchStatus(requestId, { phase: 'default' })
@ -449,7 +458,7 @@ class WebSearchService {
return { query: 'summaries', results: contents }
}
const searchPromises = questions.map((q) => this.search(webSearchProvider, q, { signal }))
const searchPromises = questions.map((q) => this.search(q, { signal }))
const searchResults = await Promise.allSettled(searchPromises)
// 统计成功完成的搜索数量
@ -532,5 +541,3 @@ class WebSearchService {
}
}
}
export default new WebSearchService()

View File

@ -4,12 +4,7 @@ import type OpenAI from 'openai'
import type { CSSProperties } from 'react'
import type { Message } from './newMessage'
export type GenericProviderTool = {
name: string
description: string
type: 'provider'
}
import type { BaseTool, MCPTool } from './tool'
export type Assistant = {
id: string
@ -616,15 +611,6 @@ export interface MCPToolInputSchema {
properties: Record<string, object>
}
export interface MCPTool {
id: string
serverId: string
serverName: string
name: string
description?: string
inputSchema: MCPToolInputSchema
}
export interface MCPPromptArguments {
name: string
description?: string
@ -661,7 +647,7 @@ export interface MCPConfig {
interface BaseToolResponse {
id: string // unique id
tool: MCPTool | GenericProviderTool
tool: BaseTool
arguments: Record<string, unknown> | undefined
status: string // 'invoking' | 'done'
response?: any
@ -755,3 +741,4 @@ export type S3Config = {
}
export type { Message } from './newMessage'
export * from './tool'

View File

@ -0,0 +1,26 @@
import type { MCPToolInputSchema } from './index'
export type ToolType = 'builtin' | 'provider' | 'mcp'
export interface BaseTool {
id: string
name: string
description?: string
type: ToolType
}
export interface GenericProviderTool extends BaseTool {
type: 'provider'
}
export interface BuiltinTool extends BaseTool {
inputSchema: MCPToolInputSchema
type: 'builtin'
}
export interface MCPTool extends BaseTool {
serverId: string
serverName: string
inputSchema: MCPToolInputSchema
type: 'mcp'
}