mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
feat: enhance provider options and examples for AI SDK
- Introduced new utility functions for creating and merging provider options, improving type safety and usability. - Added comprehensive examples for OpenAI, Anthropic, Google, and generic provider options to demonstrate usage. - Refactored existing code to streamline provider configuration and enhance clarity in the options management. - Updated the PluginEnabledAiClient to simplify the handling of model parameters and improve overall functionality.
This commit is contained in:
parent
f934b479b2
commit
2f58b3360e
@ -225,7 +225,7 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
|
||||
)
|
||||
} else {
|
||||
// 外部 registry 方式:直接使用用户提供的 model
|
||||
return await streamText(modelIdOrParams)
|
||||
return streamText(modelIdOrParams)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -112,15 +112,26 @@ export type {
|
||||
ZhipuProviderSettings
|
||||
} from './clients/types'
|
||||
|
||||
// ==================== 选项 ====================
|
||||
export {
|
||||
createAnthropicOptions,
|
||||
createGoogleOptions,
|
||||
createOpenAIOptions,
|
||||
type ExtractProviderOptions,
|
||||
mergeProviderOptions,
|
||||
type ProviderOptionsMap,
|
||||
type TypedProviderOptions
|
||||
} from './options'
|
||||
|
||||
// ==================== 工具函数 ====================
|
||||
export { createClient as createApiClient, getClientInfo, getSupportedProviders } from './clients/ApiClientFactory'
|
||||
export { getAllProviders, getProvider, isProviderSupported, registerProvider } from './providers/registry'
|
||||
|
||||
// ==================== Provider 配置工厂 ====================
|
||||
export {
|
||||
BaseProviderConfig,
|
||||
type BaseProviderConfig,
|
||||
createProviderConfig,
|
||||
ProviderConfigBuilder,
|
||||
type ProviderConfigBuilder,
|
||||
providerConfigBuilder,
|
||||
ProviderConfigFactory
|
||||
} from './providers/factory'
|
||||
|
||||
87
packages/aiCore/src/options/examples.ts
Normal file
87
packages/aiCore/src/options/examples.ts
Normal file
@ -0,0 +1,87 @@
|
||||
import { streamText } from 'ai'
|
||||
|
||||
import {
|
||||
createAnthropicOptions,
|
||||
createGenericProviderOptions,
|
||||
createGoogleOptions,
|
||||
createOpenAIOptions,
|
||||
mergeProviderOptions
|
||||
} from './factory'
|
||||
|
||||
// 示例1: 使用已知供应商的严格类型约束
|
||||
export function exampleOpenAIWithOptions() {
|
||||
const openaiOptions = createOpenAIOptions({
|
||||
reasoningEffort: 'medium'
|
||||
})
|
||||
|
||||
// 这里会有类型检查,确保选项符合OpenAI的设置
|
||||
return streamText({
|
||||
model: {} as any, // 实际使用时替换为真实模型
|
||||
prompt: 'Hello',
|
||||
providerOptions: openaiOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例2: 使用Anthropic供应商选项
|
||||
export function exampleAnthropicWithOptions() {
|
||||
const anthropicOptions = createAnthropicOptions({
|
||||
thinking: {
|
||||
type: 'enabled',
|
||||
budgetTokens: 1000
|
||||
}
|
||||
})
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: anthropicOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例3: 使用Google供应商选项
|
||||
export function exampleGoogleWithOptions() {
|
||||
const googleOptions = createGoogleOptions({
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: 1000
|
||||
}
|
||||
})
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: googleOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例4: 使用未知供应商(通用类型)
|
||||
export function exampleUnknownProviderWithOptions() {
|
||||
const customProviderOptions = createGenericProviderOptions('custom-provider', {
|
||||
temperature: 0.7,
|
||||
customSetting: 'value',
|
||||
anotherOption: true
|
||||
})
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: customProviderOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例5: 合并多个供应商选项
|
||||
export function exampleMergedOptions() {
|
||||
const openaiOptions = createOpenAIOptions({})
|
||||
|
||||
const customOptions = createGenericProviderOptions('custom', {
|
||||
customParam: 'value'
|
||||
})
|
||||
|
||||
const mergedOptions = mergeProviderOptions(openaiOptions, customOptions)
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: mergedOptions
|
||||
})
|
||||
}
|
||||
57
packages/aiCore/src/options/factory.ts
Normal file
57
packages/aiCore/src/options/factory.ts
Normal file
@ -0,0 +1,57 @@
|
||||
import { ExtractProviderOptions, ProviderOptionsMap, TypedProviderOptions } from './types'
|
||||
|
||||
/**
|
||||
* 创建特定供应商的选项
|
||||
* @param provider 供应商名称
|
||||
* @param options 供应商特定的选项
|
||||
* @returns 格式化的provider options
|
||||
*/
|
||||
export function createProviderOptions<T extends keyof ProviderOptionsMap>(
|
||||
provider: T,
|
||||
options: ExtractProviderOptions<T>
|
||||
): Record<T, ExtractProviderOptions<T>> {
|
||||
return { [provider]: options } as Record<T, ExtractProviderOptions<T>>
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建任意供应商的选项(包括未知供应商)
|
||||
* @param provider 供应商名称
|
||||
* @param options 供应商选项
|
||||
* @returns 格式化的provider options
|
||||
*/
|
||||
export function createGenericProviderOptions<T extends string>(
|
||||
provider: T,
|
||||
options: Record<string, any>
|
||||
): Record<T, Record<string, any>> {
|
||||
return { [provider]: options } as Record<T, Record<string, any>>
|
||||
}
|
||||
|
||||
/**
|
||||
* 合并多个供应商的options
|
||||
* @param optionsMap 包含多个供应商选项的对象
|
||||
* @returns 合并后的TypedProviderOptions
|
||||
*/
|
||||
export function mergeProviderOptions(...optionsMap: Partial<TypedProviderOptions>[]): TypedProviderOptions {
|
||||
return Object.assign({}, ...optionsMap)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenAI供应商选项的便捷函数
|
||||
*/
|
||||
export function createOpenAIOptions(options: ExtractProviderOptions<'openai'>) {
|
||||
return createProviderOptions('openai', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建Anthropic供应商选项的便捷函数
|
||||
*/
|
||||
export function createAnthropicOptions(options: ExtractProviderOptions<'anthropic'>) {
|
||||
return createProviderOptions('anthropic', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建Google供应商选项的便捷函数
|
||||
*/
|
||||
export function createGoogleOptions(options: ExtractProviderOptions<'google'>) {
|
||||
return createProviderOptions('google', options)
|
||||
}
|
||||
2
packages/aiCore/src/options/index.ts
Normal file
2
packages/aiCore/src/options/index.ts
Normal file
@ -0,0 +1,2 @@
|
||||
export * from './factory'
|
||||
export * from './types'
|
||||
28
packages/aiCore/src/options/types.ts
Normal file
28
packages/aiCore/src/options/types.ts
Normal file
@ -0,0 +1,28 @@
|
||||
import { AnthropicProviderOptions } from '@ai-sdk/anthropic'
|
||||
import { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
|
||||
import { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
|
||||
import { LanguageModelV1ProviderMetadata } from '@ai-sdk/provider'
|
||||
|
||||
export type ProviderOptions<T extends keyof LanguageModelV1ProviderMetadata> = LanguageModelV1ProviderMetadata[T]
|
||||
|
||||
/**
|
||||
* 供应商选项类型,如果map中没有,说明没有约束
|
||||
*/
|
||||
export type ProviderOptionsMap = {
|
||||
openai: OpenAIResponsesProviderOptions
|
||||
anthropic: AnthropicProviderOptions
|
||||
google: GoogleGenerativeAIProviderOptions
|
||||
}
|
||||
|
||||
// 工具类型,用于从ProviderOptionsMap中提取特定供应商的选项类型
|
||||
export type ExtractProviderOptions<T extends keyof ProviderOptionsMap> = ProviderOptionsMap[T]
|
||||
|
||||
/**
|
||||
* 类型安全的ProviderOptions
|
||||
* 对于已知供应商使用严格类型,对于未知供应商允许任意Record<string, JSONValue>
|
||||
*/
|
||||
export type TypedProviderOptions = {
|
||||
[K in keyof ProviderOptionsMap]?: ProviderOptionsMap[K]
|
||||
} & {
|
||||
[K in string]?: Record<string, any>
|
||||
} & LanguageModelV1ProviderMetadata
|
||||
@ -22,6 +22,8 @@ import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/u
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
|
||||
import { buildProviderOptions } from './utils/reasoning'
|
||||
|
||||
/**
|
||||
* 获取温度参数
|
||||
*/
|
||||
@ -233,6 +235,13 @@ export async function buildStreamTextParams(
|
||||
systemPrompt = await buildSystemPromptWithTools(systemPrompt, mcpTools, assistant)
|
||||
}
|
||||
|
||||
// 构建真正的 providerOptions
|
||||
const providerOptions = buildProviderOptions(assistant, model, {
|
||||
enableReasoning,
|
||||
enableWebSearch,
|
||||
enableGenerateImage
|
||||
})
|
||||
|
||||
// 构建基础参数
|
||||
const params: StreamTextParams = {
|
||||
messages: sdkMessages,
|
||||
@ -242,19 +251,7 @@ export async function buildStreamTextParams(
|
||||
system: systemPrompt || undefined,
|
||||
abortSignal: options.requestOptions?.signal,
|
||||
headers: options.requestOptions?.headers,
|
||||
// 随便填着,后面再改
|
||||
providerOptions: {
|
||||
reasoning: {
|
||||
enabled: enableReasoning
|
||||
},
|
||||
webSearch: {
|
||||
enabled: enableWebSearch
|
||||
},
|
||||
generateImage: {
|
||||
enabled: enableGenerateImage
|
||||
}
|
||||
},
|
||||
...getCustomParameters(assistant)
|
||||
providerOptions
|
||||
}
|
||||
|
||||
// 添加工具(如果启用且有工具)
|
||||
@ -280,32 +277,3 @@ export async function buildGenerateTextParams(
|
||||
// 复用流式参数的构建逻辑
|
||||
return await buildStreamTextParams(messages, assistant, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取自定义参数
|
||||
* 从 assistant 设置中提取自定义参数
|
||||
*/
|
||||
export function getCustomParameters(assistant: Assistant): Record<string, any> {
|
||||
return (
|
||||
assistant?.settings?.customParameters?.reduce((acc, param) => {
|
||||
if (!param.name?.trim()) {
|
||||
return acc
|
||||
}
|
||||
if (param.type === 'json') {
|
||||
const value = param.value as string
|
||||
if (value === 'undefined') {
|
||||
return { ...acc, [param.name]: undefined }
|
||||
}
|
||||
try {
|
||||
return { ...acc, [param.name]: JSON.parse(value) }
|
||||
} catch {
|
||||
return { ...acc, [param.name]: value }
|
||||
}
|
||||
}
|
||||
return {
|
||||
...acc,
|
||||
[param.name]: param.value
|
||||
}
|
||||
}, {}) || {}
|
||||
)
|
||||
}
|
||||
|
||||
466
src/renderer/src/aiCore/utils/reasoning.ts
Normal file
466
src/renderer/src/aiCore/utils/reasoning.ts
Normal file
@ -0,0 +1,466 @@
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import {
|
||||
findTokenLimit,
|
||||
GEMINI_FLASH_MODEL_REGEX,
|
||||
isDoubaoThinkingAutoModel,
|
||||
isReasoningModel,
|
||||
isSupportedReasoningEffortGrokModel,
|
||||
isSupportedReasoningEffortModel,
|
||||
isSupportedReasoningEffortOpenAIModel,
|
||||
isSupportedThinkingTokenClaudeModel,
|
||||
isSupportedThinkingTokenDoubaoModel,
|
||||
isSupportedThinkingTokenGeminiModel,
|
||||
isSupportedThinkingTokenModel,
|
||||
isSupportedThinkingTokenQwenModel
|
||||
} from '@renderer/config/models'
|
||||
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { Assistant, EFFORT_RATIO, Model } from '@renderer/types'
|
||||
import { ReasoningEffortOptionalParams } from '@renderer/types/sdk'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
|
||||
export function getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
|
||||
const provider = getProviderByModel(model)
|
||||
if (provider.id === 'groq') {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
// Doubao 思考模式支持
|
||||
if (isSupportedThinkingTokenDoubaoModel(model)) {
|
||||
// reasoningEffort 为空,默认开启 enabled
|
||||
if (!reasoningEffort) {
|
||||
return { thinking: { type: 'disabled' } }
|
||||
}
|
||||
if (reasoningEffort === 'high') {
|
||||
return { thinking: { type: 'enabled' } }
|
||||
}
|
||||
if (reasoningEffort === 'auto' && isDoubaoThinkingAutoModel(model)) {
|
||||
return { thinking: { type: 'auto' } }
|
||||
}
|
||||
// 其他情况不带 thinking 字段
|
||||
return {}
|
||||
}
|
||||
|
||||
if (!reasoningEffort) {
|
||||
if (isSupportedThinkingTokenQwenModel(model)) {
|
||||
return { enable_thinking: false }
|
||||
}
|
||||
|
||||
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||
// openrouter没有提供一个不推理的选项,先隐藏
|
||||
if (provider.id === 'openrouter') {
|
||||
return { reasoning: { max_tokens: 0, exclude: true } }
|
||||
}
|
||||
if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
|
||||
return { reasoning_effort: 'none' }
|
||||
}
|
||||
return {}
|
||||
}
|
||||
|
||||
if (isSupportedThinkingTokenDoubaoModel(model)) {
|
||||
return { thinking: { type: 'disabled' } }
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
const budgetTokens = Math.floor(
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min!
|
||||
)
|
||||
|
||||
// OpenRouter models
|
||||
if (model.provider === 'openrouter') {
|
||||
if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) {
|
||||
return {
|
||||
reasoning: {
|
||||
effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Qwen models
|
||||
if (isSupportedThinkingTokenQwenModel(model)) {
|
||||
return {
|
||||
enable_thinking: true,
|
||||
thinking_budget: budgetTokens
|
||||
}
|
||||
}
|
||||
|
||||
// Grok models
|
||||
if (isSupportedReasoningEffortGrokModel(model)) {
|
||||
return {
|
||||
reasoning_effort: reasoningEffort
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI models
|
||||
if (isSupportedReasoningEffortOpenAIModel(model) || isSupportedThinkingTokenGeminiModel(model)) {
|
||||
return {
|
||||
reasoning_effort: reasoningEffort
|
||||
}
|
||||
}
|
||||
|
||||
// Claude models
|
||||
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||
const maxTokens = assistant.settings?.maxTokens
|
||||
return {
|
||||
thinking: {
|
||||
type: 'enabled',
|
||||
budget_tokens: Math.floor(
|
||||
Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio))
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Doubao models
|
||||
if (isSupportedThinkingTokenDoubaoModel(model)) {
|
||||
if (assistant.settings?.reasoning_effort === 'high') {
|
||||
return {
|
||||
thinking: {
|
||||
type: 'enabled'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default case: no special thinking settings
|
||||
return {}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 AI SDK 的 providerOptions
|
||||
* 按 provider 类型分离,保持类型安全
|
||||
* 返回格式:{ 'providerId': providerOptions }
|
||||
*/
|
||||
export function buildProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const provider = getProviderByModel(model)
|
||||
const providerId = getAiSdkProviderId(provider)
|
||||
|
||||
// 构建 provider 特定的选项
|
||||
let providerSpecificOptions: Record<string, any> = {}
|
||||
|
||||
// 根据 provider 类型分离构建逻辑
|
||||
switch (provider.type) {
|
||||
case 'openai':
|
||||
case 'azure-openai':
|
||||
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
|
||||
case 'anthropic':
|
||||
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
|
||||
case 'gemini':
|
||||
case 'vertexai':
|
||||
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
|
||||
default:
|
||||
// 对于其他 provider,使用通用的构建逻辑
|
||||
providerSpecificOptions = buildGenericProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
}
|
||||
|
||||
// 合并自定义参数到 provider 特定的选项中
|
||||
const customParameters = getCustomParameters(assistant)
|
||||
Object.assign(providerSpecificOptions, customParameters)
|
||||
|
||||
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions }
|
||||
return {
|
||||
[providerId]: providerSpecificOptions
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 OpenAI 特定的 providerOptions
|
||||
*/
|
||||
function buildOpenAIProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
|
||||
const providerOptions: Record<string, any> = {}
|
||||
|
||||
// OpenAI 推理参数
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getOpenAIReasoningParams(assistant, model)
|
||||
Object.assign(providerOptions, reasoningParams)
|
||||
}
|
||||
|
||||
// Web 搜索和图像生成暂时使用通用格式
|
||||
if (enableWebSearch) {
|
||||
providerOptions.webSearch = { enabled: true }
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions.generateImage = { enabled: true }
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 Anthropic 特定的 providerOptions
|
||||
*/
|
||||
function buildAnthropicProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
|
||||
const providerOptions: Record<string, any> = {}
|
||||
|
||||
// Anthropic 推理参数
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getAnthropicReasoningParams(assistant, model)
|
||||
Object.assign(providerOptions, reasoningParams)
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
providerOptions.webSearch = { enabled: true }
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions.generateImage = { enabled: true }
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 Gemini 特定的 providerOptions
|
||||
*/
|
||||
function buildGeminiProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
|
||||
const providerOptions: Record<string, any> = {}
|
||||
|
||||
// Gemini 推理参数
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getGeminiReasoningParams(assistant, model)
|
||||
Object.assign(providerOptions, reasoningParams)
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
providerOptions.webSearch = { enabled: true }
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions.generateImage = { enabled: true }
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建通用的 providerOptions(用于其他 provider)
|
||||
*/
|
||||
function buildGenericProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
|
||||
const providerOptions: Record<string, any> = {}
|
||||
|
||||
// 使用原有的通用推理逻辑
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getReasoningEffort(assistant, model)
|
||||
Object.assign(providerOptions, reasoningParams)
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
providerOptions.webSearch = { enabled: true }
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions.generateImage = { enabled: true }
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 OpenAI 推理参数
|
||||
* 从 OpenAIResponseAPIClient 和 OpenAIAPIClient 中提取的逻辑
|
||||
*/
|
||||
function getOpenAIReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
if (!reasoningEffort) {
|
||||
return {}
|
||||
}
|
||||
|
||||
// OpenAI 推理参数
|
||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
return {
|
||||
reasoning_effort: reasoningEffort
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 Anthropic 推理参数
|
||||
* 从 AnthropicAPIClient 中提取的逻辑
|
||||
*/
|
||||
function getAnthropicReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
if (reasoningEffort === undefined) {
|
||||
return {
|
||||
thinking: {
|
||||
type: 'disabled'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Claude 推理参数
|
||||
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
const budgetTokens = Math.max(
|
||||
1024,
|
||||
Math.floor(
|
||||
Math.min(
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
||||
findTokenLimit(model.id)?.min!,
|
||||
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
thinking: {
|
||||
type: 'enabled',
|
||||
budget_tokens: budgetTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 Gemini 推理参数
|
||||
* 从 GeminiAPIClient 中提取的逻辑
|
||||
*/
|
||||
function getGeminiReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
// Gemini 推理参数
|
||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||
if (reasoningEffort === undefined) {
|
||||
return {
|
||||
thinkingConfig: {
|
||||
includeThoughts: false,
|
||||
...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
if (effortRatio > 1) {
|
||||
return {
|
||||
thinkingConfig: {
|
||||
includeThoughts: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const { min, max } = findTokenLimit(model.id) || { min: 0, max: 0 }
|
||||
const budget = Math.floor((max - min) * effortRatio + min)
|
||||
|
||||
return {
|
||||
thinkingConfig: {
|
||||
...(budget > 0 ? { thinkingBudget: budget } : {}),
|
||||
includeThoughts: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取自定义参数
|
||||
* 从 assistant 设置中提取自定义参数
|
||||
*/
|
||||
function getCustomParameters(assistant: Assistant): Record<string, any> {
|
||||
return (
|
||||
assistant?.settings?.customParameters?.reduce((acc, param) => {
|
||||
if (!param.name?.trim()) {
|
||||
return acc
|
||||
}
|
||||
if (param.type === 'json') {
|
||||
const value = param.value as string
|
||||
if (value === 'undefined') {
|
||||
return { ...acc, [param.name]: undefined }
|
||||
}
|
||||
try {
|
||||
return { ...acc, [param.name]: JSON.parse(value) }
|
||||
} catch {
|
||||
return { ...acc, [param.name]: value }
|
||||
}
|
||||
}
|
||||
return {
|
||||
...acc,
|
||||
[param.name]: param.value
|
||||
}
|
||||
}, {}) || {}
|
||||
)
|
||||
}
|
||||
@ -34,7 +34,7 @@ import {
|
||||
resetAssistantMessage
|
||||
} from '@renderer/utils/messageUtils/create'
|
||||
import { getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { getTopicQueue } from '@renderer/utils/queue'
|
||||
import { getTopicQueue, waitForTopicQueue } from '@renderer/utils/queue'
|
||||
import { isOnHomePage } from '@renderer/utils/window'
|
||||
import { t } from 'i18next'
|
||||
import { isEmpty, throttle } from 'lodash'
|
||||
@ -44,10 +44,10 @@ import type { AppDispatch, RootState } from '../index'
|
||||
import { removeManyBlocks, updateOneBlock, upsertManyBlocks, upsertOneBlock } from '../messageBlock'
|
||||
import { newMessagesActions, selectMessagesForTopic } from '../newMessage'
|
||||
|
||||
// const handleChangeLoadingOfTopic = async (topicId: string) => {
|
||||
// await waitForTopicQueue(topicId)
|
||||
// store.dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
// }
|
||||
const handleChangeLoadingOfTopic = async (topicId: string) => {
|
||||
await waitForTopicQueue(topicId)
|
||||
store.dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
}
|
||||
// TODO: 后续可以将db操作移到Listener Middleware中
|
||||
export const saveMessageAndBlocksToDB = async (message: Message, blocks: MessageBlock[], messageIndex: number = -1) => {
|
||||
try {
|
||||
@ -704,12 +704,12 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
}
|
||||
|
||||
const serializableError = {
|
||||
name: error.name,
|
||||
message: pauseErrorLanguagePlaceholder || error.message || formatErrorMessage(error),
|
||||
originalMessage: error.message,
|
||||
stack: error.stack,
|
||||
status: error.status || error.code,
|
||||
requestId: error.request_id
|
||||
name: String(error.name || ''),
|
||||
message: pauseErrorLanguagePlaceholder || String(error.message || '') || formatErrorMessage(error),
|
||||
originalMessage: String(error.message || ''),
|
||||
stack: String(error.stack || ''),
|
||||
status: error.status || error.code || undefined,
|
||||
requestId: error.request_id || undefined
|
||||
}
|
||||
if (!isOnHomePage()) {
|
||||
await notificationService.send({
|
||||
@ -723,7 +723,14 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
})
|
||||
}
|
||||
const possibleBlockId =
|
||||
mainTextBlockId || thinkingBlockId || toolBlockId || imageBlockId || citationBlockId || lastBlockId
|
||||
mainTextBlockId ||
|
||||
thinkingBlockId ||
|
||||
toolBlockId ||
|
||||
imageBlockId ||
|
||||
citationBlockId ||
|
||||
lastBlockId ||
|
||||
initialPlaceholderBlockId
|
||||
|
||||
if (possibleBlockId) {
|
||||
// 更改上一个block的状态为ERROR
|
||||
const changes: Partial<MessageBlock> = {
|
||||
@ -743,6 +750,9 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
|
||||
saveUpdatesToDB(assistantMsgId, topicId, messageErrorUpdate, [])
|
||||
|
||||
// 立即设置 loading 为 false,因为当前任务已经出错
|
||||
dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
|
||||
EventEmitter.emit(EVENT_NAMES.MESSAGE_COMPLETE, {
|
||||
id: assistantMsgId,
|
||||
topicId,
|
||||
@ -843,14 +853,14 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
)
|
||||
} catch (error: any) {
|
||||
console.error('Error in fetchAndProcessAssistantResponseImpl:', error)
|
||||
// The main error handling is now delegated to OrchestrationService,
|
||||
// which calls the `onError` callback. This catch block is for
|
||||
// any errors that might occur outside of that orchestration flow.
|
||||
if (assistantMessage && callbacks.onError) {
|
||||
callbacks.onError(error)
|
||||
} else {
|
||||
// Fallback if callbacks are not even defined yet
|
||||
throw error
|
||||
// 统一错误处理:确保 loading 状态被正确设置,避免队列任务卡住
|
||||
try {
|
||||
await callbacks.onError?.(error)
|
||||
} catch (callbackError) {
|
||||
console.error('Error in onError callback:', callbackError)
|
||||
} finally {
|
||||
// 确保无论如何都设置 loading 为 false(onError 回调中已设置,这里是保险)
|
||||
dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -895,10 +905,9 @@ export const sendMessage =
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error in sendMessage thunk:', error)
|
||||
} finally {
|
||||
handleChangeLoadingOfTopic(topicId)
|
||||
}
|
||||
// finally {
|
||||
// handleChangeLoadingOfTopic(topicId)
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1132,10 +1141,9 @@ export const resendMessageThunk =
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`[resendMessageThunk] Error resending user message ${userMessageToResend.id}:`, error)
|
||||
} finally {
|
||||
handleChangeLoadingOfTopic(topicId)
|
||||
}
|
||||
// finally {
|
||||
// handleChangeLoadingOfTopic(topicId)
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1243,11 +1251,10 @@ export const regenerateAssistantResponseThunk =
|
||||
`[regenerateAssistantResponseThunk] Error regenerating response for assistant message ${assistantMessageToRegenerate.id}:`,
|
||||
error
|
||||
)
|
||||
// dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
} finally {
|
||||
handleChangeLoadingOfTopic(topicId)
|
||||
}
|
||||
// finally {
|
||||
// handleChangeLoadingOfTopic(topicId)
|
||||
// }
|
||||
}
|
||||
|
||||
// --- Thunk to initiate translation and create the initial block ---
|
||||
@ -1413,10 +1420,9 @@ export const appendAssistantResponseThunk =
|
||||
console.error(`[appendAssistantResponseThunk] Error appending assistant response:`, error)
|
||||
// Optionally dispatch an error action or notification
|
||||
// Resetting loading state should be handled by the underlying fetchAndProcessAssistantResponseImpl
|
||||
} finally {
|
||||
handleChangeLoadingOfTopic(topicId)
|
||||
}
|
||||
// finally {
|
||||
// handleChangeLoadingOfTopic(topicId)
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -71,8 +71,11 @@ export function getErrorMessage(error: any): string {
|
||||
}
|
||||
|
||||
export const isAbortError = (error: any): boolean => {
|
||||
// Convert message to string for consistent checking
|
||||
const errorMessage = String(error?.message || '')
|
||||
|
||||
// 检查错误消息
|
||||
if (error?.message === 'Request was aborted.') {
|
||||
if (errorMessage === 'Request was aborted.') {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -85,7 +88,8 @@ export const isAbortError = (error: any): boolean => {
|
||||
if (
|
||||
error &&
|
||||
typeof error === 'object' &&
|
||||
(error.message === 'Request was aborted.' || error?.message?.includes('signal is aborted without reason'))
|
||||
errorMessage &&
|
||||
(errorMessage === 'Request was aborted.' || errorMessage.includes('signal is aborted without reason'))
|
||||
) {
|
||||
return true
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user