diff --git a/.github/workflows/dispatch-docs-update.yml b/.github/workflows/dispatch-docs-update.yml index b9457faec6..bb33c60b33 100644 --- a/.github/workflows/dispatch-docs-update.yml +++ b/.github/workflows/dispatch-docs-update.yml @@ -19,7 +19,7 @@ jobs: echo "tag=${{ github.event.release.tag_name }}" >> $GITHUB_OUTPUT - name: Dispatch update-download-version workflow to cherry-studio-docs - uses: peter-evans/repository-dispatch@v3 + uses: peter-evans/repository-dispatch@v4 with: token: ${{ secrets.REPO_DISPATCH_TOKEN }} repository: CherryHQ/cherry-studio-docs diff --git a/package.json b/package.json index 3f95aee6d5..fd5eb0151d 100644 --- a/package.json +++ b/package.json @@ -318,6 +318,7 @@ "motion": "^12.10.5", "notion-helper": "^1.3.22", "npx-scope-finder": "^1.2.0", + "ollama-ai-provider-v2": "^1.5.5", "oxlint": "^1.22.0", "oxlint-tsgolint": "^0.2.0", "p-queue": "^8.1.0", diff --git a/packages/ai-sdk-provider/package.json b/packages/ai-sdk-provider/package.json index ba232937f3..25864f3b1f 100644 --- a/packages/ai-sdk-provider/package.json +++ b/packages/ai-sdk-provider/package.json @@ -41,6 +41,7 @@ "ai": "^5.0.26" }, "dependencies": { + "@ai-sdk/openai-compatible": "^1.0.28", "@ai-sdk/provider": "^2.0.0", "@ai-sdk/provider-utils": "^3.0.17" }, diff --git a/packages/ai-sdk-provider/src/cherryin-provider.ts b/packages/ai-sdk-provider/src/cherryin-provider.ts index d045fdc505..33ec1a2a3a 100644 --- a/packages/ai-sdk-provider/src/cherryin-provider.ts +++ b/packages/ai-sdk-provider/src/cherryin-provider.ts @@ -2,7 +2,6 @@ import { AnthropicMessagesLanguageModel } from '@ai-sdk/anthropic/internal' import { GoogleGenerativeAILanguageModel } from '@ai-sdk/google/internal' import type { OpenAIProviderSettings } from '@ai-sdk/openai' import { - OpenAIChatLanguageModel, OpenAICompletionLanguageModel, OpenAIEmbeddingModel, OpenAIImageModel, @@ -10,6 +9,7 @@ import { OpenAISpeechModel, OpenAITranscriptionModel } from '@ai-sdk/openai/internal' +import { OpenAICompatibleChatLanguageModel } from '@ai-sdk/openai-compatible' import { type EmbeddingModelV2, type ImageModelV2, @@ -118,7 +118,7 @@ const createCustomFetch = (originalFetch?: any) => { return originalFetch ? originalFetch(url, options) : fetch(url, options) } } -class CherryInOpenAIChatLanguageModel extends OpenAIChatLanguageModel { +class CherryInOpenAIChatLanguageModel extends OpenAICompatibleChatLanguageModel { constructor(modelId: string, settings: any) { super(modelId, { ...settings, diff --git a/packages/aiCore/package.json b/packages/aiCore/package.json index df106158c1..a648dcf3c7 100644 --- a/packages/aiCore/package.json +++ b/packages/aiCore/package.json @@ -41,7 +41,7 @@ "dependencies": { "@ai-sdk/anthropic": "^2.0.49", "@ai-sdk/azure": "^2.0.74", - "@ai-sdk/deepseek": "^1.0.29", + "@ai-sdk/deepseek": "^1.0.31", "@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch", "@ai-sdk/provider": "^2.0.0", "@ai-sdk/provider-utils": "^3.0.17", diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts index a50356130d..61e6f49b81 100644 --- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts @@ -35,7 +35,6 @@ export interface WebSearchPluginConfig { anthropic?: AnthropicSearchConfig xai?: ProviderOptionsMap['xai']['searchParameters'] google?: GoogleSearchConfig - 'google-vertex'?: GoogleSearchConfig openrouter?: OpenRouterSearchConfig } @@ -44,7 +43,6 @@ export interface WebSearchPluginConfig { */ export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = { google: {}, - 'google-vertex': {}, openai: {}, 'openai-chat': {}, xai: { @@ -97,55 +95,28 @@ export type WebSearchToolInputSchema = { 'openai-chat': InferToolInput } -export const switchWebSearchTool = (providerId: string, config: WebSearchPluginConfig, params: any) => { - switch (providerId) { - case 'openai': { - if (config.openai) { - if (!params.tools) params.tools = {} - params.tools.web_search = openai.tools.webSearch(config.openai) - } - break - } - case 'openai-chat': { - if (config['openai-chat']) { - if (!params.tools) params.tools = {} - params.tools.web_search_preview = openai.tools.webSearchPreview(config['openai-chat']) - } - break - } - - case 'anthropic': { - if (config.anthropic) { - if (!params.tools) params.tools = {} - params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic) - } - break - } - - case 'google': { - // case 'google-vertex': - if (!params.tools) params.tools = {} - params.tools.web_search = google.tools.googleSearch(config.google || {}) - break - } - - case 'xai': { - if (config.xai) { - const searchOptions = createXaiOptions({ - searchParameters: { ...config.xai, mode: 'on' } - }) - params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) - } - break - } - - case 'openrouter': { - if (config.openrouter) { - const searchOptions = createOpenRouterOptions(config.openrouter) - params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) - } - break - } +export const switchWebSearchTool = (config: WebSearchPluginConfig, params: any) => { + if (config.openai) { + if (!params.tools) params.tools = {} + params.tools.web_search = openai.tools.webSearch(config.openai) + } else if (config['openai-chat']) { + if (!params.tools) params.tools = {} + params.tools.web_search_preview = openai.tools.webSearchPreview(config['openai-chat']) + } else if (config.anthropic) { + if (!params.tools) params.tools = {} + params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic) + } else if (config.google) { + // case 'google-vertex': + if (!params.tools) params.tools = {} + params.tools.web_search = google.tools.googleSearch(config.google || {}) + } else if (config.xai) { + const searchOptions = createXaiOptions({ + searchParameters: { ...config.xai, mode: 'on' } + }) + params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) + } else if (config.openrouter) { + const searchOptions = createOpenRouterOptions(config.openrouter) + params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) } return params } diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts index 75692cdf36..a46df7dd4c 100644 --- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts @@ -4,7 +4,6 @@ */ import { definePlugin } from '../../' -import type { AiRequestContext } from '../../types' import type { WebSearchPluginConfig } from './helper' import { DEFAULT_WEB_SEARCH_CONFIG, switchWebSearchTool } from './helper' @@ -18,15 +17,8 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR name: 'webSearch', enforce: 'pre', - transformParams: async (params: any, context: AiRequestContext) => { - const { providerId } = context - switchWebSearchTool(providerId, config, params) - - if (providerId === 'cherryin' || providerId === 'cherryin-chat') { - // cherryin.gemini - const _providerId = params.model.provider.split('.')[1] - switchWebSearchTool(_providerId, config, params) - } + transformParams: async (params: any) => { + switchWebSearchTool(config, params) return params } }) diff --git a/packages/shared/api/index.ts b/packages/shared/api/index.ts index dbc8c627a6..0566fd6551 100644 --- a/packages/shared/api/index.ts +++ b/packages/shared/api/index.ts @@ -102,6 +102,17 @@ export function formatVertexApiHost( return formatApiHost(trimmedHost) } +/** + * 格式化 Ollama 的 API 主机地址。 + */ +export function formatOllamaApiHost(host: string): string { + const normalizedHost = withoutTrailingSlash(host) + ?.replace(/\/v1$/, '') + ?.replace(/\/api$/, '') + ?.replace(/\/chat$/, '') + return formatApiHost(normalizedHost + '/api', false) +} + /** * Formats an API host URL by normalizing it and optionally appending an API version. * diff --git a/packages/shared/config/constant.ts b/packages/shared/config/constant.ts index c05fde902c..1e02ce7706 100644 --- a/packages/shared/config/constant.ts +++ b/packages/shared/config/constant.ts @@ -7,6 +7,11 @@ export const documentExts = ['.pdf', '.doc', '.docx', '.pptx', '.xlsx', '.odt', export const thirdPartyApplicationExts = ['.draftsExport'] export const bookExts = ['.epub'] +export const API_SERVER_DEFAULTS = { + HOST: '127.0.0.1', + PORT: 23333 +} + /** * A flat array of all file extensions known by the linguist database. * This is the primary source for identifying code files. diff --git a/packages/shared/provider/detection.ts b/packages/shared/provider/detection.ts index 19fff2dff9..8e76218fc8 100644 --- a/packages/shared/provider/detection.ts +++ b/packages/shared/provider/detection.ts @@ -52,11 +52,12 @@ export function isAwsBedrockProvider

(provider: P): bo return provider.type === 'aws-bedrock' } -/** - * Check if provider is AI Gateway type - */ export function isAIGatewayProvider

(provider: P): boolean { - return provider.type === 'ai-gateway' + return provider.type === 'gateway' +} + +export function isOllamaProvider

(provider: P): boolean { + return provider.type === 'ollama' } /** diff --git a/packages/shared/provider/format.ts b/packages/shared/provider/format.ts index 72e768d9b3..2a0e468eaa 100644 --- a/packages/shared/provider/format.ts +++ b/packages/shared/provider/format.ts @@ -9,6 +9,7 @@ import { formatApiHost, formatAzureOpenAIApiHost, + formatOllamaApiHost, formatVertexApiHost, routeToEndpoint, withoutTrailingSlash @@ -18,6 +19,7 @@ import { isAzureOpenAIProvider, isCherryAIProvider, isGeminiProvider, + isOllamaProvider, isPerplexityProvider, isVertexProvider } from './detection' @@ -77,6 +79,8 @@ export function formatProviderApiHost(provider: T, co } } else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) { formatted.apiHost = formatApiHost(formatted.apiHost, false) + } else if (isOllamaProvider(formatted)) { + formatted.apiHost = formatOllamaApiHost(formatted.apiHost) } else if (isGeminiProvider(formatted)) { formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta') } else if (isAzureOpenAIProvider(formatted)) { diff --git a/packages/shared/provider/index.ts b/packages/shared/provider/index.ts index f0b9b11d10..53f132acf1 100644 --- a/packages/shared/provider/index.ts +++ b/packages/shared/provider/index.ts @@ -19,6 +19,7 @@ export { isCherryAIProvider, isGeminiProvider, isNewApiProvider, + isOllamaProvider, isOpenAICompatibleProvider, isOpenAIProvider, isPerplexityProvider, diff --git a/packages/shared/provider/initialization.ts b/packages/shared/provider/initialization.ts index fbb5fba54f..c21ecd8732 100644 --- a/packages/shared/provider/initialization.ts +++ b/packages/shared/provider/initialization.ts @@ -79,12 +79,12 @@ export const SHARED_PROVIDER_CONFIGS: ProviderConfig[] = [ aliases: ['hf', 'hugging-face'] }, { - id: 'ai-gateway', - name: 'AI Gateway', + id: 'gateway', + name: 'Vercel AI Gateway', import: () => import('@ai-sdk/gateway'), creatorFunctionName: 'createGateway', supportsImageGeneration: true, - aliases: ['gateway'] + aliases: ['ai-gateway'] }, { id: 'cerebras', @@ -92,6 +92,13 @@ export const SHARED_PROVIDER_CONFIGS: ProviderConfig[] = [ import: () => import('@ai-sdk/cerebras'), creatorFunctionName: 'createCerebras', supportsImageGeneration: false + }, + { + id: 'ollama', + name: 'Ollama', + import: () => import('ollama-ai-provider-v2'), + creatorFunctionName: 'createOllama', + supportsImageGeneration: false } ] as const diff --git a/packages/shared/provider/sdk-config.ts b/packages/shared/provider/sdk-config.ts index 91b3c8d54e..d332cce8ba 100644 --- a/packages/shared/provider/sdk-config.ts +++ b/packages/shared/provider/sdk-config.ts @@ -6,8 +6,10 @@ */ import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider' +import { isEmpty } from 'lodash' import { routeToEndpoint } from '../api' +import { isOllamaProvider } from './detection' import { getAiSdkProviderId } from './mapping' import type { MinimalProvider } from './types' import { SystemProviderIds } from './types' @@ -157,6 +159,19 @@ export function providerToAiSdkConfig( } } + if (isOllamaProvider(provider)) { + return { + providerId: 'ollama', + options: { + ...baseConfig, + headers: { + ...provider.extra_headers, + Authorization: !isEmpty(baseConfig.apiKey) ? `Bearer ${baseConfig.apiKey}` : undefined + } + } + } + } + // Build extra options const extraOptions: Record = {} if (endpoint) { diff --git a/packages/shared/provider/types.ts b/packages/shared/provider/types.ts index b9745f9d3a..763ed210c4 100644 --- a/packages/shared/provider/types.ts +++ b/packages/shared/provider/types.ts @@ -11,7 +11,8 @@ export const ProviderTypeSchema = z.enum([ 'aws-bedrock', 'vertex-anthropic', 'new-api', - 'ai-gateway' + 'gateway', + 'ollama' ]) export type ProviderType = z.infer @@ -98,7 +99,7 @@ export const SystemProviderIdSchema = z.enum([ 'longcat', 'huggingface', 'sophnet', - 'ai-gateway', + 'gateway', 'cerebras' ]) @@ -167,7 +168,7 @@ export const SystemProviderIds = { aionly: 'aionly', longcat: 'longcat', huggingface: 'huggingface', - 'ai-gateway': 'ai-gateway', + gateway: 'gateway', cerebras: 'cerebras' } as const satisfies Record diff --git a/packages/shared/utils/naming.ts b/packages/shared/utils/naming.ts index a8b4f5501d..c9aaf55c36 100644 --- a/packages/shared/utils/naming.ts +++ b/packages/shared/utils/naming.ts @@ -27,5 +27,10 @@ export const getLowerBaseModelName = (id: string, delimiter: string = '/'): stri if (baseModelName.endsWith(':free')) { return baseModelName.replace(':free', '') } + + // for cherryin + if (baseModelName.endsWith('(free)')) { + return baseModelName.replace('(free)', '') + } return baseModelName } diff --git a/scripts/feishu-notify.js b/scripts/feishu-notify.js index aae9004a48..d238dedb90 100644 --- a/scripts/feishu-notify.js +++ b/scripts/feishu-notify.js @@ -91,23 +91,6 @@ function createIssueCard(issueData) { return { elements: [ - { - tag: 'div', - text: { - tag: 'lark_md', - content: `**🐛 New GitHub Issue #${issueNumber}**` - } - }, - { - tag: 'hr' - }, - { - tag: 'div', - text: { - tag: 'lark_md', - content: `**📝 Title:** ${issueTitle}` - } - }, { tag: 'div', text: { @@ -158,7 +141,7 @@ function createIssueCard(issueData) { template: 'blue', title: { tag: 'plain_text', - content: '🆕 Cherry Studio - New Issue' + content: `#${issueNumber} - ${issueTitle}` } } } diff --git a/src/main/apiServer/config.ts b/src/main/apiServer/config.ts index 60b1986be9..0966827a7b 100644 --- a/src/main/apiServer/config.ts +++ b/src/main/apiServer/config.ts @@ -1,3 +1,4 @@ +import { API_SERVER_DEFAULTS } from '@shared/config/constant' import type { ApiServerConfig } from '@types' import { v4 as uuidv4 } from 'uuid' @@ -6,9 +7,6 @@ import { reduxService } from '../services/ReduxService' const logger = loggerService.withContext('ApiServerConfig') -const defaultHost = 'localhost' -const defaultPort = 23333 - class ConfigManager { private _config: ApiServerConfig | null = null @@ -30,8 +28,8 @@ class ConfigManager { } this._config = { enabled: serverSettings?.enabled ?? false, - port: serverSettings?.port ?? defaultPort, - host: defaultHost, + port: serverSettings?.port ?? API_SERVER_DEFAULTS.PORT, + host: serverSettings?.host ?? API_SERVER_DEFAULTS.HOST, apiKey: apiKey } return this._config @@ -39,8 +37,8 @@ class ConfigManager { logger.warn('Failed to load config from Redux, using defaults', { error }) this._config = { enabled: false, - port: defaultPort, - host: defaultHost, + port: API_SERVER_DEFAULTS.PORT, + host: API_SERVER_DEFAULTS.HOST, apiKey: this.generateApiKey() } return this._config diff --git a/src/main/apiServer/middleware/openapi.ts b/src/main/apiServer/middleware/openapi.ts index ff01005bd9..6b374901ca 100644 --- a/src/main/apiServer/middleware/openapi.ts +++ b/src/main/apiServer/middleware/openapi.ts @@ -20,8 +20,8 @@ const swaggerOptions: swaggerJSDoc.Options = { }, servers: [ { - url: 'http://localhost:23333', - description: 'Local development server' + url: '/', + description: 'Current server' } ], components: { diff --git a/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts b/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts index 8a780d5618..e9f459fd6c 100644 --- a/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts +++ b/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts @@ -19,19 +19,9 @@ export default class EmbeddingsFactory { }) } if (provider === 'ollama') { - if (baseURL.includes('v1/')) { - return new OllamaEmbeddings({ - model: model, - baseUrl: baseURL.replace('v1/', ''), - requestOptions: { - // @ts-ignore expected - 'encoding-format': 'float' - } - }) - } return new OllamaEmbeddings({ model: model, - baseUrl: baseURL, + baseUrl: baseURL.replace(/\/api$/, ''), requestOptions: { // @ts-ignore expected 'encoding-format': 'float' diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index ed92d4ddd8..8c031f7754 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -189,7 +189,7 @@ export default class ModernAiProvider { config: ModernAiProviderConfig ): Promise { // ai-gateway不是image/generation 端点,所以就先不走legacy了 - if (config.isImageGenerationEndpoint && this.getActualProvider().id !== SystemProviderIds['ai-gateway']) { + if (config.isImageGenerationEndpoint && this.getActualProvider().id !== SystemProviderIds.gateway) { // 使用 legacy 实现处理图像生成(支持图片编辑等高级功能) if (!config.uiMessages) { throw new Error('uiMessages is required for image generation endpoint') @@ -480,7 +480,7 @@ export default class ModernAiProvider { // 代理其他方法到原有实现 public async models() { - if (this.actualProvider.id === SystemProviderIds['ai-gateway']) { + if (this.actualProvider.id === SystemProviderIds.gateway) { const formatModel = function (models: GatewayLanguageModelEntry[]): Model[] { return models.map((m) => ({ id: m.id, diff --git a/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts b/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts index e755ce3f20..92f24b4abe 100644 --- a/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts @@ -9,6 +9,7 @@ import { import { REFERENCE_PROMPT } from '@renderer/config/prompts' import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio' import { getAssistantSettings } from '@renderer/services/AssistantService' +import type { RootState } from '@renderer/store' import type { Assistant, GenerateImageParams, @@ -245,23 +246,20 @@ export abstract class BaseApiClient< protected getVerbosity(model?: Model): OpenAIVerbosity { try { - const state = window.store?.getState() + const state = window.store?.getState() as RootState const verbosity = state?.settings?.openAI?.verbosity - if (verbosity && ['low', 'medium', 'high'].includes(verbosity)) { - // If model is provided, check if the verbosity is supported by the model - if (model) { - const supportedVerbosity = getModelSupportedVerbosity(model) - // Use user's verbosity if supported, otherwise use the first supported option - return supportedVerbosity.includes(verbosity) ? verbosity : supportedVerbosity[0] - } - return verbosity + // If model is provided, check if the verbosity is supported by the model + if (model) { + const supportedVerbosity = getModelSupportedVerbosity(model) + // Use user's verbosity if supported, otherwise use the first supported option + return supportedVerbosity.includes(verbosity) ? verbosity : supportedVerbosity[0] } + return verbosity } catch (error) { - logger.warn('Failed to get verbosity from state:', error as Error) + logger.warn('Failed to get verbosity from state. Fallback to undefined.', error as Error) + return undefined } - - return 'medium' } protected getTimeout(model: Model) { diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts index ea50680ea4..cfc9087545 100644 --- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts @@ -32,7 +32,6 @@ import { isSupportedThinkingTokenModel, isSupportedThinkingTokenQwenModel, isSupportedThinkingTokenZhipuModel, - isSupportVerbosityModel, isVisionModel, MODEL_SUPPORTED_REASONING_EFFORT, ZHIPU_RESULT_TOKENS @@ -714,13 +713,8 @@ export class OpenAIAPIClient extends OpenAIBaseClient< ...modalities, // groq 有不同的 service tier 配置,不符合 openai 接口类型 service_tier: this.getServiceTier(model) as OpenAIServiceTier, - ...(isSupportVerbosityModel(model) - ? { - text: { - verbosity: this.getVerbosity(model) - } - } - : {}), + // verbosity. getVerbosity ensures the returned value is valid. + verbosity: this.getVerbosity(model), ...this.getProviderSpecificParameters(assistant, model), ...reasoningEffort, // ...getOpenAIWebSearchParams(model, enableWebSearch), diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts index 9a8d5f8383..dc97e74a3c 100644 --- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts @@ -11,7 +11,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings' import { getAssistantSettings } from '@renderer/services/AssistantService' import store from '@renderer/store' import type { SettingsState } from '@renderer/store/settings' -import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types' +import { type Assistant, type GenerateImageParams, type Model, type Provider } from '@renderer/types' import type { OpenAIResponseSdkMessageParam, OpenAIResponseSdkParams, @@ -25,7 +25,8 @@ import type { OpenAISdkRawOutput, ReasoningEffortOptionalParams } from '@renderer/types/sdk' -import { formatApiHost } from '@renderer/utils/api' +import { formatApiHost, withoutTrailingSlash } from '@renderer/utils/api' +import { isOllamaProvider } from '@renderer/utils/provider' import { BaseApiClient } from '../BaseApiClient' @@ -115,6 +116,34 @@ export abstract class OpenAIBaseClient< })) .filter(isSupportedModel) } + + if (isOllamaProvider(this.provider)) { + const baseUrl = withoutTrailingSlash(this.getBaseURL(false)) + .replace(/\/v1$/, '') + .replace(/\/api$/, '') + const response = await fetch(`${baseUrl}/api/tags`, { + headers: { + Authorization: `Bearer ${this.apiKey}`, + ...this.defaultHeaders(), + ...this.provider.extra_headers + } + }) + + if (!response.ok) { + throw new Error(`Ollama server returned ${response.status} ${response.statusText}`) + } + + const data = await response.json() + if (!data?.models || !Array.isArray(data.models)) { + throw new Error('Invalid response from Ollama API: missing models array') + } + + return data.models.map((model) => ({ + id: model.name, + object: 'model', + owned_by: 'ollama' + })) + } const response = await sdk.models.list() if (this.provider.id === 'together') { // @ts-ignore key is not typed diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts index 82e1c32465..6f1ec709b8 100644 --- a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts @@ -4,7 +4,7 @@ import { isGemini3Model, isSupportedThinkingTokenQwenModel } from '@renderer/con import type { MCPTool } from '@renderer/types' import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' -import { isSupportEnableThinkingProvider } from '@renderer/utils/provider' +import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider' import { openrouterReasoningMiddleware, skipGeminiThoughtSignatureMiddleware } from '@shared/middleware' import type { LanguageModelMiddleware } from 'ai' import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai' @@ -239,6 +239,7 @@ function addModelSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: Ai // Use /think or /no_think suffix to control thinking mode if ( config.provider && + !isOllamaProvider(config.provider) && isSupportedThinkingTokenQwenModel(config.model) && !isSupportEnableThinkingProvider(config.provider) ) { diff --git a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts index 6eccbf8bc5..cba7fcdb10 100644 --- a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts +++ b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts @@ -11,12 +11,16 @@ import { vertex } from '@ai-sdk/google-vertex/edge' import { combineHeaders } from '@ai-sdk/provider-utils' import type { AnthropicSearchConfig, WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins' import { isBaseProvider } from '@cherrystudio/ai-core/core/providers/schemas' +import type { BaseProviderId } from '@cherrystudio/ai-core/provider' import { loggerService } from '@logger' import { isAnthropicModel, + isFixedReasoningModel, + isGeminiModel, isGenerateImageModel, + isGrokModel, + isOpenAIModel, isOpenRouterBuiltInWebSearchModel, - isReasoningModel, isSupportedReasoningEffortModel, isSupportedThinkingTokenModel, isWebSearchModel @@ -24,11 +28,12 @@ import { import { getDefaultModel } from '@renderer/services/AssistantService' import store from '@renderer/store' import type { CherryWebSearchConfig } from '@renderer/store/websearch' -import { type Assistant, type MCPTool, type Provider } from '@renderer/types' +import type { Model } from '@renderer/types' +import { type Assistant, type MCPTool, type Provider, SystemProviderIds } from '@renderer/types' import type { StreamTextParams } from '@renderer/types/aiCoreTypes' import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern' import { replacePromptVariables } from '@renderer/utils/prompt' -import { isAwsBedrockProvider } from '@renderer/utils/provider' +import { isAIGatewayProvider, isAwsBedrockProvider } from '@renderer/utils/provider' import type { ModelMessage, Tool } from 'ai' import { stepCountIs } from 'ai' @@ -43,6 +48,25 @@ const logger = loggerService.withContext('parameterBuilder') type ProviderDefinedTool = Extract, { type: 'provider-defined' }> +function mapVertexAIGatewayModelToProviderId(model: Model): BaseProviderId | undefined { + if (isAnthropicModel(model)) { + return 'anthropic' + } + if (isGeminiModel(model)) { + return 'google' + } + if (isGrokModel(model)) { + return 'xai' + } + if (isOpenAIModel(model)) { + return 'openai' + } + logger.warn( + `[mapVertexAIGatewayModelToProviderId] Unknown model type for AI Gateway: ${model.id}. Web search will not be enabled.` + ) + return undefined +} + /** * 构建 AI SDK 流式参数 * 这是主要的参数构建函数,整合所有转换逻辑 @@ -83,7 +107,7 @@ export async function buildStreamTextParams( const enableReasoning = ((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) && assistant.settings?.reasoning_effort !== undefined) || - (isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model))) + isFixedReasoningModel(model) // 判断是否使用内置搜索 // 条件:没有外部搜索提供商 && (用户开启了内置搜索 || 模型强制使用内置搜索) @@ -117,6 +141,11 @@ export async function buildStreamTextParams( if (enableWebSearch) { if (isBaseProvider(aiSdkProviderId)) { webSearchPluginConfig = buildProviderBuiltinWebSearchConfig(aiSdkProviderId, webSearchConfig, model) + } else if (isAIGatewayProvider(provider) || SystemProviderIds.gateway === provider.id) { + const aiSdkProviderId = mapVertexAIGatewayModelToProviderId(model) + if (aiSdkProviderId) { + webSearchPluginConfig = buildProviderBuiltinWebSearchConfig(aiSdkProviderId, webSearchConfig, model) + } } if (!tools) { tools = {} diff --git a/src/renderer/src/aiCore/provider/providerConfig.ts b/src/renderer/src/aiCore/provider/providerConfig.ts index 0d2db2ebdd..c02036b042 100644 --- a/src/renderer/src/aiCore/provider/providerConfig.ts +++ b/src/renderer/src/aiCore/provider/providerConfig.ts @@ -218,7 +218,6 @@ export async function prepareSpecialProviderConfig( ...(config.options.headers ? config.options.headers : {}), 'Content-Type': 'application/json', 'anthropic-version': '2023-06-01', - 'anthropic-beta': 'oauth-2025-04-20', Authorization: `Bearer ${oauthToken}` }, baseURL: 'https://api.anthropic.com/v1', diff --git a/src/renderer/src/aiCore/utils/__tests__/options.test.ts b/src/renderer/src/aiCore/utils/__tests__/options.test.ts index ca6b883d74..9eeeac725b 100644 --- a/src/renderer/src/aiCore/utils/__tests__/options.test.ts +++ b/src/renderer/src/aiCore/utils/__tests__/options.test.ts @@ -27,7 +27,8 @@ vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => { 'xai', 'deepseek', 'openrouter', - 'openai-compatible' + 'openai-compatible', + 'cherryin' ] if (baseProviders.includes(id)) { return { success: true, data: id } @@ -37,7 +38,15 @@ vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => { }, customProviderIdSchema: { safeParse: vi.fn((id) => { - const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock'] + const customProviders = [ + 'google-vertex', + 'google-vertex-anthropic', + 'bedrock', + 'gateway', + 'aihubmix', + 'newapi', + 'ollama' + ] if (customProviders.includes(id)) { return { success: true, data: id } } @@ -47,20 +56,7 @@ vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => { } }) -vi.mock('../provider/factory', () => ({ - getAiSdkProviderId: vi.fn((provider) => { - // Simulate the provider ID mapping - const mapping: Record = { - [SystemProviderIds.gemini]: 'google', - [SystemProviderIds.openai]: 'openai', - [SystemProviderIds.anthropic]: 'anthropic', - [SystemProviderIds.grok]: 'xai', - [SystemProviderIds.deepseek]: 'deepseek', - [SystemProviderIds.openrouter]: 'openrouter' - } - return mapping[provider.id] || provider.id - }) -})) +// Don't mock getAiSdkProviderId - use real implementation for more accurate tests vi.mock('@renderer/config/models', async (importOriginal) => ({ ...(await importOriginal()), @@ -179,8 +175,11 @@ describe('options utils', () => { provider: SystemProviderIds.openai } as Model - beforeEach(() => { + beforeEach(async () => { vi.clearAllMocks() + // Reset getCustomParameters to return empty object by default + const { getCustomParameters } = await import('../reasoning') + vi.mocked(getCustomParameters).mockReturnValue({}) }) describe('buildProviderOptions', () => { @@ -391,7 +390,6 @@ describe('options utils', () => { enableWebSearch: false, enableGenerateImage: false }) - expect(result.providerOptions).toHaveProperty('deepseek') expect(result.providerOptions.deepseek).toBeDefined() }) @@ -461,10 +459,14 @@ describe('options utils', () => { } ) - expect(result.providerOptions.openai).toHaveProperty('custom_param') - expect(result.providerOptions.openai.custom_param).toBe('custom_value') - expect(result.providerOptions.openai).toHaveProperty('another_param') - expect(result.providerOptions.openai.another_param).toBe(123) + expect(result.providerOptions).toStrictEqual({ + openai: { + custom_param: 'custom_value', + another_param: 123, + serviceTier: undefined, + textVerbosity: undefined + } + }) }) it('should extract AI SDK standard params from custom parameters', async () => { @@ -696,5 +698,459 @@ describe('options utils', () => { }) }) }) + + describe('AI Gateway provider', () => { + const gatewayProvider: Provider = { + id: SystemProviderIds.gateway, + name: 'Vercel AI Gateway', + type: 'gateway', + apiKey: 'test-key', + apiHost: 'https://gateway.vercel.com', + isSystem: true + } as Provider + + it('should build OpenAI options for OpenAI models through gateway', () => { + const openaiModel: Model = { + id: 'openai/gpt-4', + name: 'GPT-4', + provider: SystemProviderIds.gateway + } as Model + + const result = buildProviderOptions(mockAssistant, openaiModel, gatewayProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + expect(result.providerOptions).toHaveProperty('openai') + expect(result.providerOptions.openai).toBeDefined() + }) + + it('should build Anthropic options for Anthropic models through gateway', () => { + const anthropicModel: Model = { + id: 'anthropic/claude-3-5-sonnet-20241022', + name: 'Claude 3.5 Sonnet', + provider: SystemProviderIds.gateway + } as Model + + const result = buildProviderOptions(mockAssistant, anthropicModel, gatewayProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + expect(result.providerOptions).toHaveProperty('anthropic') + expect(result.providerOptions.anthropic).toBeDefined() + }) + + it('should build Google options for Gemini models through gateway', () => { + const geminiModel: Model = { + id: 'google/gemini-2.0-flash-exp', + name: 'Gemini 2.0 Flash', + provider: SystemProviderIds.gateway + } as Model + + const result = buildProviderOptions(mockAssistant, geminiModel, gatewayProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + expect(result.providerOptions).toHaveProperty('google') + expect(result.providerOptions.google).toBeDefined() + }) + + it('should build xAI options for Grok models through gateway', () => { + const grokModel: Model = { + id: 'xai/grok-2-latest', + name: 'Grok 2', + provider: SystemProviderIds.gateway + } as Model + + const result = buildProviderOptions(mockAssistant, grokModel, gatewayProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + expect(result.providerOptions).toHaveProperty('xai') + expect(result.providerOptions.xai).toBeDefined() + }) + + it('should include reasoning parameters for Anthropic models when enabled', () => { + const anthropicModel: Model = { + id: 'anthropic/claude-3-5-sonnet-20241022', + name: 'Claude 3.5 Sonnet', + provider: SystemProviderIds.gateway + } as Model + + const result = buildProviderOptions(mockAssistant, anthropicModel, gatewayProvider, { + enableReasoning: true, + enableWebSearch: false, + enableGenerateImage: false + }) + + expect(result.providerOptions.anthropic).toHaveProperty('thinking') + expect(result.providerOptions.anthropic.thinking).toEqual({ + type: 'enabled', + budgetTokens: 5000 + }) + }) + + it('should merge gateway routing options from custom parameters', async () => { + const { getCustomParameters } = await import('../reasoning') + + vi.mocked(getCustomParameters).mockReturnValue({ + gateway: { + order: ['vertex', 'anthropic'], + only: ['vertex', 'anthropic'] + } + }) + + const anthropicModel: Model = { + id: 'anthropic/claude-3-5-sonnet-20241022', + name: 'Claude 3.5 Sonnet', + provider: SystemProviderIds.gateway + } as Model + + const result = buildProviderOptions(mockAssistant, anthropicModel, gatewayProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Should have both anthropic provider options and gateway routing options + expect(result.providerOptions).toHaveProperty('anthropic') + expect(result.providerOptions).toHaveProperty('gateway') + expect(result.providerOptions.gateway).toEqual({ + order: ['vertex', 'anthropic'], + only: ['vertex', 'anthropic'] + }) + }) + + it('should combine provider-specific options with gateway routing options', async () => { + const { getCustomParameters } = await import('../reasoning') + + vi.mocked(getCustomParameters).mockReturnValue({ + gateway: { + order: ['openai', 'anthropic'] + } + }) + + const openaiModel: Model = { + id: 'openai/gpt-4', + name: 'GPT-4', + provider: SystemProviderIds.gateway + } as Model + + const result = buildProviderOptions(mockAssistant, openaiModel, gatewayProvider, { + enableReasoning: true, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Should have OpenAI provider options with reasoning + expect(result.providerOptions.openai).toBeDefined() + expect(result.providerOptions.openai).toHaveProperty('reasoningEffort') + + // Should also have gateway routing options + expect(result.providerOptions.gateway).toBeDefined() + expect(result.providerOptions.gateway.order).toEqual(['openai', 'anthropic']) + }) + + it('should build generic options for unknown model types through gateway', () => { + const unknownModel: Model = { + id: 'unknown-provider/model-name', + name: 'Unknown Model', + provider: SystemProviderIds.gateway + } as Model + + const result = buildProviderOptions(mockAssistant, unknownModel, gatewayProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + expect(result.providerOptions).toHaveProperty('openai-compatible') + expect(result.providerOptions['openai-compatible']).toBeDefined() + }) + }) + + describe('Proxy provider custom parameters mapping', () => { + it('should map cherryin provider ID to actual AI SDK provider ID (Google)', async () => { + const { getCustomParameters } = await import('../reasoning') + + // Mock Cherry In provider that uses Google SDK + const cherryinProvider = { + id: 'cherryin', + name: 'Cherry In', + type: 'gemini', // Using Google SDK + apiKey: 'test-key', + apiHost: 'https://cherryin.com', + models: [] as Model[] + } as Provider + + const geminiModel: Model = { + id: 'gemini-2.0-flash-exp', + name: 'Gemini 2.0 Flash', + provider: 'cherryin' + } as Model + + // User provides custom parameters with Cherry Studio provider ID + vi.mocked(getCustomParameters).mockReturnValue({ + cherryin: { + customOption1: 'value1', + customOption2: 'value2' + } + }) + + const result = buildProviderOptions(mockAssistant, geminiModel, cherryinProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Should map to 'google' AI SDK provider, not 'cherryin' + expect(result.providerOptions).toHaveProperty('google') + expect(result.providerOptions).not.toHaveProperty('cherryin') + expect(result.providerOptions.google).toMatchObject({ + customOption1: 'value1', + customOption2: 'value2' + }) + }) + + it('should map cherryin provider ID to actual AI SDK provider ID (OpenAI)', async () => { + const { getCustomParameters } = await import('../reasoning') + + // Mock Cherry In provider that uses OpenAI SDK + const cherryinProvider = { + id: 'cherryin', + name: 'Cherry In', + type: 'openai-response', // Using OpenAI SDK + apiKey: 'test-key', + apiHost: 'https://cherryin.com', + models: [] as Model[] + } as Provider + + const openaiModel: Model = { + id: 'gpt-4', + name: 'GPT-4', + provider: 'cherryin' + } as Model + + // User provides custom parameters with Cherry Studio provider ID + vi.mocked(getCustomParameters).mockReturnValue({ + cherryin: { + customOpenAIOption: 'openai_value' + } + }) + + const result = buildProviderOptions(mockAssistant, openaiModel, cherryinProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Should map to 'openai' AI SDK provider, not 'cherryin' + expect(result.providerOptions).toHaveProperty('openai') + expect(result.providerOptions).not.toHaveProperty('cherryin') + expect(result.providerOptions.openai).toMatchObject({ + customOpenAIOption: 'openai_value' + }) + }) + + it('should allow direct AI SDK provider ID in custom parameters', async () => { + const { getCustomParameters } = await import('../reasoning') + + const geminiProvider = { + id: SystemProviderIds.gemini, + name: 'Google', + type: 'gemini', + apiKey: 'test-key', + apiHost: 'https://generativelanguage.googleapis.com', + models: [] as Model[] + } as Provider + + const geminiModel: Model = { + id: 'gemini-2.0-flash-exp', + name: 'Gemini 2.0 Flash', + provider: SystemProviderIds.gemini + } as Model + + // User provides custom parameters directly with AI SDK provider ID + vi.mocked(getCustomParameters).mockReturnValue({ + google: { + directGoogleOption: 'google_value' + } + }) + + const result = buildProviderOptions(mockAssistant, geminiModel, geminiProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Should merge directly to 'google' provider + expect(result.providerOptions.google).toMatchObject({ + directGoogleOption: 'google_value' + }) + }) + + it('should map gateway provider custom parameters to actual AI SDK provider', async () => { + const { getCustomParameters } = await import('../reasoning') + + const gatewayProvider: Provider = { + id: SystemProviderIds.gateway, + name: 'Vercel AI Gateway', + type: 'gateway', + apiKey: 'test-key', + apiHost: 'https://gateway.vercel.com', + isSystem: true + } as Provider + + const anthropicModel: Model = { + id: 'anthropic/claude-3-5-sonnet-20241022', + name: 'Claude 3.5 Sonnet', + provider: SystemProviderIds.gateway + } as Model + + // User provides both gateway routing options and gateway-scoped custom parameters + vi.mocked(getCustomParameters).mockReturnValue({ + gateway: { + order: ['vertex', 'anthropic'], + only: ['vertex'] + }, + customParam: 'should_go_to_anthropic' + }) + + const result = buildProviderOptions(mockAssistant, anthropicModel, gatewayProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Gateway routing options should be preserved + expect(result.providerOptions.gateway).toEqual({ + order: ['vertex', 'anthropic'], + only: ['vertex'] + }) + + // Custom parameters should go to the actual AI SDK provider (anthropic) + expect(result.providerOptions.anthropic).toMatchObject({ + customParam: 'should_go_to_anthropic' + }) + }) + + it('should handle mixed custom parameters (AI SDK provider ID + custom params)', async () => { + const { getCustomParameters } = await import('../reasoning') + + const openaiProvider: Provider = { + id: SystemProviderIds.openai, + name: 'OpenAI', + type: 'openai-response', + apiKey: 'test-key', + apiHost: 'https://api.openai.com/v1', + isSystem: true + } as Provider + + // User provides both direct AI SDK provider params and custom params + vi.mocked(getCustomParameters).mockReturnValue({ + openai: { + providerSpecific: 'value1' + }, + customParam1: 'value2', + customParam2: 123 + }) + + const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Should merge both into 'openai' provider options + expect(result.providerOptions.openai).toMatchObject({ + providerSpecific: 'value1', + customParam1: 'value2', + customParam2: 123 + }) + }) + + // Note: For proxy providers like aihubmix/newapi, users should write AI SDK provider ID (google/anthropic) + // instead of the Cherry Studio provider ID for custom parameters to work correctly + + it('should handle cherryin fallback to openai-compatible with custom parameters', async () => { + const { getCustomParameters } = await import('../reasoning') + + // Mock cherryin provider that falls back to openai-compatible (default case) + const cherryinProvider = { + id: 'cherryin', + name: 'Cherry In', + type: 'openai', + apiKey: 'test-key', + apiHost: 'https://cherryin.com', + models: [] as Model[] + } as Provider + + const testModel: Model = { + id: 'some-model', + name: 'Some Model', + provider: 'cherryin' + } as Model + + // User provides custom parameters with cherryin provider ID + vi.mocked(getCustomParameters).mockReturnValue({ + customCherryinOption: 'cherryin_value' + }) + + const result = buildProviderOptions(mockAssistant, testModel, cherryinProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // When cherryin falls back to default case, it should use rawProviderId (cherryin) + // User's cherryin params should merge with the provider options + expect(result.providerOptions).toHaveProperty('cherryin') + expect(result.providerOptions.cherryin).toMatchObject({ + customCherryinOption: 'cherryin_value' + }) + }) + + it('should handle cross-provider configurations', async () => { + const { getCustomParameters } = await import('../reasoning') + + const openaiProvider: Provider = { + id: SystemProviderIds.openai, + name: 'OpenAI', + type: 'openai-response', + apiKey: 'test-key', + apiHost: 'https://api.openai.com/v1', + isSystem: true + } as Provider + + // User provides parameters for multiple providers + // In real usage, anthropic/google params would be treated as regular params for openai provider + vi.mocked(getCustomParameters).mockReturnValue({ + openai: { + openaiSpecific: 'openai_value' + }, + customParam: 'value' + }) + + const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Should have openai provider options with both scoped and custom params + expect(result.providerOptions).toHaveProperty('openai') + expect(result.providerOptions.openai).toMatchObject({ + openaiSpecific: 'openai_value', + customParam: 'value' + }) + }) + }) }) }) diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts index a1352a801a..8ec46c9df2 100644 --- a/src/renderer/src/aiCore/utils/options.ts +++ b/src/renderer/src/aiCore/utils/options.ts @@ -1,5 +1,5 @@ import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock' -import type { AnthropicProviderOptions } from '@ai-sdk/anthropic' +import { type AnthropicProviderOptions } from '@ai-sdk/anthropic' import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google' import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai' import type { XaiProviderOptions } from '@ai-sdk/xai' @@ -7,6 +7,9 @@ import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-c import { loggerService } from '@logger' import { getModelSupportedVerbosity, + isAnthropicModel, + isGeminiModel, + isGrokModel, isOpenAIModel, isQwenMTModel, isSupportFlexServiceTierModel, @@ -29,12 +32,14 @@ import { type OpenAIServiceTier, OpenAIServiceTiers, type Provider, - type ServiceTier + type ServiceTier, + SystemProviderIds } from '@renderer/types' import { type AiSdkParam, isAiSdkParam, type OpenAIVerbosity } from '@renderer/types/aiCoreTypes' import { isSupportServiceTierProvider, isSupportVerbosityProvider } from '@renderer/utils/provider' import type { JSONValue } from 'ai' import { t } from 'i18next' +import type { OllamaCompletionProviderOptions } from 'ollama-ai-provider-v2' import { addAnthropicHeaders } from '../prepareParams/header' import { getAiSdkProviderId } from '../provider/factory' @@ -156,8 +161,8 @@ export function buildProviderOptions( providerOptions: Record> standardParams: Partial> } { - logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities }) const rawProviderId = getAiSdkProviderId(actualProvider) + logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities, rawProviderId }) // 构建 provider 特定的选项 let providerSpecificOptions: Record = {} const serviceTier = getServiceTier(model, actualProvider) @@ -172,14 +177,13 @@ export function buildProviderOptions( case 'azure': case 'azure-responses': { - const options: OpenAIResponsesProviderOptions = buildOpenAIProviderOptions( + providerSpecificOptions = buildOpenAIProviderOptions( assistant, model, capabilities, serviceTier, textVerbosity ) - providerSpecificOptions = options } break case 'anthropic': @@ -197,10 +201,13 @@ export function buildProviderOptions( case 'openrouter': case 'openai-compatible': { // 对于其他 provider,使用通用的构建逻辑 + const genericOptions = buildGenericProviderOptions(rawProviderId, assistant, model, capabilities) providerSpecificOptions = { - ...buildGenericProviderOptions(assistant, model, capabilities), - serviceTier, - textVerbosity + [rawProviderId]: { + ...genericOptions[rawProviderId], + serviceTier, + textVerbosity + } } break } @@ -236,48 +243,108 @@ export function buildProviderOptions( case 'huggingface': providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier) break + case SystemProviderIds.ollama: + providerSpecificOptions = buildOllamaProviderOptions(assistant, capabilities) + break + case SystemProviderIds.gateway: + providerSpecificOptions = buildAIGatewayOptions(assistant, model, capabilities, serviceTier, textVerbosity) + break default: // 对于其他 provider,使用通用的构建逻辑 + providerSpecificOptions = buildGenericProviderOptions(rawProviderId, assistant, model, capabilities) + // Merge serviceTier and textVerbosity providerSpecificOptions = { - ...buildGenericProviderOptions(assistant, model, capabilities), - serviceTier, - textVerbosity + ...providerSpecificOptions, + [rawProviderId]: { + ...providerSpecificOptions[rawProviderId], + serviceTier, + textVerbosity + } } } } else { throw error } } - - // 获取自定义参数并分离标准参数和 provider 特定参数 + logger.debug('Built providerSpecificOptions', { providerSpecificOptions }) + /** + * Retrieve custom parameters and separate standard parameters from provider-specific parameters. + */ const customParams = getCustomParameters(assistant) const { standardParams, providerParams } = extractAiSdkStandardParams(customParams) + logger.debug('Extracted standardParams and providerParams', { standardParams, providerParams }) - // 合并 provider 特定的自定义参数到 providerSpecificOptions - providerSpecificOptions = { - ...providerSpecificOptions, - ...providerParams - } - - let rawProviderKey = - { - 'google-vertex': 'google', - 'google-vertex-anthropic': 'anthropic', - 'azure-anthropic': 'anthropic', - 'ai-gateway': 'gateway', - azure: 'openai', - 'azure-responses': 'openai' - }[rawProviderId] || rawProviderId - - if (rawProviderKey === 'cherryin') { - rawProviderKey = { gemini: 'google', ['openai-response']: 'openai' }[actualProvider.type] || actualProvider.type + /** + * Get the actual AI SDK provider ID(s) from the already-built providerSpecificOptions. + * For proxy providers (cherryin, aihubmix, newapi), this will be the actual SDK provider (e.g., 'google', 'openai', 'anthropic') + * For regular providers, this will be the provider itself + */ + const actualAiSdkProviderIds = Object.keys(providerSpecificOptions) + const primaryAiSdkProviderId = actualAiSdkProviderIds[0] // Use the first one as primary for non-scoped params + + /** + * Merge custom parameters into providerSpecificOptions. + * Simple logic: + * 1. If key is in actualAiSdkProviderIds → merge directly (user knows the actual AI SDK provider ID) + * 2. If key == rawProviderId: + * - If it's gateway/ollama → preserve (they need their own config for routing/options) + * - Otherwise → map to primary (this is a proxy provider like cherryin) + * 3. Otherwise → treat as regular parameter, merge to primary provider + * + * Example: + * - User writes `cherryin: { opt: 'val' }` → mapped to `google: { opt: 'val' }` (case 2, proxy) + * - User writes `gateway: { order: [...] }` → stays as `gateway: { order: [...] }` (case 2, routing config) + * - User writes `google: { opt: 'val' }` → stays as `google: { opt: 'val' }` (case 1) + * - User writes `customKey: 'val'` → merged to `google: { customKey: 'val' }` (case 3) + */ + for (const key of Object.keys(providerParams)) { + if (actualAiSdkProviderIds.includes(key)) { + // Case 1: Key is an actual AI SDK provider ID - merge directly + providerSpecificOptions = { + ...providerSpecificOptions, + [key]: { + ...providerSpecificOptions[key], + ...providerParams[key] + } + } + } else if (key === rawProviderId && !actualAiSdkProviderIds.includes(rawProviderId)) { + // Case 2: Key is the current provider (not in actualAiSdkProviderIds, so it's a proxy or special provider) + // Gateway is special: it needs routing config preserved + if (key === SystemProviderIds.gateway) { + // Preserve gateway config for routing + providerSpecificOptions = { + ...providerSpecificOptions, + [key]: { + ...providerSpecificOptions[key], + ...providerParams[key] + } + } + } else { + // Proxy provider (cherryin, etc.) - map to actual AI SDK provider + providerSpecificOptions = { + ...providerSpecificOptions, + [primaryAiSdkProviderId]: { + ...providerSpecificOptions[primaryAiSdkProviderId], + ...providerParams[key] + } + } + } + } else { + // Case 3: Regular parameter - merge to primary provider + providerSpecificOptions = { + ...providerSpecificOptions, + [primaryAiSdkProviderId]: { + ...providerSpecificOptions[primaryAiSdkProviderId], + [key]: providerParams[key] + } + } + } } + logger.debug('Final providerSpecificOptions after merging providerParams', { providerSpecificOptions }) // 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数 return { - providerOptions: { - [rawProviderKey]: providerSpecificOptions - }, + providerOptions: providerSpecificOptions, standardParams } } @@ -295,7 +362,7 @@ function buildOpenAIProviderOptions( }, serviceTier: OpenAIServiceTier, textVerbosity?: OpenAIVerbosity -): OpenAIResponsesProviderOptions { +): Record { const { enableReasoning } = capabilities let providerOptions: OpenAIResponsesProviderOptions = {} // OpenAI 推理参数 @@ -334,7 +401,9 @@ function buildOpenAIProviderOptions( textVerbosity } - return providerOptions + return { + openai: providerOptions + } } /** @@ -348,7 +417,7 @@ function buildAnthropicProviderOptions( enableWebSearch: boolean enableGenerateImage: boolean } -): AnthropicProviderOptions { +): Record { const { enableReasoning } = capabilities let providerOptions: AnthropicProviderOptions = {} @@ -361,7 +430,11 @@ function buildAnthropicProviderOptions( } } - return providerOptions + return { + anthropic: { + ...providerOptions + } + } } /** @@ -375,7 +448,7 @@ function buildGeminiProviderOptions( enableWebSearch: boolean enableGenerateImage: boolean } -): GoogleGenerativeAIProviderOptions { +): Record { const { enableReasoning, enableGenerateImage } = capabilities let providerOptions: GoogleGenerativeAIProviderOptions = {} @@ -395,7 +468,11 @@ function buildGeminiProviderOptions( } } - return providerOptions + return { + google: { + ...providerOptions + } + } } function buildXAIProviderOptions( @@ -406,7 +483,7 @@ function buildXAIProviderOptions( enableWebSearch: boolean enableGenerateImage: boolean } -): XaiProviderOptions { +): Record { const { enableReasoning } = capabilities let providerOptions: Record = {} @@ -418,7 +495,11 @@ function buildXAIProviderOptions( } } - return providerOptions + return { + xai: { + ...providerOptions + } + } } function buildCherryInProviderOptions( @@ -432,19 +513,20 @@ function buildCherryInProviderOptions( actualProvider: Provider, serviceTier: OpenAIServiceTier, textVerbosity: OpenAIVerbosity -): OpenAIResponsesProviderOptions | AnthropicProviderOptions | GoogleGenerativeAIProviderOptions { +): Record { switch (actualProvider.type) { case 'openai': + return buildGenericProviderOptions('cherryin', assistant, model, capabilities) case 'openai-response': return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier, textVerbosity) - case 'anthropic': return buildAnthropicProviderOptions(assistant, model, capabilities) - case 'gemini': return buildGeminiProviderOptions(assistant, model, capabilities) + + default: + return buildGenericProviderOptions('cherryin', assistant, model, capabilities) } - return {} } /** @@ -458,7 +540,7 @@ function buildBedrockProviderOptions( enableWebSearch: boolean enableGenerateImage: boolean } -): BedrockProviderOptions { +): Record { const { enableReasoning } = capabilities let providerOptions: BedrockProviderOptions = {} @@ -475,13 +557,35 @@ function buildBedrockProviderOptions( providerOptions.anthropicBeta = betaHeaders } - return providerOptions + return { + bedrock: providerOptions + } +} + +function buildOllamaProviderOptions( + assistant: Assistant, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + } +): Record { + const { enableReasoning } = capabilities + const providerOptions: OllamaCompletionProviderOptions = {} + const reasoningEffort = assistant.settings?.reasoning_effort + if (enableReasoning) { + providerOptions.think = !['none', undefined].includes(reasoningEffort) + } + return { + ollama: providerOptions + } } /** * 构建通用的 providerOptions(用于其他 provider) */ function buildGenericProviderOptions( + providerId: string, assistant: Assistant, model: Model, capabilities: { @@ -524,5 +628,37 @@ function buildGenericProviderOptions( } } - return providerOptions + return { + [providerId]: providerOptions + } +} + +function buildAIGatewayOptions( + assistant: Assistant, + model: Model, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + }, + serviceTier: OpenAIServiceTier, + textVerbosity?: OpenAIVerbosity +): Record< + string, + | OpenAIResponsesProviderOptions + | AnthropicProviderOptions + | GoogleGenerativeAIProviderOptions + | Record +> { + if (isAnthropicModel(model)) { + return buildAnthropicProviderOptions(assistant, model, capabilities) + } else if (isOpenAIModel(model)) { + return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier, textVerbosity) + } else if (isGeminiModel(model)) { + return buildGeminiProviderOptions(assistant, model, capabilities) + } else if (isGrokModel(model)) { + return buildXAIProviderOptions(assistant, model, capabilities) + } else { + return buildGenericProviderOptions('openai-compatible', assistant, model, capabilities) + } } diff --git a/src/renderer/src/aiCore/utils/reasoning.ts b/src/renderer/src/aiCore/utils/reasoning.ts index f320a9f5d9..46350b085f 100644 --- a/src/renderer/src/aiCore/utils/reasoning.ts +++ b/src/renderer/src/aiCore/utils/reasoning.ts @@ -250,9 +250,25 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin enable_thinking: true, incremental_output: true } + // TODO: 支持 new-api类型 + case SystemProviderIds['new-api']: + case SystemProviderIds.cherryin: { + return { + extra_body: { + thinking: { + type: 'enabled' // auto is invalid + } + } + } + } case SystemProviderIds.hunyuan: case SystemProviderIds['tencent-cloud-ti']: case SystemProviderIds.doubao: + case SystemProviderIds.deepseek: + case SystemProviderIds.aihubmix: + case SystemProviderIds.sophnet: + case SystemProviderIds.ppio: + case SystemProviderIds.dmxapi: return { thinking: { type: 'enabled' // auto is invalid @@ -274,8 +290,6 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin logger.warn( `Skipping thinking options for provider ${provider.name} as DeepSeek v3.1 thinking control method is unknown` ) - case SystemProviderIds.silicon: - // specially handled before } } } diff --git a/src/renderer/src/components/CodeBlockView/view.tsx b/src/renderer/src/components/CodeBlockView/view.tsx index 2ba94d0ef2..8b251b9603 100644 --- a/src/renderer/src/components/CodeBlockView/view.tsx +++ b/src/renderer/src/components/CodeBlockView/view.tsx @@ -264,9 +264,10 @@ export const CodeBlockView: React.FC = memo(({ children, language, onSave expanded={shouldExpand} wrapped={shouldWrap} maxHeight={`${MAX_COLLAPSED_CODE_HEIGHT}px`} + onRequestExpand={codeCollapsible ? () => setExpandOverride(true) : undefined} /> ), - [children, codeEditor.enabled, handleHeightChange, language, onSave, shouldExpand, shouldWrap] + [children, codeCollapsible, codeEditor.enabled, handleHeightChange, language, onSave, shouldExpand, shouldWrap] ) // 特殊视图组件映射 diff --git a/src/renderer/src/components/CodeToolbar/__tests__/__snapshots__/CodeToolbar.test.tsx.snap b/src/renderer/src/components/CodeToolbar/__tests__/__snapshots__/CodeToolbar.test.tsx.snap index c2b4028e32..56fb14ccc4 100644 --- a/src/renderer/src/components/CodeToolbar/__tests__/__snapshots__/CodeToolbar.test.tsx.snap +++ b/src/renderer/src/components/CodeToolbar/__tests__/__snapshots__/CodeToolbar.test.tsx.snap @@ -64,7 +64,11 @@ exports[`CodeToolbar > basic rendering > should match snapshot with mixed tools data-title="code_block.more" >