diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index a98ca3414a..efc9ba992f 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -156,7 +156,7 @@ export default class ModernAiProvider { config: ModernAiProviderConfig ): Promise { // ai-gateway不是image/generation 端点,所以就先不走legacy了 - if (config.isImageGenerationEndpoint && config.provider!.id !== SystemProviderIds['ai-gateway']) { + if (config.isImageGenerationEndpoint && this.getActualProvider().id !== SystemProviderIds['ai-gateway']) { // 使用 legacy 实现处理图像生成(支持图片编辑等高级功能) if (!config.uiMessages) { throw new Error('uiMessages is required for image generation endpoint') diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts index 0b89f55b16..b314ddd737 100644 --- a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts @@ -1,6 +1,6 @@ import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins' import { loggerService } from '@logger' -import { isSupportedThinkingTokenQwenModel } from '@renderer/config/models' +import { isGemini3Model, isSupportedThinkingTokenQwenModel } from '@renderer/config/models' import type { MCPTool } from '@renderer/types' import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' @@ -9,11 +9,13 @@ import type { LanguageModelMiddleware } from 'ai' import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai' import { isEmpty } from 'lodash' +import { getAiSdkProviderId } from '../provider/factory' import { isOpenRouterGeminiGenerateImageModel } from '../utils/image' import { noThinkMiddleware } from './noThinkMiddleware' import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware' import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware' import { qwenThinkingMiddleware } from './qwenThinkingMiddleware' +import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware' import { toolChoiceMiddleware } from './toolChoiceMiddleware' const logger = loggerService.withContext('AiSdkMiddlewareBuilder') @@ -257,6 +259,15 @@ function addModelSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: Ai middleware: openrouterGenerateImageMiddleware() }) } + + if (isGemini3Model(config.model)) { + const aiSdkId = getAiSdkProviderId(config.provider) + builder.add({ + name: 'skip-gemini3-thought-signature', + middleware: skipGeminiThoughtSignatureMiddleware(aiSdkId) + }) + logger.debug('Added skip Gemini3 thought signature middleware') + } } /** diff --git a/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts b/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts new file mode 100644 index 0000000000..da318ea60d --- /dev/null +++ b/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts @@ -0,0 +1,36 @@ +import type { LanguageModelMiddleware } from 'ai' + +/** + * skip Gemini Thought Signature Middleware + * 由于多模型客户端请求的复杂性(可以中途切换其他模型),这里选择通过中间件方式添加跳过所有 Gemini3 思考签名 + * Due to the complexity of multi-model client requests (which can switch to other models mid-process), + * it was decided to add a skip for all Gemini3 thinking signatures via middleware. + * @param aiSdkId AI SDK Provider ID + * @returns LanguageModelMiddleware + */ +export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware { + const MAGIC_STRING = 'skip_thought_signature_validator' + return { + middlewareVersion: 'v2', + + transformParams: async ({ params }) => { + const transformedParams = { ...params } + // Process messages in prompt + if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) { + transformedParams.prompt = transformedParams.prompt.map((message) => { + if (typeof message.content !== 'string') { + for (const part of message.content) { + const googleOptions = part?.providerOptions?.[aiSdkId] + if (googleOptions?.thoughtSignature) { + googleOptions.thoughtSignature = MAGIC_STRING + } + } + } + return message + }) + } + + return transformedParams + } + } +} diff --git a/src/renderer/src/aiCore/prepareParams/__tests__/message-converter.test.ts b/src/renderer/src/aiCore/prepareParams/__tests__/message-converter.test.ts index 2e7ae522cc..2433192cd0 100644 --- a/src/renderer/src/aiCore/prepareParams/__tests__/message-converter.test.ts +++ b/src/renderer/src/aiCore/prepareParams/__tests__/message-converter.test.ts @@ -180,6 +180,10 @@ describe('messageConverter', () => { const result = await convertMessagesToSdkMessages([initialUser, assistant, finalUser], model) expect(result).toEqual([ + { + role: 'user', + content: [{ type: 'text', text: 'Start editing' }] + }, { role: 'assistant', content: [{ type: 'text', text: 'Here is the current preview' }] @@ -217,6 +221,7 @@ describe('messageConverter', () => { expect(result).toEqual([ { role: 'system', content: 'fileid://reference' }, + { role: 'user', content: [{ type: 'text', text: 'Use this document as inspiration' }] }, { role: 'assistant', content: [{ type: 'text', text: 'Generated previews ready' }] diff --git a/src/renderer/src/aiCore/prepareParams/messageConverter.ts b/src/renderer/src/aiCore/prepareParams/messageConverter.ts index 72f387d9a4..b0c432ef85 100644 --- a/src/renderer/src/aiCore/prepareParams/messageConverter.ts +++ b/src/renderer/src/aiCore/prepareParams/messageConverter.ts @@ -194,20 +194,20 @@ async function convertMessageToAssistantModelMessage( * This function processes messages and transforms them into the format required by the SDK. * It handles special cases for vision models and image enhancement models. * - * @param messages - Array of messages to convert. Must contain at least 2 messages when using image enhancement models. + * @param messages - Array of messages to convert. Must contain at least 3 messages when using image enhancement models for special handling. * @param model - The model configuration that determines conversion behavior * * @returns A promise that resolves to an array of SDK-compatible model messages * * @remarks - * For image enhancement models with 2+ messages: - * - Expects the second-to-last message (index length-2) to be an assistant message containing image blocks - * - Expects the last message (index length-1) to be a user message - * - Extracts images from the assistant message and appends them to the user message content - * - Returns only the last two processed messages [assistantSdkMessage, userSdkMessage] + * For image enhancement models with 3+ messages: + * - Examines the last 2 messages to find an assistant message containing image blocks + * - If found, extracts images from the assistant message and appends them to the last user message content + * - Returns all converted messages (not just the last two) with the images merged into the user message + * - Typical pattern: [system?, assistant(image), user] -> [system?, assistant, user(image)] * * For other models: - * - Returns all converted messages in order + * - Returns all converted messages in order without special image handling * * The function automatically detects vision model capabilities and adjusts conversion accordingly. */ @@ -220,29 +220,25 @@ export async function convertMessagesToSdkMessages(messages: Message[], model: M sdkMessages.push(...(Array.isArray(sdkMessage) ? sdkMessage : [sdkMessage])) } // Special handling for image enhancement models - // Only keep the last two messages and merge images into the user message - // [system?, user, assistant, user] + // Only merge images into the user message + // [system?, assistant(image), user] -> [system?, assistant, user(image)] if (isImageEnhancementModel(model) && messages.length >= 3) { const needUpdatedMessages = messages.slice(-2) - const needUpdatedSdkMessages = sdkMessages.slice(-2) - const assistantMessage = needUpdatedMessages.filter((m) => m.role === 'assistant')[0] - const assistantSdkMessage = needUpdatedSdkMessages.filter((m) => m.role === 'assistant')[0] - const userSdkMessage = needUpdatedSdkMessages.filter((m) => m.role === 'user')[0] - const systemSdkMessages = sdkMessages.filter((m) => m.role === 'system') - const imageBlocks = findImageBlocks(assistantMessage) - const imageParts = await convertImageBlockToImagePart(imageBlocks) - const parts: Array = [] - if (typeof userSdkMessage.content === 'string') { - parts.push({ type: 'text', text: userSdkMessage.content }) - parts.push(...imageParts) - userSdkMessage.content = parts - } else { - userSdkMessage.content.push(...imageParts) + const assistantMessage = needUpdatedMessages.find((m) => m.role === 'assistant') + const userSdkMessage = sdkMessages[sdkMessages.length - 1] + + if (assistantMessage && userSdkMessage?.role === 'user') { + const imageBlocks = findImageBlocks(assistantMessage) + const imageParts = await convertImageBlockToImagePart(imageBlocks) + + if (imageParts.length > 0) { + if (typeof userSdkMessage.content === 'string') { + userSdkMessage.content = [{ type: 'text', text: userSdkMessage.content }, ...imageParts] + } else if (Array.isArray(userSdkMessage.content)) { + userSdkMessage.content.push(...imageParts) + } + } } - if (systemSdkMessages.length > 0) { - return [systemSdkMessages[0], assistantSdkMessage, userSdkMessage] - } - return [assistantSdkMessage, userSdkMessage] } return sdkMessages diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts index f1fc61aacd..4fb6f07e1f 100644 --- a/src/renderer/src/aiCore/utils/options.ts +++ b/src/renderer/src/aiCore/utils/options.ts @@ -91,9 +91,21 @@ function getServiceTier(model: Model, provider: T): OpenAISe } } -function getVerbosity(): OpenAIVerbosity { +function getVerbosity(model: Model): OpenAIVerbosity { + if (!isSupportVerbosityModel(model) || !isSupportVerbosityProvider(getProviderById(model.provider)!)) { + return undefined + } const openAI = getStoreSetting('openAI') - return openAI.verbosity + + const userVerbosity = openAI.verbosity + + if (userVerbosity) { + const supportedVerbosity = getModelSupportedVerbosity(model) + // Use user's verbosity if supported, otherwise use the first supported option + const verbosity = supportedVerbosity.includes(userVerbosity) ? userVerbosity : supportedVerbosity[0] + return verbosity + } + return undefined } /** @@ -148,7 +160,7 @@ export function buildProviderOptions( // 构建 provider 特定的选项 let providerSpecificOptions: Record = {} const serviceTier = getServiceTier(model, actualProvider) - const textVerbosity = getVerbosity() + const textVerbosity = getVerbosity(model) // 根据 provider 类型分离构建逻辑 const { data: baseProviderId, success } = baseProviderIdSchema.safeParse(rawProviderId) if (success) { @@ -163,7 +175,8 @@ export function buildProviderOptions( assistant, model, capabilities, - serviceTier + serviceTier, + textVerbosity ) providerSpecificOptions = options } @@ -196,7 +209,8 @@ export function buildProviderOptions( model, capabilities, actualProvider, - serviceTier + serviceTier, + textVerbosity ) break default: @@ -255,7 +269,7 @@ export function buildProviderOptions( }[rawProviderId] || rawProviderId if (rawProviderKey === 'cherryin') { - rawProviderKey = { gemini: 'google' }[actualProvider.type] || actualProvider.type + rawProviderKey = { gemini: 'google', ['openai-response']: 'openai' }[actualProvider.type] || actualProvider.type } // 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数 @@ -278,7 +292,8 @@ function buildOpenAIProviderOptions( enableWebSearch: boolean enableGenerateImage: boolean }, - serviceTier: OpenAIServiceTier + serviceTier: OpenAIServiceTier, + textVerbosity?: OpenAIVerbosity ): OpenAIResponsesProviderOptions { const { enableReasoning } = capabilities let providerOptions: OpenAIResponsesProviderOptions = {} @@ -314,7 +329,8 @@ function buildOpenAIProviderOptions( providerOptions = { ...providerOptions, - serviceTier + serviceTier, + textVerbosity } return providerOptions @@ -413,11 +429,13 @@ function buildCherryInProviderOptions( enableGenerateImage: boolean }, actualProvider: Provider, - serviceTier: OpenAIServiceTier + serviceTier: OpenAIServiceTier, + textVerbosity: OpenAIVerbosity ): OpenAIResponsesProviderOptions | AnthropicProviderOptions | GoogleGenerativeAIProviderOptions { switch (actualProvider.type) { case 'openai': - return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier) + case 'openai-response': + return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier, textVerbosity) case 'anthropic': return buildAnthropicProviderOptions(assistant, model, capabilities) diff --git a/src/renderer/src/aiCore/utils/reasoning.ts b/src/renderer/src/aiCore/utils/reasoning.ts index 8e4112c8d0..6223f8ecc8 100644 --- a/src/renderer/src/aiCore/utils/reasoning.ts +++ b/src/renderer/src/aiCore/utils/reasoning.ts @@ -12,7 +12,7 @@ import { isDeepSeekHybridInferenceModel, isDoubaoSeedAfter251015, isDoubaoThinkingAutoModel, - isGemini3Model, + isGemini3ThinkingTokenModel, isGPT51SeriesModel, isGrok4FastReasoningModel, isGrokReasoningModel, @@ -36,7 +36,7 @@ import { } from '@renderer/config/models' import { getStoreSetting } from '@renderer/hooks/useSettings' import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService' -import type { Assistant, Model, ReasoningEffortOption } from '@renderer/types' +import type { Assistant, Model } from '@renderer/types' import { EFFORT_RATIO, isSystemProvider, SystemProviderIds } from '@renderer/types' import type { OpenAISummaryText } from '@renderer/types/aiCoreTypes' import type { ReasoningEffortOptionalParams } from '@renderer/types/sdk' @@ -281,7 +281,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin // gemini series, openai compatible api if (isSupportedThinkingTokenGeminiModel(model)) { // https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#openai_compatibility - if (isGemini3Model(model)) { + if (isGemini3ThinkingTokenModel(model)) { return { reasoning_effort: reasoningEffort } @@ -465,20 +465,20 @@ export function getAnthropicReasoningParams( return {} } -type GoogelThinkingLevel = NonNullable['thinkingLevel'] +// type GoogleThinkingLevel = NonNullable['thinkingLevel'] -function mapToGeminiThinkingLevel(reasoningEffort: ReasoningEffortOption): GoogelThinkingLevel { - switch (reasoningEffort) { - case 'low': - return 'low' - case 'medium': - return 'medium' - case 'high': - return 'high' - default: - return 'medium' - } -} +// function mapToGeminiThinkingLevel(reasoningEffort: ReasoningEffortOption): GoogelThinkingLevel { +// switch (reasoningEffort) { +// case 'low': +// return 'low' +// case 'medium': +// return 'medium' +// case 'high': +// return 'high' +// default: +// return 'medium' +// } +// } /** * 获取 Gemini 推理参数 @@ -507,14 +507,15 @@ export function getGeminiReasoningParams( } } + // TODO: 很多中转还不支持 // https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#new_api_features_in_gemini_3 - if (isGemini3Model(model)) { - return { - thinkingConfig: { - thinkingLevel: mapToGeminiThinkingLevel(reasoningEffort) - } - } - } + // if (isGemini3ThinkingTokenModel(model)) { + // return { + // thinkingConfig: { + // thinkingLevel: mapToGeminiThinkingLevel(reasoningEffort) + // } + // } + // } const effortRatio = EFFORT_RATIO[reasoningEffort] diff --git a/src/renderer/src/config/models/__tests__/reasoning.test.ts b/src/renderer/src/config/models/__tests__/reasoning.test.ts index 0f2b6dfa77..cfbf50e8c1 100644 --- a/src/renderer/src/config/models/__tests__/reasoning.test.ts +++ b/src/renderer/src/config/models/__tests__/reasoning.test.ts @@ -33,6 +33,7 @@ import { MODEL_SUPPORTED_OPTIONS, MODEL_SUPPORTED_REASONING_EFFORT } from '../reasoning' +import { isGemini3ThinkingTokenModel } from '../utils' import { isTextToImageModel } from '../vision' vi.mock('@renderer/store', () => ({ @@ -955,7 +956,7 @@ describe('Gemini Models', () => { provider: '', group: '' }) - ).toBe(true) + ).toBe(false) expect( isSupportedThinkingTokenGeminiModel({ id: 'gemini-3.0-flash-image-preview', @@ -963,7 +964,7 @@ describe('Gemini Models', () => { provider: '', group: '' }) - ).toBe(true) + ).toBe(false) expect( isSupportedThinkingTokenGeminiModel({ id: 'gemini-3.5-pro-image-preview', @@ -971,7 +972,7 @@ describe('Gemini Models', () => { provider: '', group: '' }) - ).toBe(true) + ).toBe(false) }) it('should return false for gemini-2.x image models', () => { @@ -1163,7 +1164,7 @@ describe('Gemini Models', () => { provider: '', group: '' }) - ).toBe(true) + ).toBe(false) expect( isGeminiReasoningModel({ id: 'gemini-3.5-flash-image-preview', @@ -1171,7 +1172,7 @@ describe('Gemini Models', () => { provider: '', group: '' }) - ).toBe(true) + ).toBe(false) }) it('should return false for older gemini models without thinking', () => { @@ -1230,3 +1231,153 @@ describe('findTokenLimit', () => { expect(findTokenLimit('unknown-model')).toBeUndefined() }) }) + +describe('isGemini3ThinkingTokenModel', () => { + it('should return true for Gemini 3 non-image models', () => { + expect( + isGemini3ThinkingTokenModel({ + id: 'gemini-3-flash', + name: '', + provider: '', + group: '' + }) + ).toBe(true) + expect( + isGemini3ThinkingTokenModel({ + id: 'gemini-3-pro', + name: '', + provider: '', + group: '' + }) + ).toBe(true) + expect( + isGemini3ThinkingTokenModel({ + id: 'gemini-3-pro-preview', + name: '', + provider: '', + group: '' + }) + ).toBe(true) + expect( + isGemini3ThinkingTokenModel({ + id: 'google/gemini-3-flash', + name: '', + provider: '', + group: '' + }) + ).toBe(true) + expect( + isGemini3ThinkingTokenModel({ + id: 'gemini-3.0-flash', + name: '', + provider: '', + group: '' + }) + ).toBe(true) + expect( + isGemini3ThinkingTokenModel({ + id: 'gemini-3.5-pro-preview', + name: '', + provider: '', + group: '' + }) + ).toBe(true) + }) + + it('should return false for Gemini 3 image models', () => { + expect( + isGemini3ThinkingTokenModel({ + id: 'gemini-3-flash-image', + name: '', + provider: '', + group: '' + }) + ).toBe(false) + expect( + isGemini3ThinkingTokenModel({ + id: 'gemini-3-pro-image-preview', + name: '', + provider: '', + group: '' + }) + ).toBe(false) + expect( + isGemini3ThinkingTokenModel({ + id: 'gemini-3.0-flash-image-preview', + name: '', + provider: '', + group: '' + }) + ).toBe(false) + expect( + isGemini3ThinkingTokenModel({ + id: 'gemini-3.5-pro-image-preview', + name: '', + provider: '', + group: '' + }) + ).toBe(false) + }) + + it('should return false for non-Gemini 3 models', () => { + expect( + isGemini3ThinkingTokenModel({ + id: 'gemini-2.5-flash', + name: '', + provider: '', + group: '' + }) + ).toBe(false) + expect( + isGemini3ThinkingTokenModel({ + id: 'gemini-1.5-pro', + name: '', + provider: '', + group: '' + }) + ).toBe(false) + expect( + isGemini3ThinkingTokenModel({ + id: 'gpt-4', + name: '', + provider: '', + group: '' + }) + ).toBe(false) + expect( + isGemini3ThinkingTokenModel({ + id: 'claude-3-opus', + name: '', + provider: '', + group: '' + }) + ).toBe(false) + }) + + it('should handle case insensitivity', () => { + expect( + isGemini3ThinkingTokenModel({ + id: 'Gemini-3-Flash', + name: '', + provider: '', + group: '' + }) + ).toBe(true) + expect( + isGemini3ThinkingTokenModel({ + id: 'GEMINI-3-PRO', + name: '', + provider: '', + group: '' + }) + ).toBe(true) + expect( + isGemini3ThinkingTokenModel({ + id: 'Gemini-3-Pro-Image', + name: '', + provider: '', + group: '' + }) + ).toBe(false) + }) +}) diff --git a/src/renderer/src/config/models/reasoning.ts b/src/renderer/src/config/models/reasoning.ts index 42b862ed5a..ef54794dcb 100644 --- a/src/renderer/src/config/models/reasoning.ts +++ b/src/renderer/src/config/models/reasoning.ts @@ -16,7 +16,7 @@ import { isOpenAIReasoningModel, isSupportedReasoningEffortOpenAIModel } from './openai' -import { GEMINI_FLASH_MODEL_REGEX, isGemini3Model } from './utils' +import { GEMINI_FLASH_MODEL_REGEX, isGemini3ThinkingTokenModel } from './utils' import { isTextToImageModel } from './vision' // Reasoning models @@ -115,7 +115,7 @@ const _getThinkModelType = (model: Model): ThinkingModelType => { } else { thinkingModelType = 'gemini_pro' } - if (isGemini3Model(model)) { + if (isGemini3ThinkingTokenModel(model)) { thinkingModelType = 'gemini3' } } else if (isSupportedReasoningEffortGrokModel(model)) thinkingModelType = 'grok' @@ -271,14 +271,6 @@ export const GEMINI_THINKING_MODEL_REGEX = export const isSupportedThinkingTokenGeminiModel = (model: Model): boolean => { const modelId = getLowerBaseModelName(model.id, '/') if (GEMINI_THINKING_MODEL_REGEX.test(modelId)) { - // gemini-3.x 的 image 模型支持思考模式 - if (isGemini3Model(model)) { - if (modelId.includes('tts')) { - return false - } - return true - } - // gemini-2.x 的 image/tts 模型不支持 if (modelId.includes('image') || modelId.includes('tts')) { return false } diff --git a/src/renderer/src/config/models/tooluse.ts b/src/renderer/src/config/models/tooluse.ts index 7b3b09d2c1..7f90df5f7b 100644 --- a/src/renderer/src/config/models/tooluse.ts +++ b/src/renderer/src/config/models/tooluse.ts @@ -43,7 +43,8 @@ const FUNCTION_CALLING_EXCLUDED_MODELS = [ 'gpt-5-chat(?:-[\\w-]+)?', 'glm-4\\.5v', 'gemini-2.5-flash-image(?:-[\\w-]+)?', - 'gemini-2.0-flash-preview-image-generation' + 'gemini-2.0-flash-preview-image-generation', + 'gemini-3(?:\\.\\d+)?-pro-image(?:-[\\w-]+)?' ] export const FUNCTION_CALLING_REGEX = new RegExp( diff --git a/src/renderer/src/config/models/utils.ts b/src/renderer/src/config/models/utils.ts index 61b994a916..1d5c9a6443 100644 --- a/src/renderer/src/config/models/utils.ts +++ b/src/renderer/src/config/models/utils.ts @@ -164,3 +164,8 @@ export const isGemini3Model = (model: Model) => { const modelId = getLowerBaseModelName(model.id) return modelId.includes('gemini-3') } + +export const isGemini3ThinkingTokenModel = (model: Model) => { + const modelId = getLowerBaseModelName(model.id) + return isGemini3Model(model) && !modelId.includes('image') +} diff --git a/src/renderer/src/pages/home/Inputbar/tools/components/WebSearchQuickPanelManager.tsx b/src/renderer/src/pages/home/Inputbar/tools/components/WebSearchQuickPanelManager.tsx index 21300d8fd9..5728887af8 100644 --- a/src/renderer/src/pages/home/Inputbar/tools/components/WebSearchQuickPanelManager.tsx +++ b/src/renderer/src/pages/home/Inputbar/tools/components/WebSearchQuickPanelManager.tsx @@ -4,6 +4,7 @@ import { BingLogo, BochaLogo, ExaLogo, SearXNGLogo, TavilyLogo, ZhipuLogo } from import type { QuickPanelListItem } from '@renderer/components/QuickPanel' import { QuickPanelReservedSymbol } from '@renderer/components/QuickPanel' import { + isFunctionCallingModel, isGeminiModel, isGPT5SeriesReasoningModel, isOpenAIWebSearchModel, @@ -18,6 +19,7 @@ import WebSearchService from '@renderer/services/WebSearchService' import type { WebSearchProvider, WebSearchProviderId } from '@renderer/types' import { hasObjectKey } from '@renderer/utils' import { isToolUseModeFunction } from '@renderer/utils/assistant' +import { isPromptToolUse } from '@renderer/utils/mcp-tools' import { isGeminiWebSearchProvider } from '@renderer/utils/provider' import { Globe } from 'lucide-react' import { useCallback, useEffect, useMemo } from 'react' @@ -126,20 +128,25 @@ export const useWebSearchPanelController = (assistantId: string, quickPanelContr const providerItems = useMemo(() => { const isWebSearchModelEnabled = assistant.model && isWebSearchModel(assistant.model) - const items: QuickPanelListItem[] = providers - .map((p) => ({ - label: p.name, - description: WebSearchService.isWebSearchEnabled(p.id) - ? hasObjectKey(p, 'apiKey') - ? t('settings.tool.websearch.apikey') - : t('settings.tool.websearch.free') - : t('chat.input.web_search.enable_content'), - icon: , - isSelected: p.id === assistant?.webSearchProviderId, - disabled: !WebSearchService.isWebSearchEnabled(p.id), - action: () => updateQuickPanelItem(p.id) - })) - .filter((item) => !item.disabled) + const items: QuickPanelListItem[] = [] + if (isFunctionCallingModel(assistant.model) || isPromptToolUse(assistant)) { + items.push( + ...providers + .map((p) => ({ + label: p.name, + description: WebSearchService.isWebSearchEnabled(p.id) + ? hasObjectKey(p, 'apiKey') + ? t('settings.tool.websearch.apikey') + : t('settings.tool.websearch.free') + : t('chat.input.web_search.enable_content'), + icon: , + isSelected: p.id === assistant?.webSearchProviderId, + disabled: !WebSearchService.isWebSearchEnabled(p.id), + action: () => updateQuickPanelItem(p.id) + })) + .filter((item) => !item.disabled) + ) + } if (isWebSearchModelEnabled) { items.unshift({ @@ -155,15 +162,7 @@ export const useWebSearchPanelController = (assistantId: string, quickPanelContr } return items - }, [ - assistant.enableWebSearch, - assistant.model, - assistant?.webSearchProviderId, - providers, - t, - updateQuickPanelItem, - updateToModelBuiltinWebSearch - ]) + }, [assistant, providers, t, updateQuickPanelItem, updateToModelBuiltinWebSearch]) const openQuickPanel = useCallback(() => { quickPanelController.open({