mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
feat(toolUsePlugin): enhance tool parsing and extraction functionality
- Updated the `defaultParseToolUse` function to return both parsed results and remaining content, improving usability. - Introduced a new `TagExtractor` class for flexible tag extraction, supporting various tag formats. - Modified type definitions to reflect changes in parsing function signatures. - Enhanced handling of tool call events in the `ToolCallChunkHandler` for better integration with the new parsing logic. - Added `isBuiltIn` property to the `MCPTool` interface for clearer tool categorization.
This commit is contained in:
parent
c9c0616c91
commit
d93a36e5c9
@ -198,9 +198,9 @@ function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): str
|
||||
* 默认工具解析函数(提取自 Cherry Studio)
|
||||
* 解析 XML 格式的工具调用
|
||||
*/
|
||||
function defaultParseToolUse(content: string, tools: ToolSet): ToolUseResult[] {
|
||||
function defaultParseToolUse(content: string, tools: ToolSet): { results: ToolUseResult[]; content: string } {
|
||||
if (!content || !tools || Object.keys(tools).length === 0) {
|
||||
return []
|
||||
return { results: [], content: content }
|
||||
}
|
||||
|
||||
// 支持两种格式:
|
||||
@ -208,7 +208,6 @@ function defaultParseToolUse(content: string, tools: ToolSet): ToolUseResult[] {
|
||||
// 2. 只有内部内容(从 TagExtractor 提取出来的)
|
||||
|
||||
let contentToProcess = content
|
||||
|
||||
// 如果内容不包含 <tool_use> 标签,说明是从 TagExtractor 提取的内部内容,需要包装
|
||||
if (!content.includes('<tool_use>')) {
|
||||
contentToProcess = `<tool_use>\n${content}\n</tool_use>`
|
||||
@ -222,6 +221,7 @@ function defaultParseToolUse(content: string, tools: ToolSet): ToolUseResult[] {
|
||||
|
||||
// Find all tool use blocks
|
||||
while ((match = toolUsePattern.exec(contentToProcess)) !== null) {
|
||||
const fullMatch = match[0]
|
||||
const toolName = match[2].trim()
|
||||
const toolArgs = match[4].trim()
|
||||
|
||||
@ -248,8 +248,9 @@ function defaultParseToolUse(content: string, tools: ToolSet): ToolUseResult[] {
|
||||
arguments: parsedArgs,
|
||||
status: 'pending'
|
||||
})
|
||||
contentToProcess = contentToProcess.replace(fullMatch, '')
|
||||
}
|
||||
return results
|
||||
return { results, content: contentToProcess }
|
||||
}
|
||||
|
||||
export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
@ -262,7 +263,6 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
return params
|
||||
}
|
||||
|
||||
// 直接存储工具信息到 context 上,利用改进的插件引擎
|
||||
context.mcpTools = params.tools
|
||||
console.log('tools stored in context', params.tools)
|
||||
|
||||
@ -309,7 +309,7 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'finish-step') {
|
||||
if (chunk.type === 'text-end' || chunk.type === 'finish-step') {
|
||||
// console.log('[MCP Prompt Stream] Received step-finish, checking for tool use...')
|
||||
|
||||
// 从 context 获取工具信息
|
||||
@ -322,7 +322,7 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
}
|
||||
|
||||
// 解析工具调用
|
||||
const parsedTools = parseToolUse(textBuffer, tools)
|
||||
const { results: parsedTools, content: parsedContent } = parseToolUse(textBuffer, tools)
|
||||
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
||||
// console.log('parsedTools', parsedTools)
|
||||
|
||||
@ -332,7 +332,23 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
if (chunk.type === 'text-end') {
|
||||
controller.enqueue({
|
||||
type: 'text-end',
|
||||
id: stepId,
|
||||
providerMetadata: {
|
||||
text: {
|
||||
value: parsedContent
|
||||
}
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
finishReason: 'tool-calls'
|
||||
})
|
||||
// console.log('[MCP Prompt Stream] Found valid tool uses:', validToolUses.length)
|
||||
|
||||
// 发送 step-start 事件(工具调用步骤开始)
|
||||
@ -350,8 +366,15 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
if (!tool || typeof tool.execute !== 'function') {
|
||||
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
|
||||
}
|
||||
// 发送 tool-input-start 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-input-start',
|
||||
id: toolUse.id,
|
||||
toolName: toolUse.toolName
|
||||
})
|
||||
|
||||
console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments)
|
||||
console.log('toolUse,toolUse', toolUse)
|
||||
// 发送 tool-call 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-call',
|
||||
@ -423,7 +446,7 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
// 发送最终的 step-finish 事件
|
||||
controller.enqueue({
|
||||
type: 'finish-step',
|
||||
finishReason: 'tool-calls',
|
||||
finishReason: 'stop',
|
||||
response: chunk.response,
|
||||
usage: chunk.usage,
|
||||
providerMetadata: chunk.providerMetadata
|
||||
|
||||
@ -0,0 +1,196 @@
|
||||
// Copied from https://github.com/vercel/ai/blob/main/packages/ai/core/util/get-potential-start-index.ts
|
||||
|
||||
/**
|
||||
* Returns the index of the start of the searchedText in the text, or null if it
|
||||
* is not found.
|
||||
*/
|
||||
export function getPotentialStartIndex(text: string, searchedText: string): number | null {
|
||||
// Return null immediately if searchedText is empty.
|
||||
if (searchedText.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Check if the searchedText exists as a direct substring of text.
|
||||
const directIndex = text.indexOf(searchedText)
|
||||
if (directIndex !== -1) {
|
||||
return directIndex
|
||||
}
|
||||
|
||||
// Otherwise, look for the largest suffix of "text" that matches
|
||||
// a prefix of "searchedText". We go from the end of text inward.
|
||||
for (let i = text.length - 1; i >= 0; i--) {
|
||||
const suffix = text.substring(i)
|
||||
if (searchedText.startsWith(suffix)) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
export interface TagConfig {
|
||||
openingTag: string
|
||||
closingTag: string
|
||||
separator?: string
|
||||
}
|
||||
|
||||
export interface TagExtractionState {
|
||||
textBuffer: string
|
||||
isInsideTag: boolean
|
||||
isFirstTag: boolean
|
||||
isFirstText: boolean
|
||||
afterSwitch: boolean
|
||||
accumulatedTagContent: string
|
||||
hasTagContent: boolean
|
||||
}
|
||||
|
||||
export interface TagExtractionResult {
|
||||
content: string
|
||||
isTagContent: boolean
|
||||
complete: boolean
|
||||
tagContentExtracted?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用标签提取处理器
|
||||
* 可以处理各种形式的标签对,如 <think>...</think>, <tool_use>...</tool_use> 等
|
||||
*/
|
||||
export class TagExtractor {
|
||||
private config: TagConfig
|
||||
private state: TagExtractionState
|
||||
|
||||
constructor(config: TagConfig) {
|
||||
this.config = config
|
||||
this.state = {
|
||||
textBuffer: '',
|
||||
isInsideTag: false,
|
||||
isFirstTag: true,
|
||||
isFirstText: true,
|
||||
afterSwitch: false,
|
||||
accumulatedTagContent: '',
|
||||
hasTagContent: false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理文本块,返回处理结果
|
||||
*/
|
||||
processText(newText: string): TagExtractionResult[] {
|
||||
this.state.textBuffer += newText
|
||||
const results: TagExtractionResult[] = []
|
||||
|
||||
// 处理标签提取逻辑
|
||||
while (true) {
|
||||
const nextTag = this.state.isInsideTag ? this.config.closingTag : this.config.openingTag
|
||||
const startIndex = getPotentialStartIndex(this.state.textBuffer, nextTag)
|
||||
|
||||
if (startIndex == null) {
|
||||
const content = this.state.textBuffer
|
||||
if (content.length > 0) {
|
||||
results.push({
|
||||
content: this.addPrefix(content),
|
||||
isTagContent: this.state.isInsideTag,
|
||||
complete: false
|
||||
})
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.accumulatedTagContent += this.addPrefix(content)
|
||||
this.state.hasTagContent = true
|
||||
}
|
||||
}
|
||||
this.state.textBuffer = ''
|
||||
break
|
||||
}
|
||||
|
||||
// 处理标签前的内容
|
||||
const contentBeforeTag = this.state.textBuffer.slice(0, startIndex)
|
||||
if (contentBeforeTag.length > 0) {
|
||||
results.push({
|
||||
content: this.addPrefix(contentBeforeTag),
|
||||
isTagContent: this.state.isInsideTag,
|
||||
complete: false
|
||||
})
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.accumulatedTagContent += this.addPrefix(contentBeforeTag)
|
||||
this.state.hasTagContent = true
|
||||
}
|
||||
}
|
||||
|
||||
const foundFullMatch = startIndex + nextTag.length <= this.state.textBuffer.length
|
||||
|
||||
if (foundFullMatch) {
|
||||
// 如果找到完整的标签
|
||||
this.state.textBuffer = this.state.textBuffer.slice(startIndex + nextTag.length)
|
||||
|
||||
// 如果刚刚结束一个标签内容,生成完整的标签内容结果
|
||||
if (this.state.isInsideTag && this.state.hasTagContent) {
|
||||
results.push({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: this.state.accumulatedTagContent
|
||||
})
|
||||
this.state.accumulatedTagContent = ''
|
||||
this.state.hasTagContent = false
|
||||
}
|
||||
|
||||
this.state.isInsideTag = !this.state.isInsideTag
|
||||
this.state.afterSwitch = true
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.isFirstTag = false
|
||||
} else {
|
||||
this.state.isFirstText = false
|
||||
}
|
||||
} else {
|
||||
this.state.textBuffer = this.state.textBuffer.slice(startIndex)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
/**
|
||||
* 完成处理,返回任何剩余的标签内容
|
||||
*/
|
||||
finalize(): TagExtractionResult | null {
|
||||
if (this.state.hasTagContent && this.state.accumulatedTagContent) {
|
||||
const result = {
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: this.state.accumulatedTagContent
|
||||
}
|
||||
this.state.accumulatedTagContent = ''
|
||||
this.state.hasTagContent = false
|
||||
return result
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
private addPrefix(text: string): string {
|
||||
const needsPrefix =
|
||||
this.state.afterSwitch && (this.state.isInsideTag ? !this.state.isFirstTag : !this.state.isFirstText)
|
||||
|
||||
const prefix = needsPrefix && this.config.separator ? this.config.separator : ''
|
||||
this.state.afterSwitch = false
|
||||
return prefix + text
|
||||
}
|
||||
|
||||
/**
|
||||
* 重置状态
|
||||
*/
|
||||
reset(): void {
|
||||
this.state = {
|
||||
textBuffer: '',
|
||||
isInsideTag: false,
|
||||
isFirstTag: true,
|
||||
isFirstText: true,
|
||||
afterSwitch: false,
|
||||
accumulatedTagContent: '',
|
||||
hasTagContent: false
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -21,7 +21,7 @@ export interface PromptToolUseConfig extends BaseToolUsePluginConfig {
|
||||
// 自定义系统提示符构建函数(可选,有默认实现)
|
||||
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string
|
||||
// 自定义工具解析函数(可选,有默认实现)
|
||||
parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[]
|
||||
parseToolUse?: (content: string, tools: ToolSet) => { results: ToolUseResult[]; content: string }
|
||||
createSystemMessage?: (systemPrompt: string, originalParams: any, context: AiRequestContext) => string | null
|
||||
}
|
||||
|
||||
|
||||
@ -77,12 +77,15 @@ export type {
|
||||
// 生成相关类型
|
||||
GenerateTextResult,
|
||||
ImagePart,
|
||||
InferToolInput,
|
||||
InferToolOutput,
|
||||
InvalidToolInputError,
|
||||
LanguageModelUsage, // AI SDK 4.0 中 TokenUsage 改名为 LanguageModelUsage
|
||||
// 消息相关类型
|
||||
ModelMessage,
|
||||
// 错误类型
|
||||
NoSuchToolError,
|
||||
ProviderMetadata,
|
||||
StreamTextResult,
|
||||
SystemModelMessage,
|
||||
TextPart,
|
||||
|
||||
@ -101,7 +101,7 @@ export class AiSdkToChunkAdapter {
|
||||
case 'text-end':
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: final.text || ''
|
||||
text: (chunk.providerMetadata?.text?.value as string) || final.text || ''
|
||||
})
|
||||
final.text = ''
|
||||
break
|
||||
|
||||
@ -4,11 +4,10 @@
|
||||
* 提供工具调用相关的处理API,每个交互使用一个新的实例
|
||||
*/
|
||||
|
||||
import { ToolSet, TypedToolCall, TypedToolResult } from '@cherrystudio/ai-core'
|
||||
import { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import { BaseTool, MCPToolResponse, ToolCallResponse } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import { type ProviderMetadata } from 'ai'
|
||||
// import type {
|
||||
// AnthropicSearchOutput,
|
||||
// WebSearchPluginConfig
|
||||
@ -67,7 +66,7 @@ export class ToolCallChunkHandler {
|
||||
switch (chunk.type) {
|
||||
case 'tool-input-start': {
|
||||
// 能拿到说明是mcpTool
|
||||
if (this.activeToolCalls.get(chunk.id)) return
|
||||
// if (this.activeToolCalls.get(chunk.id)) return
|
||||
|
||||
const tool: BaseTool = {
|
||||
id: chunk.id,
|
||||
@ -81,6 +80,17 @@ export class ToolCallChunkHandler {
|
||||
args: '',
|
||||
tool
|
||||
})
|
||||
const toolResponse: ToolCallResponse = {
|
||||
id: chunk.id,
|
||||
tool: tool,
|
||||
arguments: {},
|
||||
status: 'pending',
|
||||
toolCallId: chunk.id
|
||||
}
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_PENDING,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'tool-input-delta': {
|
||||
@ -99,18 +109,18 @@ export class ToolCallChunkHandler {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
return
|
||||
}
|
||||
const toolResponse: ToolCallResponse = {
|
||||
id: toolCall.toolCallId,
|
||||
tool: toolCall.tool,
|
||||
arguments: toolCall.args,
|
||||
status: 'pending',
|
||||
toolCallId: toolCall.toolCallId
|
||||
}
|
||||
logger.debug('toolResponse', toolResponse)
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_PENDING,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
// const toolResponse: ToolCallResponse = {
|
||||
// id: toolCall.toolCallId,
|
||||
// tool: toolCall.tool,
|
||||
// arguments: toolCall.args,
|
||||
// status: 'pending',
|
||||
// toolCallId: toolCall.toolCallId
|
||||
// }
|
||||
// logger.debug('toolResponse', toolResponse)
|
||||
// this.onChunk({
|
||||
// type: ChunkType.MCP_TOOL_PENDING,
|
||||
// responses: [toolResponse]
|
||||
// })
|
||||
break
|
||||
}
|
||||
}
|
||||
@ -227,7 +237,7 @@ export class ToolCallChunkHandler {
|
||||
|
||||
// 创建工具调用结果的 MCPToolResponse 格式
|
||||
const toolResponse: MCPToolResponse = {
|
||||
id: toolCallId,
|
||||
id: toolCallInfo.toolCallId,
|
||||
tool: toolCallInfo.tool,
|
||||
arguments: input,
|
||||
status: 'done',
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import { aiSdk, InferToolInput, InferToolOutput } from '@cherrystudio/ai-core'
|
||||
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
||||
import WebSearchService from '@renderer/services/WebSearchService'
|
||||
import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
|
||||
import { ExtractResults } from '@renderer/utils/extract'
|
||||
import { InferToolInput, InferToolOutput, tool } from 'ai'
|
||||
import { z } from 'zod'
|
||||
|
||||
const { tool } = aiSdk
|
||||
|
||||
/**
|
||||
* 使用预提取关键词的网络搜索工具
|
||||
* 这个工具直接使用插件阶段分析的搜索意图,避免重复分析
|
||||
@ -46,16 +48,13 @@ Call this tool to execute the search. You can optionally provide additional cont
|
||||
|
||||
if (additionalContext?.trim()) {
|
||||
// 如果大模型提供了额外上下文,使用更具体的描述
|
||||
console.log(`🔍 AI enhanced search with: ${additionalContext}`)
|
||||
const cleanContext = additionalContext.trim()
|
||||
if (cleanContext) {
|
||||
finalQueries = [cleanContext]
|
||||
console.log(`➕ Added additional context: ${cleanContext}`)
|
||||
}
|
||||
}
|
||||
|
||||
const searchResults: WebSearchProviderResponse[] = []
|
||||
|
||||
// 检查是否需要搜索
|
||||
if (finalQueries[0] === 'not_needed') {
|
||||
return {
|
||||
@ -75,12 +74,9 @@ Call this tool to execute the search. You can optionally provide additional cont
|
||||
links: extractedKeywords.links
|
||||
}
|
||||
}
|
||||
console.log('extractResults', extractResults)
|
||||
console.log('webSearchProvider', webSearchProvider)
|
||||
const response = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
|
||||
searchResults.push(response)
|
||||
} catch (error) {
|
||||
console.error(`Web search failed for query "${finalQueries}":`, error)
|
||||
return {
|
||||
summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
searchResults: [],
|
||||
|
||||
@ -8,6 +8,7 @@ export const thinkTool: MCPTool = {
|
||||
description:
|
||||
'Use the tool to think about something. It will not obtain new information or make any changes to the repository, but just log the thought. Use it when complex reasoning or brainstorming is needed. For example, if you explore the repo and discover the source of a bug, call this tool to brainstorm several unique ways of fixing the bug, and assess which change(s) are likely to be simplest and most effective. Alternatively, if you receive some test results, call this tool to brainstorm ways to fix the failing tests.',
|
||||
isBuiltIn: true,
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
title: 'Think Tool Input',
|
||||
|
||||
@ -44,4 +44,5 @@ export interface MCPTool extends BaseTool {
|
||||
inputSchema: MCPToolInputSchema
|
||||
outputSchema?: z.infer<typeof MCPToolOutputSchema>
|
||||
type: 'mcp'
|
||||
isBuiltIn?: boolean
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user