diff --git a/package.json b/package.json index 3f95aee6d5..fd5eb0151d 100644 --- a/package.json +++ b/package.json @@ -318,6 +318,7 @@ "motion": "^12.10.5", "notion-helper": "^1.3.22", "npx-scope-finder": "^1.2.0", + "ollama-ai-provider-v2": "^1.5.5", "oxlint": "^1.22.0", "oxlint-tsgolint": "^0.2.0", "p-queue": "^8.1.0", diff --git a/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts b/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts index 8a780d5618..e9f459fd6c 100644 --- a/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts +++ b/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts @@ -19,19 +19,9 @@ export default class EmbeddingsFactory { }) } if (provider === 'ollama') { - if (baseURL.includes('v1/')) { - return new OllamaEmbeddings({ - model: model, - baseUrl: baseURL.replace('v1/', ''), - requestOptions: { - // @ts-ignore expected - 'encoding-format': 'float' - } - }) - } return new OllamaEmbeddings({ model: model, - baseUrl: baseURL, + baseUrl: baseURL.replace(/\/api$/, ''), requestOptions: { // @ts-ignore expected 'encoding-format': 'float' diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts index 9a8d5f8383..dc97e74a3c 100644 --- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts @@ -11,7 +11,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings' import { getAssistantSettings } from '@renderer/services/AssistantService' import store from '@renderer/store' import type { SettingsState } from '@renderer/store/settings' -import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types' +import { type Assistant, type GenerateImageParams, type Model, type Provider } from '@renderer/types' import type { OpenAIResponseSdkMessageParam, OpenAIResponseSdkParams, @@ -25,7 +25,8 @@ import type { OpenAISdkRawOutput, ReasoningEffortOptionalParams } from '@renderer/types/sdk' -import { formatApiHost } from '@renderer/utils/api' +import { formatApiHost, withoutTrailingSlash } from '@renderer/utils/api' +import { isOllamaProvider } from '@renderer/utils/provider' import { BaseApiClient } from '../BaseApiClient' @@ -115,6 +116,34 @@ export abstract class OpenAIBaseClient< })) .filter(isSupportedModel) } + + if (isOllamaProvider(this.provider)) { + const baseUrl = withoutTrailingSlash(this.getBaseURL(false)) + .replace(/\/v1$/, '') + .replace(/\/api$/, '') + const response = await fetch(`${baseUrl}/api/tags`, { + headers: { + Authorization: `Bearer ${this.apiKey}`, + ...this.defaultHeaders(), + ...this.provider.extra_headers + } + }) + + if (!response.ok) { + throw new Error(`Ollama server returned ${response.status} ${response.statusText}`) + } + + const data = await response.json() + if (!data?.models || !Array.isArray(data.models)) { + throw new Error('Invalid response from Ollama API: missing models array') + } + + return data.models.map((model) => ({ + id: model.name, + object: 'model', + owned_by: 'ollama' + })) + } const response = await sdk.models.list() if (this.provider.id === 'together') { // @ts-ignore key is not typed diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts index b314ddd737..10a4d59384 100644 --- a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts @@ -4,7 +4,7 @@ import { isGemini3Model, isSupportedThinkingTokenQwenModel } from '@renderer/con import type { MCPTool } from '@renderer/types' import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' -import { isSupportEnableThinkingProvider } from '@renderer/utils/provider' +import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider' import type { LanguageModelMiddleware } from 'ai' import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai' import { isEmpty } from 'lodash' @@ -240,6 +240,7 @@ function addModelSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: Ai // Use /think or /no_think suffix to control thinking mode if ( config.provider && + !isOllamaProvider(config.provider) && isSupportedThinkingTokenQwenModel(config.model) && !isSupportEnableThinkingProvider(config.provider) ) { diff --git a/src/renderer/src/aiCore/provider/providerConfig.ts b/src/renderer/src/aiCore/provider/providerConfig.ts index 528cc8f660..a5a84fccae 100644 --- a/src/renderer/src/aiCore/provider/providerConfig.ts +++ b/src/renderer/src/aiCore/provider/providerConfig.ts @@ -11,17 +11,24 @@ import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useV import { getProviderByModel } from '@renderer/services/AssistantService' import store from '@renderer/store' import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types' -import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api' +import { + formatApiHost, + formatAzureOpenAIApiHost, + formatOllamaApiHost, + formatVertexApiHost, + routeToEndpoint +} from '@renderer/utils/api' import { isAnthropicProvider, isAzureOpenAIProvider, isCherryAIProvider, isGeminiProvider, isNewApiProvider, + isOllamaProvider, isPerplexityProvider, isVertexProvider } from '@renderer/utils/provider' -import { cloneDeep } from 'lodash' +import { cloneDeep, isEmpty } from 'lodash' import type { AiSdkConfig } from '../types' import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config' @@ -99,6 +106,8 @@ export function formatProviderApiHost(provider: Provider): Provider { } } else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) { formatted.apiHost = formatApiHost(formatted.apiHost, false) + } else if (isOllamaProvider(formatted)) { + formatted.apiHost = formatOllamaApiHost(formatted.apiHost) } else if (isGeminiProvider(formatted)) { formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta') } else if (isAzureOpenAIProvider(formatted)) { @@ -183,6 +192,19 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A } } + if (isOllamaProvider(actualProvider)) { + return { + providerId: 'ollama', + options: { + ...baseConfig, + headers: { + ...actualProvider.extra_headers, + Authorization: !isEmpty(baseConfig.apiKey) ? `Bearer ${baseConfig.apiKey}` : undefined + } + } + } + } + // 处理OpenAI模式 const extraOptions: any = {} extraOptions.endpoint = endpoint diff --git a/src/renderer/src/aiCore/provider/providerInitialization.ts b/src/renderer/src/aiCore/provider/providerInitialization.ts index 2e4b9fced2..a42e2ac659 100644 --- a/src/renderer/src/aiCore/provider/providerInitialization.ts +++ b/src/renderer/src/aiCore/provider/providerInitialization.ts @@ -94,6 +94,13 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [ import: () => import('@ai-sdk/cerebras'), creatorFunctionName: 'createCerebras', supportsImageGeneration: false + }, + { + id: 'ollama', + name: 'Ollama', + import: () => import('ollama-ai-provider-v2'), + creatorFunctionName: 'createOllama', + supportsImageGeneration: false } ] as const diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts index a1352a801a..e39837dc27 100644 --- a/src/renderer/src/aiCore/utils/options.ts +++ b/src/renderer/src/aiCore/utils/options.ts @@ -29,12 +29,14 @@ import { type OpenAIServiceTier, OpenAIServiceTiers, type Provider, - type ServiceTier + type ServiceTier, + SystemProviderIds } from '@renderer/types' import { type AiSdkParam, isAiSdkParam, type OpenAIVerbosity } from '@renderer/types/aiCoreTypes' import { isSupportServiceTierProvider, isSupportVerbosityProvider } from '@renderer/utils/provider' import type { JSONValue } from 'ai' import { t } from 'i18next' +import type { OllamaCompletionProviderOptions } from 'ollama-ai-provider-v2' import { addAnthropicHeaders } from '../prepareParams/header' import { getAiSdkProviderId } from '../provider/factory' @@ -236,6 +238,9 @@ export function buildProviderOptions( case 'huggingface': providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier) break + case SystemProviderIds.ollama: + providerSpecificOptions = buildOllamaProviderOptions(assistant, capabilities) + break default: // 对于其他 provider,使用通用的构建逻辑 providerSpecificOptions = { @@ -478,6 +483,23 @@ function buildBedrockProviderOptions( return providerOptions } +function buildOllamaProviderOptions( + assistant: Assistant, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + } +): OllamaCompletionProviderOptions { + const { enableReasoning } = capabilities + const providerOptions: OllamaCompletionProviderOptions = {} + const reasoningEffort = assistant.settings?.reasoning_effort + if (enableReasoning) { + providerOptions.think = !['none', undefined].includes(reasoningEffort) + } + return providerOptions +} + /** * 构建通用的 providerOptions(用于其他 provider) */ diff --git a/src/renderer/src/pages/settings/ProviderSettings/AddProviderPopup.tsx b/src/renderer/src/pages/settings/ProviderSettings/AddProviderPopup.tsx index e4923de1ba..b6d7145c91 100644 --- a/src/renderer/src/pages/settings/ProviderSettings/AddProviderPopup.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings/AddProviderPopup.tsx @@ -259,7 +259,8 @@ const PopupContainer: React.FC = ({ provider, resolve }) => { { label: 'Anthropic', value: 'anthropic' }, { label: 'Azure OpenAI', value: 'azure-openai' }, { label: 'New API', value: 'new-api' }, - { label: 'CherryIN', value: 'cherryin-type' } + { label: 'CherryIN', value: 'cherryin-type' }, + { label: 'Ollama', value: 'ollama' } ]} /> diff --git a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx index da05409683..c3b8904cef 100644 --- a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx @@ -29,6 +29,7 @@ import { isAzureOpenAIProvider, isGeminiProvider, isNewApiProvider, + isOllamaProvider, isOpenAICompatibleProvider, isOpenAIProvider, isVertexProvider @@ -277,6 +278,10 @@ const ProviderSetting: FC = ({ providerId }) => { const hostPreview = () => { const formattedApiHost = adaptProvider({ provider: { ...provider, apiHost } }).apiHost + if (isOllamaProvider(provider)) { + return formattedApiHost + '/chat' + } + if (isOpenAICompatibleProvider(provider)) { return formattedApiHost + '/chat/completions' } diff --git a/src/renderer/src/services/KnowledgeService.ts b/src/renderer/src/services/KnowledgeService.ts index ef35027ff5..e2f2e6fc15 100644 --- a/src/renderer/src/services/KnowledgeService.ts +++ b/src/renderer/src/services/KnowledgeService.ts @@ -6,12 +6,13 @@ import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, DEFAULT_KNOWLEDGE_THRESHOLD } from '@ import { getEmbeddingMaxContext } from '@renderer/config/embedings' import { addSpan, endSpan } from '@renderer/services/SpanManagerService' import store from '@renderer/store' -import type { - FileMetadata, - KnowledgeBase, - KnowledgeBaseParams, - KnowledgeReference, - KnowledgeSearchResult +import { + type FileMetadata, + type KnowledgeBase, + type KnowledgeBaseParams, + type KnowledgeReference, + type KnowledgeSearchResult, + SystemProviderIds } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' import { ChunkType } from '@renderer/types/chunk' @@ -50,6 +51,9 @@ export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams baseURL = baseURL + '/openai' } else if (isAzureOpenAIProvider(actualProvider)) { baseURL = baseURL + '/v1' + } else if (actualProvider.id === SystemProviderIds.ollama) { + // LangChain生态不需要/api结尾的URL + baseURL = baseURL.replace(/\/api$/, '') } logger.info(`Knowledge base ${base.name} using baseURL: ${baseURL}`) diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts index 5c562885bb..94b51474b9 100644 --- a/src/renderer/src/store/index.ts +++ b/src/renderer/src/store/index.ts @@ -67,7 +67,7 @@ const persistedReducer = persistReducer( { key: 'cherry-studio', storage, - version: 179, + version: 180, blacklist: ['runtime', 'messages', 'messageBlocks', 'tabs', 'toolPermissions'], migrate }, diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index b2993baf5d..6f05c2b348 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -2917,6 +2917,11 @@ const migrateConfig = { if (state.settings.openAI.verbosity === 'undefined') { state.settings.openAI.verbosity = undefined } + state.llm.providers.forEach((provider) => { + if (provider.id === SystemProviderIds.ollama) { + provider.type = 'ollama' + } + }) logger.info('migrate 180 success') return state } catch (error) { diff --git a/src/renderer/src/types/provider.ts b/src/renderer/src/types/provider.ts index 9d948f16d0..aea72fa287 100644 --- a/src/renderer/src/types/provider.ts +++ b/src/renderer/src/types/provider.ts @@ -15,7 +15,8 @@ export const ProviderTypeSchema = z.enum([ 'aws-bedrock', 'vertex-anthropic', 'new-api', - 'ai-gateway' + 'ai-gateway', + 'ollama' ]) export type ProviderType = z.infer diff --git a/src/renderer/src/utils/__tests__/api.test.ts b/src/renderer/src/utils/__tests__/api.test.ts index c00c2e0f60..fe34dcf26e 100644 --- a/src/renderer/src/utils/__tests__/api.test.ts +++ b/src/renderer/src/utils/__tests__/api.test.ts @@ -6,6 +6,7 @@ import { formatApiHost, formatApiKeys, formatAzureOpenAIApiHost, + formatOllamaApiHost, formatVertexApiHost, getTrailingApiVersion, hasAPIVersion, @@ -341,6 +342,73 @@ describe('api', () => { }) }) + describe('formatOllamaApiHost', () => { + it('removes trailing slash and appends /api for basic hosts', () => { + expect(formatOllamaApiHost('https://api.ollama.com/')).toBe('https://api.ollama.com/api') + expect(formatOllamaApiHost('http://localhost:11434/')).toBe('http://localhost:11434/api') + }) + + it('appends /api when no suffix is present', () => { + expect(formatOllamaApiHost('https://api.ollama.com')).toBe('https://api.ollama.com/api') + expect(formatOllamaApiHost('http://localhost:11434')).toBe('http://localhost:11434/api') + }) + + it('removes /v1 suffix and appends /api', () => { + expect(formatOllamaApiHost('https://api.ollama.com/v1')).toBe('https://api.ollama.com/api') + expect(formatOllamaApiHost('http://localhost:11434/v1/')).toBe('http://localhost:11434/api') + }) + + it('removes /api suffix and keeps /api', () => { + expect(formatOllamaApiHost('https://api.ollama.com/api')).toBe('https://api.ollama.com/api') + expect(formatOllamaApiHost('http://localhost:11434/api/')).toBe('http://localhost:11434/api') + }) + + it('removes /chat suffix and appends /api', () => { + expect(formatOllamaApiHost('https://api.ollama.com/chat')).toBe('https://api.ollama.com/api') + expect(formatOllamaApiHost('http://localhost:11434/chat/')).toBe('http://localhost:11434/api') + }) + + it('handles multiple suffix combinations correctly', () => { + expect(formatOllamaApiHost('https://api.ollama.com/v1/chat')).toBe('https://api.ollama.com/v1/api') + expect(formatOllamaApiHost('https://api.ollama.com/chat/v1')).toBe('https://api.ollama.com/api') + expect(formatOllamaApiHost('https://api.ollama.com/api/chat')).toBe('https://api.ollama.com/api/api') + }) + + it('preserves complex paths while handling suffixes', () => { + expect(formatOllamaApiHost('https://api.ollama.com/custom/path')).toBe('https://api.ollama.com/custom/path/api') + expect(formatOllamaApiHost('https://api.ollama.com/custom/path/')).toBe('https://api.ollama.com/custom/path/api') + expect(formatOllamaApiHost('https://api.ollama.com/custom/path/v1')).toBe( + 'https://api.ollama.com/custom/path/api' + ) + }) + + it('handles edge cases with multiple slashes', () => { + expect(formatOllamaApiHost('https://api.ollama.com//')).toBe('https://api.ollama.com//api') + expect(formatOllamaApiHost('https://api.ollama.com///v1///')).toBe('https://api.ollama.com///v1///api') + }) + + it('handles localhost with different ports', () => { + expect(formatOllamaApiHost('http://localhost:3000')).toBe('http://localhost:3000/api') + expect(formatOllamaApiHost('http://127.0.0.1:11434/')).toBe('http://127.0.0.1:11434/api') + expect(formatOllamaApiHost('https://localhost:8080/v1')).toBe('https://localhost:8080/api') + }) + + it('handles IP addresses', () => { + expect(formatOllamaApiHost('http://192.168.1.100:11434')).toBe('http://192.168.1.100:11434/api') + expect(formatOllamaApiHost('https://10.0.0.1:8080/v1/')).toBe('https://10.0.0.1:8080/api') + }) + + it('handles empty strings and edge cases', () => { + expect(formatOllamaApiHost('')).toBe('/api') + expect(formatOllamaApiHost('/')).toBe('/api') + }) + + it('preserves protocol and handles mixed case', () => { + expect(formatOllamaApiHost('HTTPS://API.OLLAMA.COM')).toBe('HTTPS://API.OLLAMA.COM/api') + expect(formatOllamaApiHost('HTTP://localhost:11434/V1/')).toBe('HTTP://localhost:11434/V1/api') + }) + }) + describe('getTrailingApiVersion', () => { it('extracts trailing API version from URL', () => { expect(getTrailingApiVersion('https://api.example.com/v1')).toBe('v1') diff --git a/src/renderer/src/utils/api.ts b/src/renderer/src/utils/api.ts index efadb7813c..10f31ae5c0 100644 --- a/src/renderer/src/utils/api.ts +++ b/src/renderer/src/utils/api.ts @@ -110,6 +110,17 @@ export function formatApiHost(host?: string, supportApiVersion: boolean = true, } } +/** + * 格式化 Ollama 的 API 主机地址。 + */ +export function formatOllamaApiHost(host: string): string { + const normalizedHost = withoutTrailingSlash(host) + ?.replace(/\/v1$/, '') + ?.replace(/\/api$/, '') + ?.replace(/\/chat$/, '') + return formatApiHost(normalizedHost + '/api', false) +} + /** * 格式化 Azure OpenAI 的 API 主机地址。 */ diff --git a/src/renderer/src/utils/provider.ts b/src/renderer/src/utils/provider.ts index fae0aababa..0af511b97e 100644 --- a/src/renderer/src/utils/provider.ts +++ b/src/renderer/src/utils/provider.ts @@ -175,6 +175,10 @@ export function isAIGatewayProvider(provider: Provider): boolean { return provider.type === 'ai-gateway' } +export function isOllamaProvider(provider: Provider): boolean { + return provider.type === 'ollama' +} + const NOT_SUPPORT_API_VERSION_PROVIDERS = ['github', 'copilot', 'perplexity'] as const satisfies SystemProviderId[] export const isSupportAPIVersionProvider = (provider: Provider) => { diff --git a/yarn.lock b/yarn.lock index 22b6c581db..c832447198 100644 --- a/yarn.lock +++ b/yarn.lock @@ -10232,6 +10232,7 @@ __metadata: notion-helper: "npm:^1.3.22" npx-scope-finder: "npm:^1.2.0" officeparser: "npm:^4.2.0" + ollama-ai-provider-v2: "npm:^1.5.5" os-proxy-config: "npm:^1.1.2" oxlint: "npm:^1.22.0" oxlint-tsgolint: "npm:^0.2.0" @@ -19934,6 +19935,18 @@ __metadata: languageName: node linkType: hard +"ollama-ai-provider-v2@npm:^1.5.5": + version: 1.5.5 + resolution: "ollama-ai-provider-v2@npm:1.5.5" + dependencies: + "@ai-sdk/provider": "npm:^2.0.0" + "@ai-sdk/provider-utils": "npm:^3.0.17" + peerDependencies: + zod: ^4.0.16 + checksum: 10c0/da40c8097bd8205c46eccfbd13e77c51a6ce97a29b886adfc9e1b8444460b558138d1ed4428491fcc9378d46f649dd0a9b1e5b13cf6bbc8f5385e8b321734e72 + languageName: node + linkType: hard + "ollama@npm:^0.5.12": version: 0.5.16 resolution: "ollama@npm:0.5.16"