feat(toolUsePlugin): enhance tool parsing and extraction functionality

- Updated the `defaultParseToolUse` function to return both parsed results and remaining content, improving usability.
- Introduced a new `TagExtractor` class for flexible tag extraction, supporting various tag formats.
- Modified type definitions to reflect changes in parsing function signatures.
- Enhanced handling of tool call events in the `ToolCallChunkHandler` for better integration with the new parsing logic.
- Added `isBuiltIn` property to the `MCPTool` interface for clearer tool categorization.
This commit is contained in:
lizhixuan 2025-08-19 00:44:30 +08:00
parent c9c0616c91
commit d93a36e5c9
9 changed files with 263 additions and 33 deletions

View File

@ -198,9 +198,9 @@ function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): str
* Cherry Studio
* XML
*/
function defaultParseToolUse(content: string, tools: ToolSet): ToolUseResult[] {
function defaultParseToolUse(content: string, tools: ToolSet): { results: ToolUseResult[]; content: string } {
if (!content || !tools || Object.keys(tools).length === 0) {
return []
return { results: [], content: content }
}
// 支持两种格式:
@ -208,7 +208,6 @@ function defaultParseToolUse(content: string, tools: ToolSet): ToolUseResult[] {
// 2. 只有内部内容(从 TagExtractor 提取出来的)
let contentToProcess = content
// 如果内容不包含 <tool_use> 标签,说明是从 TagExtractor 提取的内部内容,需要包装
if (!content.includes('<tool_use>')) {
contentToProcess = `<tool_use>\n${content}\n</tool_use>`
@ -222,6 +221,7 @@ function defaultParseToolUse(content: string, tools: ToolSet): ToolUseResult[] {
// Find all tool use blocks
while ((match = toolUsePattern.exec(contentToProcess)) !== null) {
const fullMatch = match[0]
const toolName = match[2].trim()
const toolArgs = match[4].trim()
@ -248,8 +248,9 @@ function defaultParseToolUse(content: string, tools: ToolSet): ToolUseResult[] {
arguments: parsedArgs,
status: 'pending'
})
contentToProcess = contentToProcess.replace(fullMatch, '')
}
return results
return { results, content: contentToProcess }
}
export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
@ -262,7 +263,6 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
return params
}
// 直接存储工具信息到 context 上,利用改进的插件引擎
context.mcpTools = params.tools
console.log('tools stored in context', params.tools)
@ -309,7 +309,7 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
return
}
if (chunk.type === 'finish-step') {
if (chunk.type === 'text-end' || chunk.type === 'finish-step') {
// console.log('[MCP Prompt Stream] Received step-finish, checking for tool use...')
// 从 context 获取工具信息
@ -322,7 +322,7 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
}
// 解析工具调用
const parsedTools = parseToolUse(textBuffer, tools)
const { results: parsedTools, content: parsedContent } = parseToolUse(textBuffer, tools)
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
// console.log('parsedTools', parsedTools)
@ -332,7 +332,23 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
controller.enqueue(chunk)
return
}
if (chunk.type === 'text-end') {
controller.enqueue({
type: 'text-end',
id: stepId,
providerMetadata: {
text: {
value: parsedContent
}
}
})
return
}
controller.enqueue({
...chunk,
finishReason: 'tool-calls'
})
// console.log('[MCP Prompt Stream] Found valid tool uses:', validToolUses.length)
// 发送 step-start 事件(工具调用步骤开始)
@ -350,8 +366,15 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
if (!tool || typeof tool.execute !== 'function') {
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
}
// 发送 tool-input-start 事件
controller.enqueue({
type: 'tool-input-start',
id: toolUse.id,
toolName: toolUse.toolName
})
console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments)
console.log('toolUse,toolUse', toolUse)
// 发送 tool-call 事件
controller.enqueue({
type: 'tool-call',
@ -423,7 +446,7 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
// 发送最终的 step-finish 事件
controller.enqueue({
type: 'finish-step',
finishReason: 'tool-calls',
finishReason: 'stop',
response: chunk.response,
usage: chunk.usage,
providerMetadata: chunk.providerMetadata

View File

@ -0,0 +1,196 @@
// Copied from https://github.com/vercel/ai/blob/main/packages/ai/core/util/get-potential-start-index.ts
/**
* Returns the index of the start of the searchedText in the text, or null if it
* is not found.
*/
export function getPotentialStartIndex(text: string, searchedText: string): number | null {
// Return null immediately if searchedText is empty.
if (searchedText.length === 0) {
return null
}
// Check if the searchedText exists as a direct substring of text.
const directIndex = text.indexOf(searchedText)
if (directIndex !== -1) {
return directIndex
}
// Otherwise, look for the largest suffix of "text" that matches
// a prefix of "searchedText". We go from the end of text inward.
for (let i = text.length - 1; i >= 0; i--) {
const suffix = text.substring(i)
if (searchedText.startsWith(suffix)) {
return i
}
}
return null
}
export interface TagConfig {
openingTag: string
closingTag: string
separator?: string
}
export interface TagExtractionState {
textBuffer: string
isInsideTag: boolean
isFirstTag: boolean
isFirstText: boolean
afterSwitch: boolean
accumulatedTagContent: string
hasTagContent: boolean
}
export interface TagExtractionResult {
content: string
isTagContent: boolean
complete: boolean
tagContentExtracted?: string
}
/**
*
* <think>...</think>, <tool_use>...</tool_use>
*/
export class TagExtractor {
private config: TagConfig
private state: TagExtractionState
constructor(config: TagConfig) {
this.config = config
this.state = {
textBuffer: '',
isInsideTag: false,
isFirstTag: true,
isFirstText: true,
afterSwitch: false,
accumulatedTagContent: '',
hasTagContent: false
}
}
/**
*
*/
processText(newText: string): TagExtractionResult[] {
this.state.textBuffer += newText
const results: TagExtractionResult[] = []
// 处理标签提取逻辑
while (true) {
const nextTag = this.state.isInsideTag ? this.config.closingTag : this.config.openingTag
const startIndex = getPotentialStartIndex(this.state.textBuffer, nextTag)
if (startIndex == null) {
const content = this.state.textBuffer
if (content.length > 0) {
results.push({
content: this.addPrefix(content),
isTagContent: this.state.isInsideTag,
complete: false
})
if (this.state.isInsideTag) {
this.state.accumulatedTagContent += this.addPrefix(content)
this.state.hasTagContent = true
}
}
this.state.textBuffer = ''
break
}
// 处理标签前的内容
const contentBeforeTag = this.state.textBuffer.slice(0, startIndex)
if (contentBeforeTag.length > 0) {
results.push({
content: this.addPrefix(contentBeforeTag),
isTagContent: this.state.isInsideTag,
complete: false
})
if (this.state.isInsideTag) {
this.state.accumulatedTagContent += this.addPrefix(contentBeforeTag)
this.state.hasTagContent = true
}
}
const foundFullMatch = startIndex + nextTag.length <= this.state.textBuffer.length
if (foundFullMatch) {
// 如果找到完整的标签
this.state.textBuffer = this.state.textBuffer.slice(startIndex + nextTag.length)
// 如果刚刚结束一个标签内容,生成完整的标签内容结果
if (this.state.isInsideTag && this.state.hasTagContent) {
results.push({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: this.state.accumulatedTagContent
})
this.state.accumulatedTagContent = ''
this.state.hasTagContent = false
}
this.state.isInsideTag = !this.state.isInsideTag
this.state.afterSwitch = true
if (this.state.isInsideTag) {
this.state.isFirstTag = false
} else {
this.state.isFirstText = false
}
} else {
this.state.textBuffer = this.state.textBuffer.slice(startIndex)
break
}
}
return results
}
/**
*
*/
finalize(): TagExtractionResult | null {
if (this.state.hasTagContent && this.state.accumulatedTagContent) {
const result = {
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: this.state.accumulatedTagContent
}
this.state.accumulatedTagContent = ''
this.state.hasTagContent = false
return result
}
return null
}
private addPrefix(text: string): string {
const needsPrefix =
this.state.afterSwitch && (this.state.isInsideTag ? !this.state.isFirstTag : !this.state.isFirstText)
const prefix = needsPrefix && this.config.separator ? this.config.separator : ''
this.state.afterSwitch = false
return prefix + text
}
/**
*
*/
reset(): void {
this.state = {
textBuffer: '',
isInsideTag: false,
isFirstTag: true,
isFirstText: true,
afterSwitch: false,
accumulatedTagContent: '',
hasTagContent: false
}
}
}

View File

@ -21,7 +21,7 @@ export interface PromptToolUseConfig extends BaseToolUsePluginConfig {
// 自定义系统提示符构建函数(可选,有默认实现)
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string
// 自定义工具解析函数(可选,有默认实现)
parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[]
parseToolUse?: (content: string, tools: ToolSet) => { results: ToolUseResult[]; content: string }
createSystemMessage?: (systemPrompt: string, originalParams: any, context: AiRequestContext) => string | null
}

View File

@ -77,12 +77,15 @@ export type {
// 生成相关类型
GenerateTextResult,
ImagePart,
InferToolInput,
InferToolOutput,
InvalidToolInputError,
LanguageModelUsage, // AI SDK 4.0 中 TokenUsage 改名为 LanguageModelUsage
// 消息相关类型
ModelMessage,
// 错误类型
NoSuchToolError,
ProviderMetadata,
StreamTextResult,
SystemModelMessage,
TextPart,

View File

@ -101,7 +101,7 @@ export class AiSdkToChunkAdapter {
case 'text-end':
this.onChunk({
type: ChunkType.TEXT_COMPLETE,
text: final.text || ''
text: (chunk.providerMetadata?.text?.value as string) || final.text || ''
})
final.text = ''
break

View File

@ -4,11 +4,10 @@
* API使
*/
import { ToolSet, TypedToolCall, TypedToolResult } from '@cherrystudio/ai-core'
import { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import { BaseTool, MCPToolResponse, ToolCallResponse } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import { type ProviderMetadata } from 'ai'
// import type {
// AnthropicSearchOutput,
// WebSearchPluginConfig
@ -67,7 +66,7 @@ export class ToolCallChunkHandler {
switch (chunk.type) {
case 'tool-input-start': {
// 能拿到说明是mcpTool
if (this.activeToolCalls.get(chunk.id)) return
// if (this.activeToolCalls.get(chunk.id)) return
const tool: BaseTool = {
id: chunk.id,
@ -81,6 +80,17 @@ export class ToolCallChunkHandler {
args: '',
tool
})
const toolResponse: ToolCallResponse = {
id: chunk.id,
tool: tool,
arguments: {},
status: 'pending',
toolCallId: chunk.id
}
this.onChunk({
type: ChunkType.MCP_TOOL_PENDING,
responses: [toolResponse]
})
break
}
case 'tool-input-delta': {
@ -99,18 +109,18 @@ export class ToolCallChunkHandler {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
return
}
const toolResponse: ToolCallResponse = {
id: toolCall.toolCallId,
tool: toolCall.tool,
arguments: toolCall.args,
status: 'pending',
toolCallId: toolCall.toolCallId
}
logger.debug('toolResponse', toolResponse)
this.onChunk({
type: ChunkType.MCP_TOOL_PENDING,
responses: [toolResponse]
})
// const toolResponse: ToolCallResponse = {
// id: toolCall.toolCallId,
// tool: toolCall.tool,
// arguments: toolCall.args,
// status: 'pending',
// toolCallId: toolCall.toolCallId
// }
// logger.debug('toolResponse', toolResponse)
// this.onChunk({
// type: ChunkType.MCP_TOOL_PENDING,
// responses: [toolResponse]
// })
break
}
}
@ -227,7 +237,7 @@ export class ToolCallChunkHandler {
// 创建工具调用结果的 MCPToolResponse 格式
const toolResponse: MCPToolResponse = {
id: toolCallId,
id: toolCallInfo.toolCallId,
tool: toolCallInfo.tool,
arguments: input,
status: 'done',

View File

@ -1,10 +1,12 @@
import { aiSdk, InferToolInput, InferToolOutput } from '@cherrystudio/ai-core'
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
import WebSearchService from '@renderer/services/WebSearchService'
import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
import { ExtractResults } from '@renderer/utils/extract'
import { InferToolInput, InferToolOutput, tool } from 'ai'
import { z } from 'zod'
const { tool } = aiSdk
/**
* 使
* 使
@ -46,16 +48,13 @@ Call this tool to execute the search. You can optionally provide additional cont
if (additionalContext?.trim()) {
// 如果大模型提供了额外上下文,使用更具体的描述
console.log(`🔍 AI enhanced search with: ${additionalContext}`)
const cleanContext = additionalContext.trim()
if (cleanContext) {
finalQueries = [cleanContext]
console.log(` Added additional context: ${cleanContext}`)
}
}
const searchResults: WebSearchProviderResponse[] = []
// 检查是否需要搜索
if (finalQueries[0] === 'not_needed') {
return {
@ -75,12 +74,9 @@ Call this tool to execute the search. You can optionally provide additional cont
links: extractedKeywords.links
}
}
console.log('extractResults', extractResults)
console.log('webSearchProvider', webSearchProvider)
const response = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
searchResults.push(response)
} catch (error) {
console.error(`Web search failed for query "${finalQueries}":`, error)
return {
summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
searchResults: [],

View File

@ -8,6 +8,7 @@ export const thinkTool: MCPTool = {
description:
'Use the tool to think about something. It will not obtain new information or make any changes to the repository, but just log the thought. Use it when complex reasoning or brainstorming is needed. For example, if you explore the repo and discover the source of a bug, call this tool to brainstorm several unique ways of fixing the bug, and assess which change(s) are likely to be simplest and most effective. Alternatively, if you receive some test results, call this tool to brainstorm ways to fix the failing tests.',
isBuiltIn: true,
type: 'mcp',
inputSchema: {
type: 'object',
title: 'Think Tool Input',

View File

@ -44,4 +44,5 @@ export interface MCPTool extends BaseTool {
inputSchema: MCPToolInputSchema
outputSchema?: z.infer<typeof MCPToolOutputSchema>
type: 'mcp'
isBuiltIn?: boolean
}