From 2f58b3360edec155f3720a14d9fc73c0bb8d2069 Mon Sep 17 00:00:00 2001 From: suyao Date: Sat, 21 Jun 2025 16:48:16 +0800 Subject: [PATCH] feat: enhance provider options and examples for AI SDK - Introduced new utility functions for creating and merging provider options, improving type safety and usability. - Added comprehensive examples for OpenAI, Anthropic, Google, and generic provider options to demonstrate usage. - Refactored existing code to streamline provider configuration and enhance clarity in the options management. - Updated the PluginEnabledAiClient to simplify the handling of model parameters and improve overall functionality. --- .../src/clients/PluginEnabledAiClient.ts | 2 +- packages/aiCore/src/index.ts | 15 +- packages/aiCore/src/options/examples.ts | 87 ++++ packages/aiCore/src/options/factory.ts | 57 +++ packages/aiCore/src/options/index.ts | 2 + packages/aiCore/src/options/types.ts | 28 ++ .../src/aiCore/transformParameters.ts | 52 +- src/renderer/src/aiCore/utils/reasoning.ts | 466 ++++++++++++++++++ src/renderer/src/store/thunk/messageThunk.ts | 72 +-- src/renderer/src/utils/error.ts | 8 +- 10 files changed, 709 insertions(+), 80 deletions(-) create mode 100644 packages/aiCore/src/options/examples.ts create mode 100644 packages/aiCore/src/options/factory.ts create mode 100644 packages/aiCore/src/options/index.ts create mode 100644 packages/aiCore/src/options/types.ts create mode 100644 src/renderer/src/aiCore/utils/reasoning.ts diff --git a/packages/aiCore/src/clients/PluginEnabledAiClient.ts b/packages/aiCore/src/clients/PluginEnabledAiClient.ts index 4e34d188cc..fdca6116d9 100644 --- a/packages/aiCore/src/clients/PluginEnabledAiClient.ts +++ b/packages/aiCore/src/clients/PluginEnabledAiClient.ts @@ -225,7 +225,7 @@ export class PluginEnabledAiClient { ) } else { // 外部 registry 方式:直接使用用户提供的 model - return await streamText(modelIdOrParams) + return streamText(modelIdOrParams) } } diff --git a/packages/aiCore/src/index.ts b/packages/aiCore/src/index.ts index c512e27bf3..6dcaa596dd 100644 --- a/packages/aiCore/src/index.ts +++ b/packages/aiCore/src/index.ts @@ -112,15 +112,26 @@ export type { ZhipuProviderSettings } from './clients/types' +// ==================== 选项 ==================== +export { + createAnthropicOptions, + createGoogleOptions, + createOpenAIOptions, + type ExtractProviderOptions, + mergeProviderOptions, + type ProviderOptionsMap, + type TypedProviderOptions +} from './options' + // ==================== 工具函数 ==================== export { createClient as createApiClient, getClientInfo, getSupportedProviders } from './clients/ApiClientFactory' export { getAllProviders, getProvider, isProviderSupported, registerProvider } from './providers/registry' // ==================== Provider 配置工厂 ==================== export { - BaseProviderConfig, + type BaseProviderConfig, createProviderConfig, - ProviderConfigBuilder, + type ProviderConfigBuilder, providerConfigBuilder, ProviderConfigFactory } from './providers/factory' diff --git a/packages/aiCore/src/options/examples.ts b/packages/aiCore/src/options/examples.ts new file mode 100644 index 0000000000..9078437d9c --- /dev/null +++ b/packages/aiCore/src/options/examples.ts @@ -0,0 +1,87 @@ +import { streamText } from 'ai' + +import { + createAnthropicOptions, + createGenericProviderOptions, + createGoogleOptions, + createOpenAIOptions, + mergeProviderOptions +} from './factory' + +// 示例1: 使用已知供应商的严格类型约束 +export function exampleOpenAIWithOptions() { + const openaiOptions = createOpenAIOptions({ + reasoningEffort: 'medium' + }) + + // 这里会有类型检查,确保选项符合OpenAI的设置 + return streamText({ + model: {} as any, // 实际使用时替换为真实模型 + prompt: 'Hello', + providerOptions: openaiOptions + }) +} + +// 示例2: 使用Anthropic供应商选项 +export function exampleAnthropicWithOptions() { + const anthropicOptions = createAnthropicOptions({ + thinking: { + type: 'enabled', + budgetTokens: 1000 + } + }) + + return streamText({ + model: {} as any, + prompt: 'Hello', + providerOptions: anthropicOptions + }) +} + +// 示例3: 使用Google供应商选项 +export function exampleGoogleWithOptions() { + const googleOptions = createGoogleOptions({ + thinkingConfig: { + includeThoughts: true, + thinkingBudget: 1000 + } + }) + + return streamText({ + model: {} as any, + prompt: 'Hello', + providerOptions: googleOptions + }) +} + +// 示例4: 使用未知供应商(通用类型) +export function exampleUnknownProviderWithOptions() { + const customProviderOptions = createGenericProviderOptions('custom-provider', { + temperature: 0.7, + customSetting: 'value', + anotherOption: true + }) + + return streamText({ + model: {} as any, + prompt: 'Hello', + providerOptions: customProviderOptions + }) +} + +// 示例5: 合并多个供应商选项 +export function exampleMergedOptions() { + const openaiOptions = createOpenAIOptions({}) + + const customOptions = createGenericProviderOptions('custom', { + customParam: 'value' + }) + + const mergedOptions = mergeProviderOptions(openaiOptions, customOptions) + + return streamText({ + model: {} as any, + prompt: 'Hello', + providerOptions: mergedOptions + }) +} diff --git a/packages/aiCore/src/options/factory.ts b/packages/aiCore/src/options/factory.ts new file mode 100644 index 0000000000..46107e65f9 --- /dev/null +++ b/packages/aiCore/src/options/factory.ts @@ -0,0 +1,57 @@ +import { ExtractProviderOptions, ProviderOptionsMap, TypedProviderOptions } from './types' + +/** + * 创建特定供应商的选项 + * @param provider 供应商名称 + * @param options 供应商特定的选项 + * @returns 格式化的provider options + */ +export function createProviderOptions( + provider: T, + options: ExtractProviderOptions +): Record> { + return { [provider]: options } as Record> +} + +/** + * 创建任意供应商的选项(包括未知供应商) + * @param provider 供应商名称 + * @param options 供应商选项 + * @returns 格式化的provider options + */ +export function createGenericProviderOptions( + provider: T, + options: Record +): Record> { + return { [provider]: options } as Record> +} + +/** + * 合并多个供应商的options + * @param optionsMap 包含多个供应商选项的对象 + * @returns 合并后的TypedProviderOptions + */ +export function mergeProviderOptions(...optionsMap: Partial[]): TypedProviderOptions { + return Object.assign({}, ...optionsMap) +} + +/** + * 创建OpenAI供应商选项的便捷函数 + */ +export function createOpenAIOptions(options: ExtractProviderOptions<'openai'>) { + return createProviderOptions('openai', options) +} + +/** + * 创建Anthropic供应商选项的便捷函数 + */ +export function createAnthropicOptions(options: ExtractProviderOptions<'anthropic'>) { + return createProviderOptions('anthropic', options) +} + +/** + * 创建Google供应商选项的便捷函数 + */ +export function createGoogleOptions(options: ExtractProviderOptions<'google'>) { + return createProviderOptions('google', options) +} diff --git a/packages/aiCore/src/options/index.ts b/packages/aiCore/src/options/index.ts new file mode 100644 index 0000000000..97a7b59914 --- /dev/null +++ b/packages/aiCore/src/options/index.ts @@ -0,0 +1,2 @@ +export * from './factory' +export * from './types' diff --git a/packages/aiCore/src/options/types.ts b/packages/aiCore/src/options/types.ts new file mode 100644 index 0000000000..b9b31af980 --- /dev/null +++ b/packages/aiCore/src/options/types.ts @@ -0,0 +1,28 @@ +import { AnthropicProviderOptions } from '@ai-sdk/anthropic' +import { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google' +import { OpenAIResponsesProviderOptions } from '@ai-sdk/openai' +import { LanguageModelV1ProviderMetadata } from '@ai-sdk/provider' + +export type ProviderOptions = LanguageModelV1ProviderMetadata[T] + +/** + * 供应商选项类型,如果map中没有,说明没有约束 + */ +export type ProviderOptionsMap = { + openai: OpenAIResponsesProviderOptions + anthropic: AnthropicProviderOptions + google: GoogleGenerativeAIProviderOptions +} + +// 工具类型,用于从ProviderOptionsMap中提取特定供应商的选项类型 +export type ExtractProviderOptions = ProviderOptionsMap[T] + +/** + * 类型安全的ProviderOptions + * 对于已知供应商使用严格类型,对于未知供应商允许任意Record + */ +export type TypedProviderOptions = { + [K in keyof ProviderOptionsMap]?: ProviderOptionsMap[K] +} & { + [K in string]?: Record +} & LanguageModelV1ProviderMetadata diff --git a/src/renderer/src/aiCore/transformParameters.ts b/src/renderer/src/aiCore/transformParameters.ts index dd9aed2111..cc80638980 100644 --- a/src/renderer/src/aiCore/transformParameters.ts +++ b/src/renderer/src/aiCore/transformParameters.ts @@ -22,6 +22,8 @@ import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/u import { buildSystemPrompt } from '@renderer/utils/prompt' import { defaultTimeout } from '@shared/config/constant' +import { buildProviderOptions } from './utils/reasoning' + /** * 获取温度参数 */ @@ -233,6 +235,13 @@ export async function buildStreamTextParams( systemPrompt = await buildSystemPromptWithTools(systemPrompt, mcpTools, assistant) } + // 构建真正的 providerOptions + const providerOptions = buildProviderOptions(assistant, model, { + enableReasoning, + enableWebSearch, + enableGenerateImage + }) + // 构建基础参数 const params: StreamTextParams = { messages: sdkMessages, @@ -242,19 +251,7 @@ export async function buildStreamTextParams( system: systemPrompt || undefined, abortSignal: options.requestOptions?.signal, headers: options.requestOptions?.headers, - // 随便填着,后面再改 - providerOptions: { - reasoning: { - enabled: enableReasoning - }, - webSearch: { - enabled: enableWebSearch - }, - generateImage: { - enabled: enableGenerateImage - } - }, - ...getCustomParameters(assistant) + providerOptions } // 添加工具(如果启用且有工具) @@ -280,32 +277,3 @@ export async function buildGenerateTextParams( // 复用流式参数的构建逻辑 return await buildStreamTextParams(messages, assistant, options) } - -/** - * 获取自定义参数 - * 从 assistant 设置中提取自定义参数 - */ -export function getCustomParameters(assistant: Assistant): Record { - return ( - assistant?.settings?.customParameters?.reduce((acc, param) => { - if (!param.name?.trim()) { - return acc - } - if (param.type === 'json') { - const value = param.value as string - if (value === 'undefined') { - return { ...acc, [param.name]: undefined } - } - try { - return { ...acc, [param.name]: JSON.parse(value) } - } catch { - return { ...acc, [param.name]: value } - } - } - return { - ...acc, - [param.name]: param.value - } - }, {}) || {} - ) -} diff --git a/src/renderer/src/aiCore/utils/reasoning.ts b/src/renderer/src/aiCore/utils/reasoning.ts new file mode 100644 index 0000000000..789688a027 --- /dev/null +++ b/src/renderer/src/aiCore/utils/reasoning.ts @@ -0,0 +1,466 @@ +import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' +import { + findTokenLimit, + GEMINI_FLASH_MODEL_REGEX, + isDoubaoThinkingAutoModel, + isReasoningModel, + isSupportedReasoningEffortGrokModel, + isSupportedReasoningEffortModel, + isSupportedReasoningEffortOpenAIModel, + isSupportedThinkingTokenClaudeModel, + isSupportedThinkingTokenDoubaoModel, + isSupportedThinkingTokenGeminiModel, + isSupportedThinkingTokenModel, + isSupportedThinkingTokenQwenModel +} from '@renderer/config/models' +import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService' +import { Assistant, EFFORT_RATIO, Model } from '@renderer/types' +import { ReasoningEffortOptionalParams } from '@renderer/types/sdk' + +import { getAiSdkProviderId } from '../provider/factory' + +export function getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams { + const provider = getProviderByModel(model) + if (provider.id === 'groq') { + return {} + } + + if (!isReasoningModel(model)) { + return {} + } + const reasoningEffort = assistant?.settings?.reasoning_effort + + // Doubao 思考模式支持 + if (isSupportedThinkingTokenDoubaoModel(model)) { + // reasoningEffort 为空,默认开启 enabled + if (!reasoningEffort) { + return { thinking: { type: 'disabled' } } + } + if (reasoningEffort === 'high') { + return { thinking: { type: 'enabled' } } + } + if (reasoningEffort === 'auto' && isDoubaoThinkingAutoModel(model)) { + return { thinking: { type: 'auto' } } + } + // 其他情况不带 thinking 字段 + return {} + } + + if (!reasoningEffort) { + if (isSupportedThinkingTokenQwenModel(model)) { + return { enable_thinking: false } + } + + if (isSupportedThinkingTokenClaudeModel(model)) { + return {} + } + + if (isSupportedThinkingTokenGeminiModel(model)) { + // openrouter没有提供一个不推理的选项,先隐藏 + if (provider.id === 'openrouter') { + return { reasoning: { max_tokens: 0, exclude: true } } + } + if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) { + return { reasoning_effort: 'none' } + } + return {} + } + + if (isSupportedThinkingTokenDoubaoModel(model)) { + return { thinking: { type: 'disabled' } } + } + + return {} + } + const effortRatio = EFFORT_RATIO[reasoningEffort] + const budgetTokens = Math.floor( + (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min! + ) + + // OpenRouter models + if (model.provider === 'openrouter') { + if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) { + return { + reasoning: { + effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort + } + } + } + } + + // Qwen models + if (isSupportedThinkingTokenQwenModel(model)) { + return { + enable_thinking: true, + thinking_budget: budgetTokens + } + } + + // Grok models + if (isSupportedReasoningEffortGrokModel(model)) { + return { + reasoning_effort: reasoningEffort + } + } + + // OpenAI models + if (isSupportedReasoningEffortOpenAIModel(model) || isSupportedThinkingTokenGeminiModel(model)) { + return { + reasoning_effort: reasoningEffort + } + } + + // Claude models + if (isSupportedThinkingTokenClaudeModel(model)) { + const maxTokens = assistant.settings?.maxTokens + return { + thinking: { + type: 'enabled', + budget_tokens: Math.floor( + Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio)) + ) + } + } + } + + // Doubao models + if (isSupportedThinkingTokenDoubaoModel(model)) { + if (assistant.settings?.reasoning_effort === 'high') { + return { + thinking: { + type: 'enabled' + } + } + } + } + + // Default case: no special thinking settings + return {} +} + +/** + * 构建 AI SDK 的 providerOptions + * 按 provider 类型分离,保持类型安全 + * 返回格式:{ 'providerId': providerOptions } + */ +export function buildProviderOptions( + assistant: Assistant, + model: Model, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + } +): Record { + const provider = getProviderByModel(model) + const providerId = getAiSdkProviderId(provider) + + // 构建 provider 特定的选项 + let providerSpecificOptions: Record = {} + + // 根据 provider 类型分离构建逻辑 + switch (provider.type) { + case 'openai': + case 'azure-openai': + providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities) + break + + case 'anthropic': + providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities) + break + + case 'gemini': + case 'vertexai': + providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities) + break + + default: + // 对于其他 provider,使用通用的构建逻辑 + providerSpecificOptions = buildGenericProviderOptions(assistant, model, capabilities) + break + } + + // 合并自定义参数到 provider 特定的选项中 + const customParameters = getCustomParameters(assistant) + Object.assign(providerSpecificOptions, customParameters) + + // 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } + return { + [providerId]: providerSpecificOptions + } +} + +/** + * 构建 OpenAI 特定的 providerOptions + */ +function buildOpenAIProviderOptions( + assistant: Assistant, + model: Model, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + } +): Record { + const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities + const providerOptions: Record = {} + + // OpenAI 推理参数 + if (enableReasoning) { + const reasoningParams = getOpenAIReasoningParams(assistant, model) + Object.assign(providerOptions, reasoningParams) + } + + // Web 搜索和图像生成暂时使用通用格式 + if (enableWebSearch) { + providerOptions.webSearch = { enabled: true } + } + + if (enableGenerateImage) { + providerOptions.generateImage = { enabled: true } + } + + return providerOptions +} + +/** + * 构建 Anthropic 特定的 providerOptions + */ +function buildAnthropicProviderOptions( + assistant: Assistant, + model: Model, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + } +): Record { + const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities + const providerOptions: Record = {} + + // Anthropic 推理参数 + if (enableReasoning) { + const reasoningParams = getAnthropicReasoningParams(assistant, model) + Object.assign(providerOptions, reasoningParams) + } + + if (enableWebSearch) { + providerOptions.webSearch = { enabled: true } + } + + if (enableGenerateImage) { + providerOptions.generateImage = { enabled: true } + } + + return providerOptions +} + +/** + * 构建 Gemini 特定的 providerOptions + */ +function buildGeminiProviderOptions( + assistant: Assistant, + model: Model, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + } +): Record { + const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities + const providerOptions: Record = {} + + // Gemini 推理参数 + if (enableReasoning) { + const reasoningParams = getGeminiReasoningParams(assistant, model) + Object.assign(providerOptions, reasoningParams) + } + + if (enableWebSearch) { + providerOptions.webSearch = { enabled: true } + } + + if (enableGenerateImage) { + providerOptions.generateImage = { enabled: true } + } + + return providerOptions +} + +/** + * 构建通用的 providerOptions(用于其他 provider) + */ +function buildGenericProviderOptions( + assistant: Assistant, + model: Model, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + } +): Record { + const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities + const providerOptions: Record = {} + + // 使用原有的通用推理逻辑 + if (enableReasoning) { + const reasoningParams = getReasoningEffort(assistant, model) + Object.assign(providerOptions, reasoningParams) + } + + if (enableWebSearch) { + providerOptions.webSearch = { enabled: true } + } + + if (enableGenerateImage) { + providerOptions.generateImage = { enabled: true } + } + + return providerOptions +} + +/** + * 获取 OpenAI 推理参数 + * 从 OpenAIResponseAPIClient 和 OpenAIAPIClient 中提取的逻辑 + */ +function getOpenAIReasoningParams(assistant: Assistant, model: Model): Record { + if (!isReasoningModel(model)) { + return {} + } + + const reasoningEffort = assistant?.settings?.reasoning_effort + + if (!reasoningEffort) { + return {} + } + + // OpenAI 推理参数 + if (isSupportedReasoningEffortOpenAIModel(model)) { + return { + reasoning_effort: reasoningEffort + } + } + + return {} +} + +/** + * 获取 Anthropic 推理参数 + * 从 AnthropicAPIClient 中提取的逻辑 + */ +function getAnthropicReasoningParams(assistant: Assistant, model: Model): Record { + if (!isReasoningModel(model)) { + return {} + } + + const reasoningEffort = assistant?.settings?.reasoning_effort + + if (reasoningEffort === undefined) { + return { + thinking: { + type: 'disabled' + } + } + } + + // Claude 推理参数 + if (isSupportedThinkingTokenClaudeModel(model)) { + const { maxTokens } = getAssistantSettings(assistant) + const effortRatio = EFFORT_RATIO[reasoningEffort] + + const budgetTokens = Math.max( + 1024, + Math.floor( + Math.min( + (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + + findTokenLimit(model.id)?.min!, + (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio + ) + ) + ) + + return { + thinking: { + type: 'enabled', + budget_tokens: budgetTokens + } + } + } + + return {} +} + +/** + * 获取 Gemini 推理参数 + * 从 GeminiAPIClient 中提取的逻辑 + */ +function getGeminiReasoningParams(assistant: Assistant, model: Model): Record { + if (!isReasoningModel(model)) { + return {} + } + + const reasoningEffort = assistant?.settings?.reasoning_effort + + // Gemini 推理参数 + if (isSupportedThinkingTokenGeminiModel(model)) { + if (reasoningEffort === undefined) { + return { + thinkingConfig: { + includeThoughts: false, + ...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {}) + } + } + } + + const effortRatio = EFFORT_RATIO[reasoningEffort] + + if (effortRatio > 1) { + return { + thinkingConfig: { + includeThoughts: true + } + } + } + + const { min, max } = findTokenLimit(model.id) || { min: 0, max: 0 } + const budget = Math.floor((max - min) * effortRatio + min) + + return { + thinkingConfig: { + ...(budget > 0 ? { thinkingBudget: budget } : {}), + includeThoughts: true + } + } + } + + return {} +} + +/** + * 获取自定义参数 + * 从 assistant 设置中提取自定义参数 + */ +function getCustomParameters(assistant: Assistant): Record { + return ( + assistant?.settings?.customParameters?.reduce((acc, param) => { + if (!param.name?.trim()) { + return acc + } + if (param.type === 'json') { + const value = param.value as string + if (value === 'undefined') { + return { ...acc, [param.name]: undefined } + } + try { + return { ...acc, [param.name]: JSON.parse(value) } + } catch { + return { ...acc, [param.name]: value } + } + } + return { + ...acc, + [param.name]: param.value + } + }, {}) || {} + ) +} diff --git a/src/renderer/src/store/thunk/messageThunk.ts b/src/renderer/src/store/thunk/messageThunk.ts index fe6e1bf628..4443cb1633 100644 --- a/src/renderer/src/store/thunk/messageThunk.ts +++ b/src/renderer/src/store/thunk/messageThunk.ts @@ -34,7 +34,7 @@ import { resetAssistantMessage } from '@renderer/utils/messageUtils/create' import { getMainTextContent } from '@renderer/utils/messageUtils/find' -import { getTopicQueue } from '@renderer/utils/queue' +import { getTopicQueue, waitForTopicQueue } from '@renderer/utils/queue' import { isOnHomePage } from '@renderer/utils/window' import { t } from 'i18next' import { isEmpty, throttle } from 'lodash' @@ -44,10 +44,10 @@ import type { AppDispatch, RootState } from '../index' import { removeManyBlocks, updateOneBlock, upsertManyBlocks, upsertOneBlock } from '../messageBlock' import { newMessagesActions, selectMessagesForTopic } from '../newMessage' -// const handleChangeLoadingOfTopic = async (topicId: string) => { -// await waitForTopicQueue(topicId) -// store.dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false })) -// } +const handleChangeLoadingOfTopic = async (topicId: string) => { + await waitForTopicQueue(topicId) + store.dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false })) +} // TODO: 后续可以将db操作移到Listener Middleware中 export const saveMessageAndBlocksToDB = async (message: Message, blocks: MessageBlock[], messageIndex: number = -1) => { try { @@ -704,12 +704,12 @@ const fetchAndProcessAssistantResponseImpl = async ( } const serializableError = { - name: error.name, - message: pauseErrorLanguagePlaceholder || error.message || formatErrorMessage(error), - originalMessage: error.message, - stack: error.stack, - status: error.status || error.code, - requestId: error.request_id + name: String(error.name || ''), + message: pauseErrorLanguagePlaceholder || String(error.message || '') || formatErrorMessage(error), + originalMessage: String(error.message || ''), + stack: String(error.stack || ''), + status: error.status || error.code || undefined, + requestId: error.request_id || undefined } if (!isOnHomePage()) { await notificationService.send({ @@ -723,7 +723,14 @@ const fetchAndProcessAssistantResponseImpl = async ( }) } const possibleBlockId = - mainTextBlockId || thinkingBlockId || toolBlockId || imageBlockId || citationBlockId || lastBlockId + mainTextBlockId || + thinkingBlockId || + toolBlockId || + imageBlockId || + citationBlockId || + lastBlockId || + initialPlaceholderBlockId + if (possibleBlockId) { // 更改上一个block的状态为ERROR const changes: Partial = { @@ -743,6 +750,9 @@ const fetchAndProcessAssistantResponseImpl = async ( saveUpdatesToDB(assistantMsgId, topicId, messageErrorUpdate, []) + // 立即设置 loading 为 false,因为当前任务已经出错 + dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false })) + EventEmitter.emit(EVENT_NAMES.MESSAGE_COMPLETE, { id: assistantMsgId, topicId, @@ -843,14 +853,14 @@ const fetchAndProcessAssistantResponseImpl = async ( ) } catch (error: any) { console.error('Error in fetchAndProcessAssistantResponseImpl:', error) - // The main error handling is now delegated to OrchestrationService, - // which calls the `onError` callback. This catch block is for - // any errors that might occur outside of that orchestration flow. - if (assistantMessage && callbacks.onError) { - callbacks.onError(error) - } else { - // Fallback if callbacks are not even defined yet - throw error + // 统一错误处理:确保 loading 状态被正确设置,避免队列任务卡住 + try { + await callbacks.onError?.(error) + } catch (callbackError) { + console.error('Error in onError callback:', callbackError) + } finally { + // 确保无论如何都设置 loading 为 false(onError 回调中已设置,这里是保险) + dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false })) } } } @@ -895,10 +905,9 @@ export const sendMessage = } } catch (error) { console.error('Error in sendMessage thunk:', error) + } finally { + handleChangeLoadingOfTopic(topicId) } - // finally { - // handleChangeLoadingOfTopic(topicId) - // } } /** @@ -1132,10 +1141,9 @@ export const resendMessageThunk = } } catch (error) { console.error(`[resendMessageThunk] Error resending user message ${userMessageToResend.id}:`, error) + } finally { + handleChangeLoadingOfTopic(topicId) } - // finally { - // handleChangeLoadingOfTopic(topicId) - // } } /** @@ -1243,11 +1251,10 @@ export const regenerateAssistantResponseThunk = `[regenerateAssistantResponseThunk] Error regenerating response for assistant message ${assistantMessageToRegenerate.id}:`, error ) - // dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false })) + dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false })) + } finally { + handleChangeLoadingOfTopic(topicId) } - // finally { - // handleChangeLoadingOfTopic(topicId) - // } } // --- Thunk to initiate translation and create the initial block --- @@ -1413,10 +1420,9 @@ export const appendAssistantResponseThunk = console.error(`[appendAssistantResponseThunk] Error appending assistant response:`, error) // Optionally dispatch an error action or notification // Resetting loading state should be handled by the underlying fetchAndProcessAssistantResponseImpl + } finally { + handleChangeLoadingOfTopic(topicId) } - // finally { - // handleChangeLoadingOfTopic(topicId) - // } } /** diff --git a/src/renderer/src/utils/error.ts b/src/renderer/src/utils/error.ts index 4150b200f2..9c217e23ad 100644 --- a/src/renderer/src/utils/error.ts +++ b/src/renderer/src/utils/error.ts @@ -71,8 +71,11 @@ export function getErrorMessage(error: any): string { } export const isAbortError = (error: any): boolean => { + // Convert message to string for consistent checking + const errorMessage = String(error?.message || '') + // 检查错误消息 - if (error?.message === 'Request was aborted.') { + if (errorMessage === 'Request was aborted.') { return true } @@ -85,7 +88,8 @@ export const isAbortError = (error: any): boolean => { if ( error && typeof error === 'object' && - (error.message === 'Request was aborted.' || error?.message?.includes('signal is aborted without reason')) + errorMessage && + (errorMessage === 'Request was aborted.' || errorMessage.includes('signal is aborted without reason')) ) { return true }