feat: support both function call and system prompt for MCP tools (#5499)

* feat: support both function call and system prompt for MCP tools
- Add support for using both function call and system prompt to implement MCP tool calls
- Refactor tool handling logic to be more flexible and maintainable
- Improve code readability with better variable naming and comments
- Fix potential issues with tool call implementation

* fix: Add tool_calls in OpenAI streaming logic

* refactor: enhance OpenAICompatibleProvider and BaseOpenAiProvider structure

* feat: add tool call setting to SettingsTab component

* fix: enhance tool call handling in OpenAICompatibleProvider

* fix: enhance content handling in GeminiProvider for nonstreaming response

* refactor: improve tool property filtering logic in OpenAIProvider and mcp-tools utility

* fix: resolve eslint errors

* fix: add history for function call message in GeminiProvider

* refactor: unify MCP tool response handling across providers for consistency

* refactor: update mcp tools conversion logic in OpenAICompatibleProvider and OpenAIProvider

* refactor: enhance AihubmixProvider and BaseProvider with MCP tool handling methods

* refactor: introduce SYSTEM_PROMPT_THRESHOLD constant in BaseProvider for improved readability

* refactor: rename tool_call to enable_tool_use for clarity and consistency across the application

* refactor: remove unnecessary onChunk call in processStream for cleaner code

* fix: add toolCallId to response structure and enhance content handling in AnthropicProvider

* fix: respond image data to llm while using function call

* fix: add reasoning handling in OpenAICompatibleProvider for improved response processing

---------

Co-authored-by: kanweiwei <kanweiwei@nutstore.net>
Co-authored-by: jay <sevenjay@users.noreply.github.com>
This commit is contained in:
Camol 2025-05-09 20:20:16 +08:00 committed by GitHub
parent c0cb1693da
commit ce8b85020b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1178 additions and 403 deletions

View File

@ -705,6 +705,7 @@
"rerank_model_tooltip": "Click the Manage button in Settings -> Model Services to add.",
"search": "Search models...",
"stream_output": "Stream output",
"enable_tool_use": "Enable Tool Use",
"type": {
"embedding": "Embedding",
"free": "Free",

View File

@ -705,6 +705,7 @@
"rerank_model_tooltip": "設定->モデルサービスに移動し、管理ボタンをクリックして追加します。",
"search": "モデルを検索...",
"stream_output": "ストリーム出力",
"enable_tool_use": "ツール呼び出し",
"type": {
"embedding": "埋め込み",
"free": "無料",

View File

@ -705,6 +705,7 @@
"rerank_model_tooltip": "В настройках -> Служба модели нажмите кнопку \"Управление\", чтобы добавить.",
"search": "Поиск моделей...",
"stream_output": "Потоковый вывод",
"enable_tool_use": "Вызов инструмента",
"type": {
"embedding": "Встраиваемые",
"free": "Бесплатные",

View File

@ -705,6 +705,7 @@
"rerank_model_tooltip": "在设置->模型服务中点击管理按钮添加",
"search": "搜索模型...",
"stream_output": "流式输出",
"enable_tool_use": "工具调用",
"type": {
"embedding": "嵌入",
"free": "免费",

View File

@ -705,6 +705,7 @@
"rerank_model_tooltip": "在設定->模型服務中點擊管理按鈕添加",
"search": "搜尋模型...",
"stream_output": "串流輸出",
"enable_tool_use": "工具調用",
"type": {
"embedding": "嵌入",
"free": "免費",

View File

@ -67,7 +67,7 @@ const MessageTools: FC<Props> = ({ blocks }) => {
const isDone = status === 'done'
const hasError = isDone && response?.isError === true
const result = {
params: tool.inputSchema,
params: toolResponse.arguments,
response: toolResponse.response
}

View File

@ -70,6 +70,7 @@ const SettingsTab: FC<Props> = (props) => {
const [maxTokens, setMaxTokens] = useState(assistant?.settings?.maxTokens ?? 0)
const [fontSizeValue, setFontSizeValue] = useState(fontSize)
const [streamOutput, setStreamOutput] = useState(assistant?.settings?.streamOutput ?? true)
const [enableToolUse, setEnableToolUse] = useState(assistant?.settings?.enableToolUse ?? false)
const { t } = useTranslation()
const dispatch = useAppDispatch()
@ -222,6 +223,18 @@ const SettingsTab: FC<Props> = (props) => {
/>
</SettingRow>
<SettingDivider />
<SettingRow>
<SettingRowTitleSmall>{t('models.enable_tool_use')}</SettingRowTitleSmall>
<Switch
size="small"
checked={enableToolUse}
onChange={(checked) => {
setEnableToolUse(checked)
updateAssistantSettings({ enableToolUse: checked })
}}
/>
</SettingRow>
<SettingDivider />
<Row align="middle" justify="space-between" style={{ marginBottom: 10 }}>
<HStack alignItems="center">
<Label>{t('chat.settings.max_tokens')}</Label>

View File

@ -24,6 +24,7 @@ const AssistantModelSettings: FC<Props> = ({ assistant, updateAssistant, updateA
const [enableMaxTokens, setEnableMaxTokens] = useState(assistant?.settings?.enableMaxTokens ?? false)
const [maxTokens, setMaxTokens] = useState(assistant?.settings?.maxTokens ?? 0)
const [streamOutput, setStreamOutput] = useState(assistant?.settings?.streamOutput ?? true)
const [enableToolUse, setEnableToolUse] = useState(assistant?.settings?.enableToolUse ?? false)
const [defaultModel, setDefaultModel] = useState(assistant?.defaultModel)
const [topP, setTopP] = useState(assistant?.settings?.topP ?? 1)
const [customParameters, setCustomParameters] = useState<AssistantSettingCustomParameters[]>(
@ -377,6 +378,18 @@ const AssistantModelSettings: FC<Props> = ({ assistant, updateAssistant, updateA
/>
</SettingRow>
<Divider style={{ margin: '10px 0' }} />
<SettingRow style={{ minHeight: 30 }}>
<Label>{t('models.enable_tool_use')}</Label>
<Switch
size="small"
checked={enableToolUse}
onChange={(checked) => {
setEnableToolUse(checked)
updateAssistantSettings({ enableToolUse: checked })
}}
/>
</SettingRow>
<Divider style={{ margin: '10px 0' }} />
<SettingRow style={{ minHeight: 30 }}>
<Label>{t('models.custom_parameters')}</Label>
<Button icon={<PlusOutlined />} onClick={onAddCustomParameter}>

View File

@ -1,6 +1,6 @@
import { isOpenAILLMModel } from '@renderer/config/models'
import { getDefaultModel } from '@renderer/services/AssistantService'
import { Assistant, Model, Provider, Suggestion } from '@renderer/types'
import { Assistant, MCPCallToolResponse, MCPTool, MCPToolResponse, Model, Provider, Suggestion } from '@renderer/types'
import { Message } from '@renderer/types/newMessage'
import OpenAI from 'openai'
@ -18,6 +18,7 @@ import OpenAIProvider from './OpenAIProvider'
export default class AihubmixProvider extends BaseProvider {
private providers: Map<string, BaseProvider> = new Map()
private defaultProvider: BaseProvider
private currentProvider: BaseProvider
constructor(provider: Provider) {
super(provider)
@ -30,6 +31,7 @@ export default class AihubmixProvider extends BaseProvider {
// 设置默认提供商
this.defaultProvider = this.providers.get('default')!
this.currentProvider = this.defaultProvider
}
/**
@ -70,7 +72,8 @@ export default class AihubmixProvider extends BaseProvider {
public async completions(params: CompletionsParams): Promise<void> {
const model = params.assistant.model
return this.getProvider(model!).completions(params)
this.currentProvider = this.getProvider(model!)
return this.currentProvider.completions(params)
}
public async translate(
@ -100,4 +103,12 @@ export default class AihubmixProvider extends BaseProvider {
public async getEmbeddingDimensions(model: Model): Promise<number> {
return this.getProvider(model).getEmbeddingDimensions(model)
}
public convertMcpTools<T>(mcpTools: MCPTool[]) {
return this.currentProvider.convertMcpTools(mcpTools) as T[]
}
public mcpToolCallResponseToMessage(mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) {
return this.currentProvider.mcpToolCallResponseToMessage(mcpToolResponse, resp, model)
}
}

View File

@ -1,15 +1,19 @@
import Anthropic from '@anthropic-ai/sdk'
import {
Base64ImageSource,
ImageBlockParam,
MessageCreateParamsNonStreaming,
MessageParam,
TextBlockParam,
ToolResultBlockParam,
ToolUnion,
ToolUseBlock,
WebSearchResultBlock,
WebSearchTool20250305,
WebSearchToolResultError
} from '@anthropic-ai/sdk/resources'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { isReasoningModel, isVisionModel, isWebSearchModel } from '@renderer/config/models'
import { isReasoningModel, isWebSearchModel } from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
@ -23,16 +27,24 @@ import {
Assistant,
EFFORT_RATIO,
FileTypes,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
Suggestion,
ToolCallResponse,
WebSearchSource
} from '@renderer/types'
import { ChunkType } from '@renderer/types/chunk'
import type { Message } from '@renderer/types/newMessage'
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
import { mcpToolCallResponseToAnthropicMessage, parseAndCallTools } from '@renderer/utils/mcp-tools'
import {
anthropicToolUseToMcpTool,
mcpToolCallResponseToAnthropicMessage,
mcpToolsToAnthropicTools,
parseAndCallTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { first, flatten, sum, takeRight } from 'lodash'
@ -199,7 +211,7 @@ export default class AnthropicProvider extends BaseProvider {
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant)
const userMessagesParams: MessageParam[] = []
@ -215,10 +227,16 @@ export default class AnthropicProvider extends BaseProvider {
const userMessages = flatten(userMessagesParams)
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
// const tools = mcpTools ? mcpToolsToAnthropicTools(mcpTools) : undefined
let systemPrompt = assistant.prompt
if (mcpTools && mcpTools.length > 0) {
const { tools } = this.setupToolsConfig<ToolUnion>({
model,
mcpTools,
enableToolUse
})
if (this.useSystemPromptForTools && mcpTools && mcpTools.length) {
systemPrompt = buildSystemPrompt(systemPrompt, mcpTools)
}
@ -232,8 +250,6 @@ export default class AnthropicProvider extends BaseProvider {
const isEnabledBuiltinWebSearch = assistant.enableWebSearch
const tools: ToolUnion[] = []
if (isEnabledBuiltinWebSearch) {
const webSearchTool = await this.getWebSearchParams(model)
if (webSearchTool) {
@ -244,7 +260,6 @@ export default class AnthropicProvider extends BaseProvider {
const body: MessageCreateParamsNonStreaming = {
model: model.id,
messages: userMessages,
// tools: isEmpty(tools) ? undefined : tools,
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
@ -303,7 +318,7 @@ export default class AnthropicProvider extends BaseProvider {
const processStream = (body: MessageCreateParamsNonStreaming, idx: number) => {
return new Promise<void>((resolve, reject) => {
// 等待接口返回流
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
const toolCalls: ToolUseBlock[] = []
let hasThinkingContent = false
this.sdk.messages
.stream({ ...body, stream: true }, { signal, timeout: 5 * 60 * 1000 })
@ -380,30 +395,70 @@ export default class AnthropicProvider extends BaseProvider {
})
thinking_content += thinking
})
.on('contentBlock', (content) => {
if (content.type === 'tool_use') {
toolCalls.push(content)
}
})
.on('finalMessage', async (message) => {
const toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
// tool call
if (toolCalls.length > 0) {
const mcpToolResponses = toolCalls
.map((toolCall) => {
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
if (!mcpTool) {
return undefined
}
return {
id: toolCall.id,
toolCallId: toolCall.id,
tool: mcpTool,
arguments: toolCall.input as Record<string, unknown>,
status: 'pending'
} as ToolCallResponse
})
.filter((t) => typeof t !== 'undefined')
toolResults.push(
...(await parseAndCallTools(
mcpToolResponses,
toolResponses,
onChunk,
this.mcpToolCallResponseToMessage,
model,
mcpTools
))
)
}
// tool use
const content = message.content[0]
if (content && content.type === 'text') {
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content.text })
const toolResults = await parseAndCallTools(
content.text,
toolResponses,
onChunk,
idx,
mcpToolCallResponseToAnthropicMessage,
mcpTools,
isVisionModel(model)
toolResults.push(
...(await parseAndCallTools(
content.text,
toolResponses,
onChunk,
this.mcpToolCallResponseToMessage,
model,
mcpTools
))
)
if (toolResults.length > 0) {
userMessages.push({
role: message.role,
content: message.content
})
}
toolResults.forEach((ts) => userMessages.push(ts as MessageParam))
const newBody = body
newBody.messages = userMessages
await processStream(newBody, idx + 1)
}
userMessages.push({
role: message.role,
content: message.content
})
if (toolResults.length > 0) {
toolResults.forEach((ts) => userMessages.push(ts as MessageParam))
const newBody = body
newBody.messages = userMessages
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
await processStream(newBody, idx + 1)
}
const time_completion_millsec = new Date().getTime() - start_time_millsec
@ -434,7 +489,7 @@ export default class AnthropicProvider extends BaseProvider {
})
})
}
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
await processStream(body, 0).finally(cleanup)
}
@ -683,4 +738,47 @@ export default class AnthropicProvider extends BaseProvider {
public async getEmbeddingDimensions(): Promise<number> {
return 0
}
public convertMcpTools<T>(mcpTools: MCPTool[]): T[] {
return mcpToolsToAnthropicTools(mcpTools) as T[]
}
public mcpToolCallResponseToMessage = (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model)
} else if ('toolCallId' in mcpToolResponse) {
return {
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: mcpToolResponse.toolCallId!,
content: resp.content
.map((item) => {
if (item.type === 'text') {
return {
type: 'text',
text: item.text || ''
} satisfies TextBlockParam
}
if (item.type === 'image') {
return {
type: 'image',
source: {
data: item.data || '',
media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'],
type: 'base64'
}
} satisfies ImageBlockParam
}
return
})
.filter((n) => typeof n !== 'undefined'),
is_error: resp.isError
} satisfies ToolResultBlockParam
]
}
}
return
}
}

View File

@ -1,9 +1,13 @@
import { isFunctionCallingModel } from '@renderer/config/models'
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
import type {
Assistant,
GenerateImageParams,
KnowledgeReference,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
Suggestion,
@ -22,10 +26,15 @@ import type OpenAI from 'openai'
import type { CompletionsParams } from '.'
export default abstract class BaseProvider {
// Threshold for determining whether to use system prompt for tools
private static readonly SYSTEM_PROMPT_THRESHOLD: number = 128
protected provider: Provider
protected host: string
protected apiKey: string
protected useSystemPromptForTools: boolean = true
constructor(provider: Provider) {
this.provider = provider
this.host = this.getBaseURL()
@ -47,6 +56,12 @@ export default abstract class BaseProvider {
abstract generateImage(params: GenerateImageParams): Promise<string[]>
abstract generateImageByChat({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
abstract getEmbeddingDimensions(model: Model): Promise<number>
public abstract convertMcpTools<T>(mcpTools: MCPTool[]): T[]
public abstract mcpToolCallResponseToMessage(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): any
public getBaseURL(): string {
const host = this.provider.apiHost
@ -229,4 +244,31 @@ export default abstract class BaseProvider {
cleanup
}
}
// Setup tools configuration based on provided parameters
protected setupToolsConfig<T>(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
tools: T[]
} {
const { mcpTools, model, enableToolUse } = params
let tools: T[] = []
// If there are no tools, return an empty array
if (!mcpTools?.length) {
return { tools }
}
// If the number of tools exceeds the threshold, use the system prompt
if (mcpTools.length > BaseProvider.SYSTEM_PROMPT_THRESHOLD) {
this.useSystemPromptForTools = true
return { tools }
}
// If the model supports function calling and tool usage is enabled
if (isFunctionCallingModel(model) && enableToolUse) {
tools = this.convertMcpTools<T>(mcpTools)
this.useSystemPromptForTools = false
}
return { tools }
}
}

View File

@ -1,6 +1,7 @@
import {
Content,
File,
FunctionCall,
GenerateContentConfig,
GenerateContentResponse,
GoogleGenAI,
@ -11,8 +12,9 @@ import {
PartUnion,
SafetySetting,
ThinkingConfig,
ToolListUnion
Tool
} from '@google/genai'
import { nanoid } from '@reduxjs/toolkit'
import {
findTokenLimit,
isGeminiReasoningModel,
@ -35,17 +37,25 @@ import {
EFFORT_RATIO,
FileType,
FileTypes,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
Suggestion,
ToolCallResponse,
Usage,
WebSearchSource
} from '@renderer/types'
import { BlockCompleteChunk, Chunk, ChunkType, LLMWebSearchCompleteChunk } from '@renderer/types/chunk'
import type { Message, Response } from '@renderer/types/newMessage'
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
import { mcpToolCallResponseToGeminiMessage, parseAndCallTools } from '@renderer/utils/mcp-tools'
import {
geminiFunctionCallToMcpTool,
mcpToolCallResponseToGeminiMessage,
mcpToolsToGeminiTools,
parseAndCallTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { MB } from '@shared/config/constant'
@ -263,7 +273,7 @@ export default class GeminiProvider extends BaseProvider {
}: CompletionsParams): Promise<void> {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant)
const userMessages = filterUserRoleStartMessages(
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
@ -280,12 +290,16 @@ export default class GeminiProvider extends BaseProvider {
let systemInstruction = assistant.prompt
if (mcpTools && mcpTools.length > 0) {
const { tools } = this.setupToolsConfig<Tool>({
mcpTools,
model,
enableToolUse
})
if (this.useSystemPromptForTools) {
systemInstruction = buildSystemPrompt(assistant.prompt || '', mcpTools)
}
// const tools = mcpToolsToGeminiTools(mcpTools)
const tools: ToolListUnion = []
const toolResponses: MCPToolResponse[] = []
if (assistant.enableWebSearch && isWebSearchModel(model)) {
@ -351,6 +365,224 @@ export default class GeminiProvider extends BaseProvider {
const { cleanup, abortController } = this.createAbortController(userLastMessage?.id, true)
const processToolResults = async (toolResults: Awaited<ReturnType<typeof parseAndCallTools>>, idx: number) => {
if (toolResults.length === 0) return
const newChat = this.sdk.chats.create({
model: model.id,
config: generateContentConfig,
history: history as Content[]
})
const newStream = await newChat.sendMessageStream({
message: flatten(toolResults.map((ts) => (ts as Content).parts)) as PartUnion,
config: {
...generateContentConfig,
abortSignal: abortController.signal
}
})
await processStream(newStream, idx + 1)
}
const processToolCalls = async (toolCalls: FunctionCall[]) => {
const mcpToolResponses: ToolCallResponse[] = toolCalls
.map((toolCall) => {
const mcpTool = geminiFunctionCallToMcpTool(mcpTools, toolCall)
if (!mcpTool) return undefined
const parsedArgs = (() => {
try {
return typeof toolCall.args === 'string' ? JSON.parse(toolCall.args) : toolCall.args
} catch {
return toolCall.args
}
})()
return {
id: toolCall.id || nanoid(),
toolCallId: toolCall.id,
tool: mcpTool,
arguments: parsedArgs,
status: 'pending'
} as ToolCallResponse
})
.filter((t): t is ToolCallResponse => typeof t !== 'undefined')
return await parseAndCallTools(
mcpToolResponses,
toolResponses,
onChunk,
this.mcpToolCallResponseToMessage,
model,
mcpTools
)
}
const processToolUses = async (content: string) => {
return await parseAndCallTools(
content,
toolResponses,
onChunk,
this.mcpToolCallResponseToMessage,
model,
mcpTools
)
}
const processStream = async (
stream: AsyncGenerator<GenerateContentResponse> | GenerateContentResponse,
idx: number
) => {
history.push(messageContents)
let functionCalls: FunctionCall[] = []
if (stream instanceof GenerateContentResponse) {
let content = ''
const time_completion_millsec = new Date().getTime() - start_time_millsec
const toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
if (stream.text?.length) {
toolResults.push(...(await processToolUses(stream.text)))
}
stream.candidates?.forEach((candidate) => {
if (candidate.content) {
history.push(candidate.content)
candidate.content.parts?.forEach((part) => {
if (part.functionCall) {
functionCalls.push(part.functionCall)
}
if (part.text) {
content += part.text
onChunk({ type: ChunkType.TEXT_DELTA, text: part.text })
}
})
}
})
if (content.length) {
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
}
if (functionCalls.length) {
toolResults.push(...(await processToolCalls(functionCalls)))
}
if (stream.text?.length) {
toolResults.push(...(await processToolUses(stream.text)))
}
if (toolResults.length) {
await processToolResults(toolResults, idx)
}
onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
text: stream.text,
usage: {
prompt_tokens: stream.usageMetadata?.promptTokenCount || 0,
thoughts_tokens: stream.usageMetadata?.thoughtsTokenCount || 0,
completion_tokens: stream.usageMetadata?.candidatesTokenCount || 0,
total_tokens: stream.usageMetadata?.totalTokenCount || 0
},
metrics: {
completion_tokens: stream.usageMetadata?.candidatesTokenCount,
time_completion_millsec,
time_first_token_millsec: 0
},
webSearch: {
results: stream.candidates?.[0]?.groundingMetadata,
source: 'gemini'
}
} as Response
} as BlockCompleteChunk)
} else {
let content = ''
let final_time_completion_millsec = 0
let lastUsage: Usage | undefined = undefined
for await (const chunk of stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
// --- Calculate Metrics ---
if (time_first_token_millsec == 0 && chunk.text !== undefined) {
// Update based on text arrival
time_first_token_millsec = new Date().getTime() - start_time_millsec
}
// 1. Text Content
if (chunk.text !== undefined) {
content += chunk.text
onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text })
}
// 2. Usage Data
if (chunk.usageMetadata) {
lastUsage = {
prompt_tokens: chunk.usageMetadata.promptTokenCount || 0,
completion_tokens: chunk.usageMetadata.candidatesTokenCount || 0,
total_tokens: chunk.usageMetadata.totalTokenCount || 0
}
final_time_completion_millsec = new Date().getTime() - start_time_millsec
}
// 4. Image Generation
const generateImage = this.processGeminiImageResponse(chunk, onChunk)
if (generateImage?.images?.length) {
onChunk({ type: ChunkType.IMAGE_COMPLETE, image: generateImage })
}
if (chunk.candidates?.[0]?.finishReason) {
if (chunk.text) {
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
}
if (chunk.candidates?.[0]?.groundingMetadata) {
// 3. Grounding/Search Metadata
const groundingMetadata = chunk.candidates?.[0]?.groundingMetadata
onChunk({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: groundingMetadata,
source: WebSearchSource.GEMINI
}
} as LLMWebSearchCompleteChunk)
}
if (chunk.functionCalls) {
chunk.candidates?.forEach((candidate) => {
if (candidate.content) {
history.push(candidate.content)
}
})
functionCalls = functionCalls.concat(chunk.functionCalls)
}
onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
metrics: {
completion_tokens: lastUsage?.completion_tokens,
time_completion_millsec: final_time_completion_millsec,
time_first_token_millsec
},
usage: lastUsage
}
})
}
// --- End Incremental onChunk calls ---
// Call processToolUses AFTER potentially processing text content in this chunk
// This assumes tools might be specified within the text stream
// Note: parseAndCallTools inside should handle its own onChunk for tool responses
let toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
if (functionCalls.length) {
toolResults = await processToolCalls(functionCalls)
}
if (content.length) {
toolResults = toolResults.concat(await processToolUses(content))
}
if (toolResults.length) {
await processToolResults(toolResults, idx)
}
}
}
}
if (!streamOutput) {
const response = await chat.sendMessage({
message: messageContents as PartUnion,
@ -359,32 +591,10 @@ export default class GeminiProvider extends BaseProvider {
abortSignal: abortController.signal
}
})
const time_completion_millsec = new Date().getTime() - start_time_millsec
onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
text: response.text,
usage: {
prompt_tokens: response.usageMetadata?.promptTokenCount || 0,
thoughts_tokens: response.usageMetadata?.thoughtsTokenCount || 0,
completion_tokens: response.usageMetadata?.candidatesTokenCount || 0,
total_tokens: response.usageMetadata?.totalTokenCount || 0
},
metrics: {
completion_tokens: response.usageMetadata?.candidatesTokenCount,
time_completion_millsec,
time_first_token_millsec: 0
},
webSearch: {
results: response.candidates?.[0]?.groundingMetadata,
source: 'gemini'
}
} as Response
} as BlockCompleteChunk)
return
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
return await processStream(response, 0).then(cleanup)
}
// 等待接口返回流
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
const userMessagesStream = await chat.sendMessageStream({
message: messageContents as PartUnion,
@ -394,105 +604,6 @@ export default class GeminiProvider extends BaseProvider {
}
})
const processToolUses = async (content: string, idx: number) => {
const toolResults = await parseAndCallTools(
content,
toolResponses,
onChunk,
idx,
mcpToolCallResponseToGeminiMessage,
mcpTools,
isVisionModel(model)
)
if (toolResults && toolResults.length > 0) {
history.push(messageContents)
const newChat = this.sdk.chats.create({
model: model.id,
config: generateContentConfig,
history: history as Content[]
})
const newStream = await newChat.sendMessageStream({
message: flatten(toolResults.map((ts) => (ts as Content).parts)) as PartUnion,
config: {
...generateContentConfig,
abortSignal: abortController.signal
}
})
await processStream(newStream, idx + 1)
}
}
const processStream = async (stream: AsyncGenerator<GenerateContentResponse>, idx: number) => {
let content = ''
let final_time_completion_millsec = 0
let lastUsage: Usage | undefined = undefined
for await (const chunk of stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
// --- Calculate Metrics ---
if (time_first_token_millsec == 0 && chunk.text !== undefined) {
// Update based on text arrival
time_first_token_millsec = new Date().getTime() - start_time_millsec
}
// 1. Text Content
if (chunk.text !== undefined) {
content += chunk.text
onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text })
}
// 2. Usage Data
if (chunk.usageMetadata) {
lastUsage = {
prompt_tokens: chunk.usageMetadata.promptTokenCount || 0,
completion_tokens: chunk.usageMetadata.candidatesTokenCount || 0,
total_tokens: chunk.usageMetadata.totalTokenCount || 0
}
final_time_completion_millsec = new Date().getTime() - start_time_millsec
}
// 4. Image Generation
const generateImage = this.processGeminiImageResponse(chunk, onChunk)
if (generateImage?.images?.length) {
onChunk({ type: ChunkType.IMAGE_COMPLETE, image: generateImage })
}
if (chunk.candidates?.[0]?.finishReason) {
if (chunk.text) {
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
}
if (chunk.candidates?.[0]?.groundingMetadata) {
// 3. Grounding/Search Metadata
const groundingMetadata = chunk.candidates?.[0]?.groundingMetadata
onChunk({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: groundingMetadata,
source: WebSearchSource.GEMINI
}
} as LLMWebSearchCompleteChunk)
}
onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
metrics: {
completion_tokens: lastUsage?.completion_tokens,
time_completion_millsec: final_time_completion_millsec,
time_first_token_millsec
},
usage: lastUsage
}
})
}
// --- End Incremental onChunk calls ---
// Call processToolUses AFTER potentially processing text content in this chunk
// This assumes tools might be specified within the text stream
// Note: parseAndCallTools inside should handle its own onChunk for tool responses
await processToolUses(content, idx)
}
}
await processStream(userMessagesStream, 0).finally(cleanup)
const final_time_completion_millsec = new Date().getTime() - start_time_millsec
@ -841,4 +952,32 @@ export default class GeminiProvider extends BaseProvider {
public generateImageByChat(): Promise<void> {
throw new Error('Method not implemented.')
}
public convertMcpTools<T>(mcpTools: MCPTool[]): T[] {
return mcpToolsToGeminiTools(mcpTools) as T[]
}
public mcpToolCallResponseToMessage = (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model))
} else if ('toolCallId' in mcpToolResponse) {
const toolCallOut = {
role: 'user',
parts: [
{
functionResponse: {
id: mcpToolResponse.toolCallId,
name: mcpToolResponse.tool.id,
response: {
output: !resp.isError ? resp.content : undefined,
error: resp.isError ? resp.content : undefined
}
}
}
]
} satisfies Content
return toolCallOut
}
return
}
}

View File

@ -31,10 +31,13 @@ import {
Assistant,
EFFORT_RATIO,
FileTypes,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
Suggestion,
ToolCallResponse,
Usage,
WebSearchSource
} from '@renderer/types'
@ -48,7 +51,12 @@ import {
convertLinksToOpenRouter,
convertLinksToZhipu
} from '@renderer/utils/linkConverter'
import { mcpToolCallResponseToOpenAICompatibleMessage, parseAndCallTools } from '@renderer/utils/mcp-tools'
import {
mcpToolCallResponseToOpenAICompatibleMessage,
mcpToolsToOpenAIChatTools,
openAIToolsToMcpTool,
parseAndCallTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { asyncGeneratorToReadableStream, readableStreamAsyncIterable } from '@renderer/utils/stream'
@ -57,18 +65,22 @@ import OpenAI, { AzureOpenAI } from 'openai'
import {
ChatCompletionContentPart,
ChatCompletionCreateParamsNonStreaming,
ChatCompletionMessageParam
ChatCompletionMessageParam,
ChatCompletionMessageToolCall,
ChatCompletionTool,
ChatCompletionToolMessageParam
} from 'openai/resources'
import { CompletionsParams } from '.'
import OpenAIProvider from './OpenAIProvider'
import { BaseOpenAiProvider } from './OpenAIProvider'
// 1. 定义联合类型
export type OpenAIStreamChunk =
| { type: 'reasoning' | 'text-delta'; textDelta: string }
| { type: 'tool-calls'; delta: any }
| { type: 'finish'; finishReason: any; usage: any; delta: any; chunk: any }
export default class OpenAICompatibleProvider extends OpenAIProvider {
export default class OpenAICompatibleProvider extends BaseOpenAiProvider {
constructor(provider: Provider) {
super(provider)
@ -313,6 +325,24 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
return {}
}
public convertMcpTools<T>(mcpTools: MCPTool[]): T[] {
return mcpToolsToOpenAIChatTools(mcpTools) as T[]
}
public mcpToolCallResponseToMessage = (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
return mcpToolCallResponseToOpenAICompatibleMessage(mcpToolResponse, resp, isVisionModel(model))
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
const toolCallOut: ChatCompletionToolMessageParam = {
role: 'tool',
tool_call_id: mcpToolResponse.toolCallId,
content: JSON.stringify(resp.content)
}
return toolCallOut
}
return
}
/**
* Generate completions for the assistant
* @param messages - The messages
@ -330,7 +360,7 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant)
const isEnabledWebSearch = assistant.enableWebSearch || !!assistant.webSearchProviderId
messages = addImageFileToContents(messages)
const enableReasoning =
@ -344,7 +374,9 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}`
}
}
if (mcpTools && mcpTools.length > 0) {
const { tools } = this.setupToolsConfig<ChatCompletionTool>({ mcpTools, model, enableToolUse })
if (this.useSystemPromptForTools) {
systemMessage.content = buildSystemPrompt(systemMessage.content || '', mcpTools)
}
@ -379,53 +411,86 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
const toolResponses: MCPToolResponse[] = []
const processToolUses = async (content: string, idx: number) => {
const toolResults = await parseAndCallTools(
const processToolResults = async (toolResults: Awaited<ReturnType<typeof parseAndCallTools>>, idx: number) => {
if (toolResults.length === 0) return
toolResults.forEach((ts) => reqMessages.push(ts as ChatCompletionMessageParam))
console.debug('[tool] reqMessages before processing', model.id, reqMessages)
reqMessages = processReqMessages(model, reqMessages)
console.debug('[tool] reqMessages', model.id, reqMessages)
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
const newStream = await this.sdk.chat.completions
// @ts-ignore key is not typed
.create(
{
model: model.id,
messages: reqMessages,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
max_tokens: maxTokens,
keep_alive: this.keepAliveTime,
stream: isSupportStreamOutput(),
tools: !isEmpty(tools) ? tools : undefined,
...getOpenAIWebSearchParams(assistant, model),
...this.getReasoningEffort(assistant, model),
...this.getProviderSpecificParameters(assistant, model),
...this.getCustomParameters(assistant)
},
{
signal
}
)
await processStream(newStream, idx + 1)
}
const processToolCalls = async (mcpTools, toolCalls: ChatCompletionMessageToolCall[]) => {
const mcpToolResponses = toolCalls
.map((toolCall) => {
const mcpTool = openAIToolsToMcpTool(mcpTools, toolCall as ChatCompletionMessageToolCall)
if (!mcpTool) return undefined
const parsedArgs = (() => {
try {
return JSON.parse(toolCall.function.arguments)
} catch {
return toolCall.function.arguments
}
})()
return {
id: toolCall.id,
toolCallId: toolCall.id,
tool: mcpTool,
arguments: parsedArgs,
status: 'pending'
} as ToolCallResponse
})
.filter((t): t is ToolCallResponse => typeof t !== 'undefined')
return await parseAndCallTools(
mcpToolResponses,
toolResponses,
onChunk,
this.mcpToolCallResponseToMessage,
model,
mcpTools
)
}
const processToolUses = async (content: string) => {
return await parseAndCallTools(
content,
toolResponses,
onChunk,
idx,
mcpToolCallResponseToOpenAICompatibleMessage,
mcpTools,
isVisionModel(model)
this.mcpToolCallResponseToMessage,
model,
mcpTools
)
if (toolResults.length > 0) {
reqMessages.push({
role: 'assistant',
content: content
} as ChatCompletionMessageParam)
toolResults.forEach((ts) => reqMessages.push(ts as ChatCompletionMessageParam))
reqMessages = processReqMessages(model, reqMessages)
const newStream = await this.sdk.chat.completions
// @ts-ignore key is not typed
.create(
{
model: model.id,
messages: reqMessages,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
max_tokens: maxTokens,
keep_alive: this.keepAliveTime,
stream: isSupportStreamOutput(),
// tools: tools,
service_tier: this.getServiceTier(model),
...getOpenAIWebSearchParams(assistant, model),
...this.getReasoningEffort(assistant, model),
...this.getProviderSpecificParameters(assistant, model),
...this.getCustomParameters(assistant)
},
{
signal,
timeout: this.getTimeout(model)
}
)
await processStream(newStream, idx + 1)
}
}
const processStream = async (stream: any, idx: number) => {
const toolCalls: ChatCompletionMessageToolCall[] = []
// Handle non-streaming case (already returns early, no change needed here)
if (!isSupportStreamOutput()) {
const time_completion_millsec = new Date().getTime() - start_time_millsec
@ -439,10 +504,59 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
// Create a synthetic usage object if stream.usage is undefined
const finalUsage = stream.usage
// Separate onChunk calls for text and usage/metrics
if (stream.choices[0].message?.content) {
onChunk({ type: ChunkType.TEXT_COMPLETE, text: stream.choices[0].message.content })
let content = ''
stream.choices.forEach((choice) => {
// reasoning
if (choice.message.reasoning) {
onChunk({ type: ChunkType.THINKING_DELTA, text: choice.message.reasoning })
onChunk({
type: ChunkType.THINKING_COMPLETE,
text: choice.message.reasoning,
thinking_millsec: time_completion_millsec
})
}
// text
if (choice.message.content) {
content += choice.message.content
onChunk({ type: ChunkType.TEXT_DELTA, text: choice.message.content })
}
// tool call
if (choice.message.tool_calls && choice.message.tool_calls.length) {
choice.message.tool_calls.forEach((t) => toolCalls.push(t))
}
reqMessages.push({
role: choice.message.role,
content: choice.message.content,
tool_calls: toolCalls.length
? toolCalls.map((toolCall) => ({
id: toolCall.id,
function: {
...toolCall.function,
arguments:
typeof toolCall.function.arguments === 'string'
? toolCall.function.arguments
: JSON.stringify(toolCall.function.arguments)
},
type: 'function'
}))
: undefined
})
})
if (content.length) {
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
}
const toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
if (toolCalls.length) {
toolResults.push(...(await processToolCalls(mcpTools, toolCalls)))
}
if (stream.choices[0].message?.content) {
toolResults.push(...(await processToolUses(stream.choices[0].message?.content)))
}
await processToolResults(toolResults, idx)
// Always send usage and metrics data
onChunk({ type: ChunkType.BLOCK_COMPLETE, response: { usage: finalUsage, metrics: finalMetrics } })
return
@ -486,6 +600,9 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
if (delta?.content) {
yield { type: 'text-delta', textDelta: delta.content }
}
if (delta?.tool_calls) {
yield { type: 'tool-calls', delta: delta }
}
const finishReason = chunk.choices[0]?.finish_reason
if (!isEmpty(finishReason)) {
@ -563,6 +680,25 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
onChunk({ type: ChunkType.TEXT_DELTA, text: textDelta })
break
}
case 'tool-calls': {
chunk.delta.tool_calls.forEach((toolCall) => {
const { id, index, type, function: fun } = toolCall
if (id && type === 'function' && fun) {
const { name, arguments: args } = fun
toolCalls.push({
id,
function: {
name: name || '',
arguments: args || ''
},
type: 'function'
})
} else if (fun?.arguments) {
toolCalls[index].function.arguments += fun.arguments
}
})
break
}
case 'finish': {
const finishReason = chunk.finishReason
const usage = chunk.usage
@ -624,7 +760,33 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
} as LLMWebSearchCompleteChunk)
}
}
await processToolUses(content, idx)
reqMessages.push({
role: 'assistant',
content: content,
tool_calls: toolCalls.length
? toolCalls.map((toolCall) => ({
id: toolCall.id,
function: {
...toolCall.function,
arguments:
typeof toolCall.function.arguments === 'string'
? toolCall.function.arguments
: JSON.stringify(toolCall.function.arguments)
},
type: 'function'
}))
: undefined
})
let toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
if (toolCalls.length) {
toolResults = await processToolCalls(mcpTools, toolCalls)
}
if (content.length) {
toolResults = toolResults.concat(await processToolUses(content))
}
if (toolResults.length) {
await processToolResults(toolResults, idx)
}
onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
@ -657,7 +819,7 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
max_tokens: maxTokens,
keep_alive: this.keepAliveTime,
stream: isSupportStreamOutput(),
// tools: tools,
tools: !isEmpty(tools) ? tools : undefined,
service_tier: this.getServiceTier(model),
...getOpenAIWebSearchParams(assistant, model),
...this.getReasoningEffort(assistant, model),

View File

@ -21,10 +21,13 @@ import {
Assistant,
FileTypes,
GenerateImageParams,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
Suggestion,
ToolCallResponse,
Usage,
WebSearchSource
} from '@renderer/types'
@ -33,7 +36,12 @@ import { Message } from '@renderer/types/newMessage'
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
import { addImageFileToContents } from '@renderer/utils/formats'
import { convertLinks } from '@renderer/utils/linkConverter'
import { mcpToolCallResponseToOpenAIMessage, parseAndCallTools } from '@renderer/utils/mcp-tools'
import {
mcpToolCallResponseToOpenAIMessage,
mcpToolsToOpenAIResponseTools,
openAIToolsToMcpTool,
parseAndCallTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { isEmpty, takeRight } from 'lodash'
@ -45,7 +53,7 @@ import { FileLike, toFile } from 'openai/uploads'
import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider'
export default class OpenAIProvider extends BaseProvider {
export abstract class BaseOpenAiProvider extends BaseProvider {
protected sdk: OpenAI
constructor(provider: Provider) {
@ -61,6 +69,14 @@ export default class OpenAIProvider extends BaseProvider {
})
}
abstract convertMcpTools<T>(mcpTools: MCPTool[]): T[]
abstract mcpToolCallResponseToMessage: (
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
) => OpenAI.Responses.ResponseInputItem | ChatCompletionMessageParam | undefined
/**
* Extract the file content from the message
* @param message - The message
@ -91,16 +107,23 @@ export default class OpenAIProvider extends BaseProvider {
return ''
}
private async getReponseMessageParam(message: Message, model: Model): Promise<OpenAI.Responses.EasyInputMessage> {
private async getReponseMessageParam(message: Message, model: Model): Promise<OpenAI.Responses.ResponseInputItem> {
const isVision = isVisionModel(model)
const content = await this.getMessageContent(message)
const fileBlocks = findFileBlocks(message)
const imageBlocks = findImageBlocks(message)
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
return {
role: message.role === 'system' ? 'user' : message.role,
content: content ? [{ type: 'input_text', text: content }] : []
if (message.role === 'assistant') {
return {
role: 'assistant',
content: content
}
} else {
return {
role: message.role === 'system' ? 'user' : message.role,
content: content ? [{ type: 'input_text', text: content }] : []
} as OpenAI.Responses.EasyInputMessage
}
}
@ -285,10 +308,8 @@ export default class OpenAIProvider extends BaseProvider {
}
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant)
const isEnabledWebSearch = assistant.enableWebSearch || !!assistant.webSearchProviderId
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
// 退回到 OpenAI 兼容模式
if (isOpenAIWebSearch(model)) {
const systemMessage = { role: 'system', content: assistant.prompt || '' }
@ -387,7 +408,7 @@ export default class OpenAIProvider extends BaseProvider {
})
return
}
const tools: OpenAI.Responses.Tool[] = []
let tools: OpenAI.Responses.Tool[] = []
if (isEnabledWebSearch) {
tools.push({
type: 'web_search_preview'
@ -408,7 +429,15 @@ export default class OpenAIProvider extends BaseProvider {
systemMessage.role = 'developer'
}
if (mcpTools && mcpTools.length > 0) {
const { tools: extraTools } = this.setupToolsConfig<OpenAI.Responses.Tool>({
mcpTools,
model,
enableToolUse
})
tools = tools.concat(extraTools)
if (this.useSystemPromptForTools) {
systemMessageInput.text = buildSystemPrompt(systemMessageInput.text || '', mcpTools)
}
systemMessageContent.push(systemMessageInput)
@ -418,7 +447,7 @@ export default class OpenAIProvider extends BaseProvider {
)
onFilterMessages(_messages)
const userMessage: OpenAI.Responses.EasyInputMessage[] = []
const userMessage: OpenAI.Responses.ResponseInputItem[] = []
for (const message of _messages) {
userMessage.push(await this.getReponseMessageParam(message, model))
}
@ -431,7 +460,7 @@ export default class OpenAIProvider extends BaseProvider {
const { signal } = abortController
// 当 systemMessage 内容为空时不发送 systemMessage
let reqMessages: OpenAI.Responses.EasyInputMessage[]
let reqMessages: OpenAI.Responses.ResponseInput
if (!systemMessage.content) {
reqMessages = [...userMessage]
} else {
@ -440,48 +469,84 @@ export default class OpenAIProvider extends BaseProvider {
const toolResponses: MCPToolResponse[] = []
const processToolUses = async (content: string, idx: number) => {
const toolResults = await parseAndCallTools(
const processToolResults = async (toolResults: Awaited<ReturnType<typeof parseAndCallTools>>, idx: number) => {
if (toolResults.length === 0) return
toolResults.forEach((ts) => reqMessages.push(ts as OpenAI.Responses.EasyInputMessage))
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
const stream = await this.sdk.responses.create(
{
model: model.id,
input: reqMessages,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
max_output_tokens: maxTokens,
stream: streamOutput,
tools: !isEmpty(tools) ? tools : undefined,
service_tier: this.getServiceTier(model),
...this.getResponseReasoningEffort(assistant, model),
...this.getCustomParameters(assistant)
},
{
signal,
timeout: this.getTimeout(model)
}
)
await processStream(stream, idx + 1)
}
const processToolCalls = async (mcpTools, toolCalls: OpenAI.Responses.ResponseFunctionToolCall[]) => {
const mcpToolResponses = toolCalls
.map((toolCall) => {
const mcpTool = openAIToolsToMcpTool(mcpTools, toolCall as OpenAI.Responses.ResponseFunctionToolCall)
if (!mcpTool) return undefined
const parsedArgs = (() => {
try {
return JSON.parse(toolCall.arguments)
} catch {
return toolCall.arguments
}
})()
return {
id: toolCall.call_id,
toolCallId: toolCall.call_id,
tool: mcpTool,
arguments: parsedArgs,
status: 'pending'
} as ToolCallResponse
})
.filter((t): t is ToolCallResponse => typeof t !== 'undefined')
return await parseAndCallTools<OpenAI.Responses.ResponseInputItem | ChatCompletionMessageParam>(
mcpToolResponses,
toolResponses,
onChunk,
this.mcpToolCallResponseToMessage,
model,
mcpTools
)
}
const processToolUses = async (content: string) => {
return await parseAndCallTools(
content,
toolResponses,
onChunk,
idx,
mcpToolCallResponseToOpenAIMessage,
mcpTools,
isVisionModel(model)
this.mcpToolCallResponseToMessage,
model,
mcpTools
)
if (toolResults.length > 0) {
reqMessages.push({
role: 'assistant',
content: content
})
toolResults.forEach((ts) => reqMessages.push(ts as OpenAI.Responses.EasyInputMessage))
const newStream = await this.sdk.responses.create(
{
model: model.id,
input: reqMessages,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
max_output_tokens: maxTokens,
stream: true,
service_tier: this.getServiceTier(model),
...this.getResponseReasoningEffort(assistant, model),
...this.getCustomParameters(assistant)
},
{
signal,
timeout: this.getTimeout(model)
}
)
await processStream(newStream, idx + 1)
}
}
const processStream = async (
stream: Stream<OpenAI.Responses.ResponseStreamEvent> | OpenAI.Responses.Response,
idx: number
) => {
const toolCalls: OpenAI.Responses.ResponseFunctionToolCall[] = []
if (!streamOutput) {
const nonStream = stream as OpenAI.Responses.Response
const time_completion_millsec = new Date().getTime() - start_time_millsec
@ -499,11 +564,15 @@ export default class OpenAIProvider extends BaseProvider {
prompt_tokens: nonStream.usage?.input_tokens || 0,
total_tokens
}
let content = ''
for (const output of nonStream.output) {
switch (output.type) {
case 'message':
if (output.content[0].type === 'output_text') {
onChunk({ type: ChunkType.TEXT_DELTA, text: output.content[0].text })
onChunk({ type: ChunkType.TEXT_COMPLETE, text: output.content[0].text })
content += output.content[0].text
if (output.content[0].annotations && output.content[0].annotations.length > 0) {
onChunk({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
@ -522,8 +591,32 @@ export default class OpenAIProvider extends BaseProvider {
thinking_millsec: new Date().getTime() - start_time_millsec
})
break
case 'function_call':
toolCalls.push(output)
}
}
if (content) {
reqMessages.push({
role: 'assistant',
content: content
})
}
if (toolCalls.length) {
toolCalls.forEach((toolCall) => {
reqMessages.push(toolCall)
})
}
const toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
if (toolCalls.length) {
toolResults.push(...(await processToolCalls(mcpTools, toolCalls)))
}
if (content.length) {
toolResults.push(...(await processToolUses(content)))
}
await processToolResults(toolResults, idx)
onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
@ -534,6 +627,9 @@ export default class OpenAIProvider extends BaseProvider {
return
}
let content = ''
const outputItems: OpenAI.Responses.ResponseOutputItem[] = []
let lastUsage: Usage | undefined = undefined
let final_time_completion_millsec_delta = 0
for await (const chunk of stream as Stream<OpenAI.Responses.ResponseStreamEvent>) {
@ -544,6 +640,12 @@ export default class OpenAIProvider extends BaseProvider {
case 'response.created':
time_first_token_millsec = new Date().getTime()
break
case 'response.output_item.added':
if (chunk.item.type === 'function_call') {
outputItems.push(chunk.item)
}
break
case 'response.reasoning_summary_text.delta':
onChunk({
type: ChunkType.THINKING_DELTA,
@ -571,6 +673,21 @@ export default class OpenAIProvider extends BaseProvider {
text: chunk.text
})
break
case 'response.function_call_arguments.done': {
const outputItem: OpenAI.Responses.ResponseOutputItem | undefined = outputItems.find(
(item) => item.id === chunk.item_id
)
if (outputItem) {
if (outputItem.type === 'function_call') {
toolCalls.push({
...outputItem,
arguments: chunk.arguments
})
}
}
break
}
case 'response.content_part.done':
if (chunk.part.type === 'output_text' && chunk.part.annotations && chunk.part.annotations.length > 0) {
onChunk({
@ -607,9 +724,31 @@ export default class OpenAIProvider extends BaseProvider {
})
break
}
// --- End of Incremental onChunk calls ---
} // End of for await loop
if (content) {
reqMessages.push({
role: 'assistant',
content: content
})
}
if (toolCalls.length) {
toolCalls.forEach((toolCall) => {
reqMessages.push(toolCall)
})
}
await processToolUses(content, idx)
// Call processToolUses AFTER the loop finishes processing the main stream content
// Note: parseAndCallTools inside processToolUses should handle its own onChunk for tool responses
const toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
if (toolCalls.length) {
toolResults.push(...(await processToolCalls(mcpTools, toolCalls)))
}
if (content) {
toolResults.push(...(await processToolUses(content)))
}
await processToolResults(toolResults, idx)
onChunk({
type: ChunkType.BLOCK_COMPLETE,
@ -624,6 +763,7 @@ export default class OpenAIProvider extends BaseProvider {
})
}
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
const stream = await this.sdk.responses.create(
{
model: model.id,
@ -1072,3 +1212,31 @@ export default class OpenAIProvider extends BaseProvider {
return data.data[0].embedding.length
}
}
export default class OpenAIProvider extends BaseOpenAiProvider {
constructor(provider: Provider) {
super(provider)
}
public convertMcpTools<T>(mcpTools: MCPTool[]) {
return mcpToolsToOpenAIResponseTools(mcpTools) as T[]
}
public mcpToolCallResponseToMessage = (
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): OpenAI.Responses.ResponseInputItem | undefined => {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
return mcpToolCallResponseToOpenAIMessage(mcpToolResponse, resp, isVisionModel(model))
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
const toolCallOut: OpenAI.Responses.ResponseInputItem = {
type: 'function_call_output',
call_id: mcpToolResponse.toolCallId,
output: JSON.stringify(resp.content)
}
return toolCallOut
}
return
}
}

View File

@ -107,6 +107,7 @@ export const getAssistantSettings = (assistant: Assistant): AssistantSettings =>
enableMaxTokens: assistant?.settings?.enableMaxTokens ?? false,
maxTokens: getAssistantMaxTokens(),
streamOutput: assistant?.settings?.streamOutput ?? true,
enableToolUse: assistant?.settings?.enableToolUse ?? false,
hideMessages: assistant?.settings?.hideMessages ?? false,
defaultModel: assistant?.defaultModel ?? undefined,
customParameters: assistant?.settings?.customParameters ?? []

View File

@ -423,7 +423,17 @@ const fetchAndProcessAssistantResponseImpl = async (
}
},
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
if (toolResponse.status === 'invoking') {
if (lastBlockType === MessageBlockType.UNKNOWN && lastBlockId) {
lastBlockType = MessageBlockType.TOOL
const changes = {
type: MessageBlockType.TOOL,
status: MessageBlockStatus.PROCESSING,
metadata: { rawMcpToolResponse: toolResponse }
}
dispatch(updateOneBlock({ id: lastBlockId, changes }))
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
toolCallIdToBlockIdMap.set(toolResponse.id, lastBlockId)
} else if (toolResponse.status === 'invoking') {
const toolBlock = createToolBlock(assistantMsgId, toolResponse.id, {
toolName: toolResponse.tool.name,
status: MessageBlockStatus.PROCESSING,

View File

@ -55,6 +55,7 @@ export type AssistantSettings = {
maxTokens: number | undefined
enableMaxTokens: boolean
streamOutput: boolean
enableToolUse: boolean
hideMessages: boolean
defaultModel?: Model
customParameters?: AssistantSettingCustomParameters[]
@ -570,13 +571,25 @@ export interface MCPConfig {
servers: MCPServer[]
}
export interface MCPToolResponse {
id: string // tool call id, it should be unique
tool: MCPTool // tool info
interface BaseToolResponse {
id: string // unique id
tool: MCPTool
arguments: Record<string, unknown> | undefined
status: string // 'invoking' | 'done'
response?: any
}
export interface ToolUseResponse extends BaseToolResponse {
toolUseId: string
}
export interface ToolCallResponse extends BaseToolResponse {
// gemini tool call id might be undefined
toolCallId?: string
}
export type MCPToolResponse = ToolUseResponse | ToolCallResponse
export interface MCPToolResultContent {
type: 'text' | 'image' | 'audio' | 'resource'
text?: string
@ -586,6 +599,7 @@ export interface MCPToolResultContent {
uri?: string
text?: string
mimeType?: string
blob?: string
}
}

View File

@ -1,18 +1,31 @@
import { ContentBlockParam, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
import { MessageParam } from '@anthropic-ai/sdk/resources'
import { Content, FunctionCall, Part } from '@google/genai'
import {
ContentBlockParam,
MessageParam,
ToolResultBlockParam,
ToolUnion,
ToolUseBlock
} from '@anthropic-ai/sdk/resources'
import { Content, FunctionCall, Part, Tool, Type as GeminiSchemaType } from '@google/genai'
import { isVisionModel } from '@renderer/config/models'
import store from '@renderer/store'
import { addMCPServer } from '@renderer/store/mcp'
import { MCPCallToolResponse, MCPServer, MCPTool, MCPToolResponse } from '@renderer/types'
import { MCPCallToolResponse, MCPServer, MCPTool, MCPToolResponse, Model, ToolUseResponse } from '@renderer/types'
import type { MCPToolCompleteChunk, MCPToolInProgressChunk } from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import { isArray, isObject, pull, transform } from 'lodash'
import { nanoid } from 'nanoid'
import OpenAI from 'openai'
import { ChatCompletionContentPart, ChatCompletionMessageParam, ChatCompletionMessageToolCall } from 'openai/resources'
import {
ChatCompletionContentPart,
ChatCompletionMessageParam,
ChatCompletionMessageToolCall,
ChatCompletionTool
} from 'openai/resources'
import { CompletionsParams } from '../providers/AiProvider'
const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install'
const EXTRA_SCHEMA_KEYS = ['schema', 'headers']
// const ensureValidSchema = (obj: Record<string, any>) => {
// // Filter out unsupported keys for Gemini
@ -153,77 +166,116 @@ const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install'
// return processedProperties
// }
// export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
// return mcpTools.map((tool) => ({
// type: 'function',
// name: tool.name,
// function: {
// name: tool.id,
// description: tool.description,
// parameters: {
// type: 'object',
// properties: filterPropertieAttributes(tool)
// }
// }
// }))
// }
export function openAIToolsToMcpTool(
mcpTools: MCPTool[] | undefined,
llmTool: ChatCompletionMessageToolCall
): MCPTool | undefined {
if (!mcpTools) {
return undefined
export function filterProperties(
properties: Record<string, any> | string | number | boolean | Array<Record<string, any> | string | number | boolean>,
supportedKeys: string[]
) {
// If it is an array, recursively process each element
if (isArray(properties)) {
return properties.map((item) => filterProperties(item, supportedKeys))
}
const tool = mcpTools.find(
(mcptool) => mcptool.id === llmTool.function.name || mcptool.name === llmTool.function.name
)
// If it is an object, recursively process each property
if (isObject(properties)) {
return transform(
properties,
(result, value, key) => {
if (key === 'properties') {
result[key] = transform(value, (acc, v, k) => {
acc[k] = filterProperties(v, supportedKeys)
})
if (!tool) {
console.warn('No MCP Tool found for tool call:', llmTool)
return undefined
result['additionalProperties'] = false
result['required'] = pull(Object.keys(value), ...EXTRA_SCHEMA_KEYS)
} else if (key === 'oneOf') {
// openai only supports anyOf
result['anyOf'] = filterProperties(value, supportedKeys)
} else if (supportedKeys.includes(key)) {
result[key] = filterProperties(value, supportedKeys)
if (key === 'type' && value === 'object') {
result['additionalProperties'] = false
}
}
},
{}
)
}
console.log(
`[MCP] OpenAI Tool to MCP Tool: ${tool.serverName} ${tool.name}`,
tool,
'args',
llmTool.function.arguments
)
// use this to parse the arguments and avoid parsing errors
let args: any = {}
try {
args = JSON.parse(llmTool.function.arguments)
} catch (e) {
console.error('Error parsing arguments', e)
}
return {
id: tool.id,
serverId: tool.serverId,
serverName: tool.serverName,
name: tool.name,
description: tool.description,
inputSchema: args
}
// Return other types directly (e.g., string, number, etc.)
return properties
}
export async function callMCPTool(tool: MCPTool): Promise<MCPCallToolResponse> {
console.log(`[MCP] Calling Tool: ${tool.serverName} ${tool.name}`, tool)
export function mcpToolsToOpenAIResponseTools(mcpTools: MCPTool[]): OpenAI.Responses.Tool[] {
const schemaKeys = ['type', 'description', 'items', 'enum', 'additionalProperties', 'anyof']
return mcpTools.map(
(tool) =>
({
type: 'function',
name: tool.id,
parameters: {
type: 'object',
properties: filterProperties(tool.inputSchema, schemaKeys).properties,
required: pull(Object.keys(tool.inputSchema.properties), ...EXTRA_SCHEMA_KEYS),
additionalProperties: false
},
strict: true
}) satisfies OpenAI.Responses.Tool
)
}
export function mcpToolsToOpenAIChatTools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
return mcpTools.map(
(tool) =>
({
type: 'function',
function: {
name: tool.id,
description: tool.description,
parameters: {
type: 'object',
properties: tool.inputSchema.properties,
required: tool.inputSchema.required
}
}
}) as ChatCompletionTool
)
}
export function openAIToolsToMcpTool(
mcpTools: MCPTool[],
toolCall: OpenAI.Responses.ResponseFunctionToolCall | ChatCompletionMessageToolCall
): MCPTool | undefined {
const tool = mcpTools.find((mcpTool) => {
if ('name' in toolCall) {
return mcpTool.id === toolCall.name || mcpTool.name === toolCall.name
} else {
return mcpTool.id === toolCall.function.name || mcpTool.name === toolCall.function.name
}
})
if (!tool) {
console.warn('No MCP Tool found for tool call:', toolCall)
return undefined
}
return tool
}
export async function callMCPTool(toolResponse: MCPToolResponse): Promise<MCPCallToolResponse> {
console.log(`[MCP] Calling Tool: ${toolResponse.tool.serverName} ${toolResponse.tool.name}`, toolResponse.tool)
try {
const server = getMcpServerByTool(tool)
const server = getMcpServerByTool(toolResponse.tool)
if (!server) {
throw new Error(`Server not found: ${tool.serverName}`)
throw new Error(`Server not found: ${toolResponse.tool.serverName}`)
}
const resp = await window.api.mcp.callTool({
server,
name: tool.name,
args: tool.inputSchema
name: toolResponse.tool.name,
args: toolResponse.arguments
})
if (tool.serverName === MCP_AUTO_INSTALL_SERVER_NAME) {
if (toolResponse.tool.serverName === MCP_AUTO_INSTALL_SERVER_NAME) {
if (resp.data) {
const mcpServer: MCPServer = {
id: `f${nanoid()}`,
@ -241,16 +293,16 @@ export async function callMCPTool(tool: MCPTool): Promise<MCPCallToolResponse> {
}
}
console.log(`[MCP] Tool called: ${tool.serverName} ${tool.name}`, resp)
console.log(`[MCP] Tool called: ${toolResponse.tool.serverName} ${toolResponse.tool.name}`, resp)
return resp
} catch (e) {
console.error(`[MCP] Error calling Tool: ${tool.serverName} ${tool.name}`, e)
console.error(`[MCP] Error calling Tool: ${toolResponse.tool.serverName} ${toolResponse.tool.name}`, e)
return Promise.resolve({
isError: true,
content: [
{
type: 'text',
text: `Error calling tool ${tool.name}: ${e instanceof Error ? e.stack || e.message || 'No error details available' : JSON.stringify(e)}`
text: `Error calling tool ${toolResponse.tool.name}: ${e instanceof Error ? e.stack || e.message || 'No error details available' : JSON.stringify(e)}`
}
]
})
@ -262,7 +314,7 @@ export function mcpToolsToAnthropicTools(mcpTools: MCPTool[]): Array<ToolUnion>
const t: ToolUnion = {
name: tool.id,
description: tool.description,
// @ts-ignore no check
// @ts-ignore ignore type as it it unknow
input_schema: tool.inputSchema
}
return t
@ -275,53 +327,68 @@ export function anthropicToolUseToMcpTool(mcpTools: MCPTool[] | undefined, toolU
if (!tool) {
return undefined
}
// @ts-ignore ignore type as it it unknow
tool.inputSchema = toolUse.input
return tool
}
// export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTool[] {
// if (!mcpTools || mcpTools.length === 0) {
// // No tools available
// return []
// }
// const functions: FunctionDeclaration[] = []
// for (const tool of mcpTools) {
// const properties = filterPropertieAttributes(tool, true)
// const functionDeclaration: FunctionDeclaration = {
// name: tool.id,
// description: tool.description,
// parameters: {
// type: SchemaType.OBJECT,
// properties:
// Object.keys(properties).length > 0
// ? Object.fromEntries(
// Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
// )
// : { _empty: { type: SchemaType.STRING } as SimpleStringSchema }
// } as FunctionDeclarationSchema
// }
// functions.push(functionDeclaration)
// }
// const tool: geminiTool = {
// functionDeclarations: functions
// }
// return [tool]
// }
/**
* @param mcpTools
* @returns
*/
export function mcpToolsToGeminiTools(mcpTools: MCPTool[]): Tool[] {
/**
* @typedef {import('@google/genai').Schema} Schema
*/
const schemaKeys = [
'example',
'pattern',
'default',
'maxLength',
'minLength',
'minProperties',
'maxProperties',
'anyOf',
'description',
'enum',
'format',
'items',
'maxItems',
'maximum',
'minItems',
'minimum',
'nullable',
'properties',
'propertyOrdering',
'required',
'title',
'type'
]
return [
{
functionDeclarations: mcpTools?.map((tool) => {
return {
name: tool.id,
description: tool.description,
parameters: {
type: GeminiSchemaType.OBJECT,
properties: filterProperties(tool.inputSchema, schemaKeys).properties,
required: tool.inputSchema.required
}
}
})
}
]
}
export function geminiFunctionCallToMcpTool(
mcpTools: MCPTool[] | undefined,
fcall: FunctionCall | undefined
toolCall: FunctionCall | undefined
): MCPTool | undefined {
if (!fcall) return undefined
if (!toolCall) return undefined
if (!mcpTools) return undefined
const tool = mcpTools.find((tool) => tool.id === fcall.name)
const tool = mcpTools.find((tool) => tool.id === toolCall.name)
if (!tool) {
return undefined
}
// @ts-ignore schema is not a valid property
tool.inputSchema = fcall.args
return tool
}
@ -368,13 +435,13 @@ export function getMcpServerByTool(tool: MCPTool) {
return servers.find((s) => s.id === tool.serverId)
}
export function parseToolUse(content: string, mcpTools: MCPTool[]): MCPToolResponse[] {
export function parseToolUse(content: string, mcpTools: MCPTool[]): ToolUseResponse[] {
if (!content || !mcpTools || mcpTools.length === 0) {
return []
}
const toolUsePattern =
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
const tools: MCPToolResponse[] = []
const tools: ToolUseResponse[] = []
let match
let idx = 0
// Find all tool use blocks
@ -401,10 +468,9 @@ export function parseToolUse(content: string, mcpTools: MCPTool[]): MCPToolRespo
// Add to tools array
tools.push({
id: `${toolName}-${idx++}`, // Unique ID for each tool use
tool: {
...mcpTool,
inputSchema: parsedArgs
},
toolUseId: mcpTool.id,
tool: mcpTool,
arguments: parsedArgs,
status: 'pending'
})
@ -414,36 +480,69 @@ export function parseToolUse(content: string, mcpTools: MCPTool[]): MCPToolRespo
return tools
}
export async function parseAndCallTools(
content: string,
toolResponses: MCPToolResponse[],
export async function parseAndCallTools<R>(
tools: MCPToolResponse[],
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
idx: number,
convertToMessage: (
toolCallId: string,
resp: MCPCallToolResponse,
isVisionModel: boolean
) => ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.EasyInputMessage,
mcpTools?: MCPTool[],
isVisionModel: boolean = false
): Promise<(ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.EasyInputMessage)[]> {
const toolResults: (ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.EasyInputMessage)[] = []
// process tool use
const tools = parseToolUse(content, mcpTools || [])
if (!tools || tools.length === 0) {
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
model: Model,
mcpTools?: MCPTool[]
): Promise<
(ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.ResponseInputItem | ToolResultBlockParam)[]
>
export async function parseAndCallTools<R>(
content: string,
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
model: Model,
mcpTools?: MCPTool[]
): Promise<
(ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.ResponseInputItem | ToolResultBlockParam)[]
>
export async function parseAndCallTools<R>(
content: string | MCPToolResponse[],
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
model: Model,
mcpTools?: MCPTool[]
): Promise<R[]> {
const toolResults: R[] = []
let curToolResponses: MCPToolResponse[] = []
if (Array.isArray(content)) {
curToolResponses = content
} else {
// process tool use
curToolResponses = parseToolUse(content, mcpTools || [])
}
if (!curToolResponses || curToolResponses.length === 0) {
return toolResults
}
for (let i = 0; i < tools.length; i++) {
const tool = tools[i]
upsertMCPToolResponse(toolResponses, { id: `${tool.id}-${idx}-${i}`, tool: tool.tool, status: 'invoking' }, onChunk)
for (let i = 0; i < curToolResponses.length; i++) {
const toolResponse = curToolResponses[i]
upsertMCPToolResponse(
allToolResponses,
{
...toolResponse,
status: 'invoking'
},
onChunk
)
}
const toolPromises = tools.map(async (tool, i) => {
const toolPromises = curToolResponses.map(async (toolResponse) => {
const images: string[] = []
const toolCallResponse = await callMCPTool(tool.tool)
const toolCallResponse = await callMCPTool(toolResponse)
upsertMCPToolResponse(
toolResponses,
{ id: `${tool.id}-${idx}-${i}`, tool: tool.tool, status: 'done', response: toolCallResponse },
allToolResponses,
{
...toolResponse,
status: 'done',
response: toolCallResponse
},
onChunk
)
@ -466,15 +565,15 @@ export async function parseAndCallTools(
})
}
return convertToMessage(tool.tool.id, toolCallResponse, isVisionModel)
return convertToMessage(toolResponse, toolCallResponse, model)
})
toolResults.push(...(await Promise.all(toolPromises)))
toolResults.push(...(await Promise.all(toolPromises)).filter((t) => typeof t !== 'undefined'))
return toolResults
}
export function mcpToolCallResponseToOpenAICompatibleMessage(
toolCallId: string,
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
isVisionModel: boolean = false
): ChatCompletionMessageParam {
@ -488,7 +587,7 @@ export function mcpToolCallResponseToOpenAICompatibleMessage(
const content: ChatCompletionContentPart[] = [
{
type: 'text',
text: `Here is the result of tool call ${toolCallId}:`
text: `Here is the result of mcp tool use \`${mcpToolResponse.tool.name}\`:`
}
]
@ -541,7 +640,7 @@ export function mcpToolCallResponseToOpenAICompatibleMessage(
}
export function mcpToolCallResponseToOpenAIMessage(
toolCallId: string,
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
isVisionModel: boolean = false
): OpenAI.Responses.EasyInputMessage {
@ -555,7 +654,7 @@ export function mcpToolCallResponseToOpenAIMessage(
const content: OpenAI.Responses.ResponseInputContent[] = [
{
type: 'input_text',
text: `Here is the result of tool call ${toolCallId}:`
text: `Here is the result of mcp tool use \`${mcpToolResponse.tool.name}\`:`
}
]
@ -597,9 +696,9 @@ export function mcpToolCallResponseToOpenAIMessage(
}
export function mcpToolCallResponseToAnthropicMessage(
toolCallId: string,
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
isVisionModel: boolean = false
model: Model
): MessageParam {
const message = {
role: 'user'
@ -610,10 +709,10 @@ export function mcpToolCallResponseToAnthropicMessage(
const content: ContentBlockParam[] = [
{
type: 'text',
text: `Here is the result of tool call ${toolCallId}:`
text: `Here is the result of mcp tool use \`${mcpToolResponse.tool.name}\`:`
}
]
if (isVisionModel) {
if (isVisionModel(model)) {
for (const item of resp.content) {
switch (item.type) {
case 'text':
@ -665,7 +764,7 @@ export function mcpToolCallResponseToAnthropicMessage(
}
export function mcpToolCallResponseToGeminiMessage(
toolCallId: string,
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
isVisionModel: boolean = false
): Content {
@ -682,7 +781,7 @@ export function mcpToolCallResponseToGeminiMessage(
} else {
const parts: Part[] = [
{
text: `Here is the result of tool call ${toolCallId}:`
text: `Here is the result of mcp tool use \`${mcpToolResponse.tool.name}\`:`
}
]
if (isVisionModel) {

View File

@ -147,7 +147,7 @@ ${availableTools}
</tools>`
}
export const buildSystemPrompt = (userSystemPrompt: string, tools: MCPTool[]): string => {
export const buildSystemPrompt = (userSystemPrompt: string, tools?: MCPTool[]): string => {
if (tools && tools.length > 0) {
return SYSTEM_PROMPT.replace('{{ USER_SYSTEM_PROMPT }}', userSystemPrompt)
.replace('{{ TOOL_USE_EXAMPLES }}', ToolUseExamples)