diff --git a/.oxlintrc.json b/.oxlintrc.json index 329d08c04..7d18f83c7 100644 --- a/.oxlintrc.json +++ b/.oxlintrc.json @@ -51,6 +51,12 @@ "node": true }, "files": ["src/preload/**"] + }, + { + "files": ["packages/ai-sdk-provider/**"], + "globals": { + "fetch": "readonly" + } } ], "plugins": ["unicorn", "typescript", "oxc", "import"], diff --git a/electron.vite.config.ts b/electron.vite.config.ts index b4914539c..172d48ca9 100644 --- a/electron.vite.config.ts +++ b/electron.vite.config.ts @@ -95,7 +95,8 @@ export default defineConfig({ '@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'), '@cherrystudio/ai-core/built-in/plugins': resolve('packages/aiCore/src/core/plugins/built-in'), '@cherrystudio/ai-core': resolve('packages/aiCore/src'), - '@cherrystudio/extension-table-plus': resolve('packages/extension-table-plus/src') + '@cherrystudio/extension-table-plus': resolve('packages/extension-table-plus/src'), + '@cherrystudio/ai-sdk-provider': resolve('packages/ai-sdk-provider/src') } }, optimizeDeps: { diff --git a/packages/ai-sdk-provider/README.md b/packages/ai-sdk-provider/README.md new file mode 100644 index 000000000..ecd9df292 --- /dev/null +++ b/packages/ai-sdk-provider/README.md @@ -0,0 +1,39 @@ +# @cherrystudio/ai-sdk-provider + +CherryIN provider bundle for the [Vercel AI SDK](https://ai-sdk.dev/). +It exposes the CherryIN OpenAI-compatible entrypoints and dynamically routes Anthropic and Gemini model ids to their CherryIN upstream equivalents. + +## Installation + +```bash +npm install ai @cherrystudio/ai-sdk-provider @ai-sdk/anthropic @ai-sdk/google @ai-sdk/openai +# or +yarn add ai @cherrystudio/ai-sdk-provider @ai-sdk/anthropic @ai-sdk/google @ai-sdk/openai +``` + +> **Note**: This package requires peer dependencies `ai`, `@ai-sdk/anthropic`, `@ai-sdk/google`, and `@ai-sdk/openai` to be installed. + +## Usage + +```ts +import { createCherryIn, cherryIn } from '@cherrystudio/ai-sdk-provider' + +const cherryInProvider = createCherryIn({ + apiKey: process.env.CHERRYIN_API_KEY, + // optional overrides: + // baseURL: 'https://open.cherryin.net/v1', + // anthropicBaseURL: 'https://open.cherryin.net/anthropic', + // geminiBaseURL: 'https://open.cherryin.net/gemini/v1beta', +}) + +// Chat models will auto-route based on the model id prefix: +const openaiModel = cherryInProvider.chat('gpt-4o-mini') +const anthropicModel = cherryInProvider.chat('claude-3-5-sonnet-latest') +const geminiModel = cherryInProvider.chat('gemini-2.0-pro-exp') + +const { text } = await openaiModel.invoke('Hello CherryIN!') +``` + +The provider also exposes `completion`, `responses`, `embedding`, `image`, `transcription`, and `speech` helpers aligned with the upstream APIs. + +See [AI SDK docs](https://ai-sdk.dev/providers/community-providers/custom-providers) for configuring custom providers. diff --git a/packages/ai-sdk-provider/package.json b/packages/ai-sdk-provider/package.json new file mode 100644 index 000000000..fd0aac264 --- /dev/null +++ b/packages/ai-sdk-provider/package.json @@ -0,0 +1,64 @@ +{ + "name": "@cherrystudio/ai-sdk-provider", + "version": "0.1.0", + "description": "Cherry Studio AI SDK provider bundle with CherryIN routing.", + "keywords": [ + "ai-sdk", + "provider", + "cherryin", + "vercel-ai-sdk", + "cherry-studio" + ], + "author": "Cherry Studio", + "license": "MIT", + "homepage": "https://github.com/CherryHQ/cherry-studio", + "repository": { + "type": "git", + "url": "git+https://github.com/CherryHQ/cherry-studio.git", + "directory": "packages/ai-sdk-provider" + }, + "bugs": { + "url": "https://github.com/CherryHQ/cherry-studio/issues" + }, + "type": "module", + "main": "dist/index.cjs", + "module": "dist/index.js", + "types": "dist/index.d.ts", + "files": [ + "dist" + ], + "scripts": { + "build": "tsdown", + "dev": "tsc -w", + "clean": "rm -rf dist", + "test": "vitest run", + "test:watch": "vitest" + }, + "peerDependencies": { + "@ai-sdk/anthropic": "^2.0.29", + "@ai-sdk/google": "^2.0.23", + "@ai-sdk/openai": "^2.0.64", + "ai": "^5.0.26" + }, + "dependencies": { + "@ai-sdk/provider": "^2.0.0", + "@ai-sdk/provider-utils": "^3.0.12" + }, + "devDependencies": { + "tsdown": "^0.13.3", + "typescript": "^5.8.2", + "vitest": "^3.2.4" + }, + "sideEffects": false, + "engines": { + "node": ">=18.0.0" + }, + "exports": { + ".": { + "types": "./dist/index.d.ts", + "import": "./dist/index.js", + "require": "./dist/index.cjs", + "default": "./dist/index.js" + } + } +} diff --git a/packages/ai-sdk-provider/src/cherryin-provider.ts b/packages/ai-sdk-provider/src/cherryin-provider.ts new file mode 100644 index 000000000..478380a41 --- /dev/null +++ b/packages/ai-sdk-provider/src/cherryin-provider.ts @@ -0,0 +1,319 @@ +import { AnthropicMessagesLanguageModel } from '@ai-sdk/anthropic/internal' +import { GoogleGenerativeAILanguageModel } from '@ai-sdk/google/internal' +import type { OpenAIProviderSettings } from '@ai-sdk/openai' +import { + OpenAIChatLanguageModel, + OpenAICompletionLanguageModel, + OpenAIEmbeddingModel, + OpenAIImageModel, + OpenAIResponsesLanguageModel, + OpenAISpeechModel, + OpenAITranscriptionModel +} from '@ai-sdk/openai/internal' +import { + type EmbeddingModelV2, + type ImageModelV2, + type LanguageModelV2, + type ProviderV2, + type SpeechModelV2, + type TranscriptionModelV2 +} from '@ai-sdk/provider' +import type { FetchFunction } from '@ai-sdk/provider-utils' +import { loadApiKey, withoutTrailingSlash } from '@ai-sdk/provider-utils' + +export const CHERRYIN_PROVIDER_NAME = 'cherryin' as const +export const DEFAULT_CHERRYIN_BASE_URL = 'https://open.cherryin.net/v1' +export const DEFAULT_CHERRYIN_ANTHROPIC_BASE_URL = 'https://open.cherryin.net/v1' +export const DEFAULT_CHERRYIN_GEMINI_BASE_URL = 'https://open.cherryin.net/v1beta/models' + +const ANTHROPIC_PREFIX = /^anthropic\//i +const GEMINI_PREFIX = /^google\//i +// const GEMINI_EXCLUDED_SUFFIXES = ['-nothink', '-search'] + +type HeaderValue = string | undefined + +type HeadersInput = Record | (() => Record) + +export interface CherryInProviderSettings { + /** + * CherryIN API key. + * + * If omitted, the provider will read the `CHERRYIN_API_KEY` environment variable. + */ + apiKey?: string + /** + * Optional custom fetch implementation. + */ + fetch?: FetchFunction + /** + * Base URL for OpenAI-compatible CherryIN endpoints. + * + * Defaults to `https://open.cherryin.net/v1`. + */ + baseURL?: string + /** + * Base URL for Anthropic-compatible endpoints. + * + * Defaults to `https://open.cherryin.net/anthropic`. + */ + anthropicBaseURL?: string + /** + * Base URL for Gemini-compatible endpoints. + * + * Defaults to `https://open.cherryin.net/gemini/v1beta`. + */ + geminiBaseURL?: string + /** + * Optional static headers applied to every request. + */ + headers?: HeadersInput +} + +export interface CherryInProvider extends ProviderV2 { + (modelId: string, settings?: OpenAIProviderSettings): LanguageModelV2 + languageModel(modelId: string, settings?: OpenAIProviderSettings): LanguageModelV2 + chat(modelId: string, settings?: OpenAIProviderSettings): LanguageModelV2 + responses(modelId: string): LanguageModelV2 + completion(modelId: string, settings?: OpenAIProviderSettings): LanguageModelV2 + embedding(modelId: string, settings?: OpenAIProviderSettings): EmbeddingModelV2 + textEmbedding(modelId: string, settings?: OpenAIProviderSettings): EmbeddingModelV2 + textEmbeddingModel(modelId: string, settings?: OpenAIProviderSettings): EmbeddingModelV2 + image(modelId: string, settings?: OpenAIProviderSettings): ImageModelV2 + imageModel(modelId: string, settings?: OpenAIProviderSettings): ImageModelV2 + transcription(modelId: string): TranscriptionModelV2 + transcriptionModel(modelId: string): TranscriptionModelV2 + speech(modelId: string): SpeechModelV2 + speechModel(modelId: string): SpeechModelV2 +} + +const resolveApiKey = (options: CherryInProviderSettings): string => + loadApiKey({ + apiKey: options.apiKey, + environmentVariableName: 'CHERRYIN_API_KEY', + description: 'CherryIN' + }) + +const isAnthropicModel = (modelId: string) => ANTHROPIC_PREFIX.test(modelId) +const isGeminiModel = (modelId: string) => GEMINI_PREFIX.test(modelId) + +const createCustomFetch = (originalFetch?: any) => { + return async (url: string, options: any) => { + if (options?.body) { + try { + const body = JSON.parse(options.body) + if (body.tools && Array.isArray(body.tools) && body.tools.length === 0 && body.tool_choice) { + delete body.tool_choice + options.body = JSON.stringify(body) + } + } catch (error) { + // ignore error + } + } + + return originalFetch ? originalFetch(url, options) : fetch(url, options) + } +} +class CherryInOpenAIChatLanguageModel extends OpenAIChatLanguageModel { + constructor(modelId: string, settings: any) { + super(modelId, { + ...settings, + fetch: createCustomFetch(settings.fetch) + }) + } +} + +const resolveConfiguredHeaders = (headers?: HeadersInput): Record => { + if (typeof headers === 'function') { + return { ...headers() } + } + return headers ? { ...headers } : {} +} + +const toBearerToken = (authorization?: string) => (authorization ? authorization.replace(/^Bearer\s+/i, '') : undefined) + +const createJsonHeadersGetter = (options: CherryInProviderSettings): (() => Record) => { + return () => ({ + Authorization: `Bearer ${resolveApiKey(options)}`, + 'Content-Type': 'application/json', + ...resolveConfiguredHeaders(options.headers) + }) +} + +const createAuthHeadersGetter = (options: CherryInProviderSettings): (() => Record) => { + return () => ({ + Authorization: `Bearer ${resolveApiKey(options)}`, + ...resolveConfiguredHeaders(options.headers) + }) +} + +export const createCherryIn = (options: CherryInProviderSettings = {}): CherryInProvider => { + const { + baseURL = DEFAULT_CHERRYIN_BASE_URL, + anthropicBaseURL = DEFAULT_CHERRYIN_ANTHROPIC_BASE_URL, + geminiBaseURL = DEFAULT_CHERRYIN_GEMINI_BASE_URL, + fetch + } = options + + const getJsonHeaders = createJsonHeadersGetter(options) + const getAuthHeaders = createAuthHeadersGetter(options) + + const url = ({ path }: { path: string; modelId: string }) => `${withoutTrailingSlash(baseURL)}${path}` + + const createAnthropicModel = (modelId: string) => + new AnthropicMessagesLanguageModel(modelId, { + provider: `${CHERRYIN_PROVIDER_NAME}.anthropic`, + baseURL: anthropicBaseURL, + headers: () => { + const headers = getJsonHeaders() + const apiKey = toBearerToken(headers.Authorization) + return { + ...headers, + 'x-api-key': apiKey + } + }, + fetch, + supportedUrls: () => ({ + 'image/*': [/^https?:\/\/.*$/] + }) + }) + + const createGeminiModel = (modelId: string) => + new GoogleGenerativeAILanguageModel(modelId, { + provider: `${CHERRYIN_PROVIDER_NAME}.google`, + baseURL: geminiBaseURL, + headers: () => { + const headers = getJsonHeaders() + const apiKey = toBearerToken(headers.Authorization) + return { + ...headers, + 'x-goog-api-key': apiKey + } + }, + fetch, + generateId: () => `${CHERRYIN_PROVIDER_NAME}-${Date.now()}`, + supportedUrls: () => ({}) + }) + + const createOpenAIChatModel = (modelId: string, settings: OpenAIProviderSettings = {}) => + new CherryInOpenAIChatLanguageModel(modelId, { + provider: `${CHERRYIN_PROVIDER_NAME}.openai-chat`, + url, + headers: () => ({ + ...getJsonHeaders(), + ...settings.headers + }), + fetch + }) + + const createChatModel = (modelId: string, settings: OpenAIProviderSettings = {}) => { + if (isAnthropicModel(modelId)) { + return createAnthropicModel(modelId) + } + if (isGeminiModel(modelId)) { + return createGeminiModel(modelId) + } + return new OpenAIResponsesLanguageModel(modelId, { + provider: `${CHERRYIN_PROVIDER_NAME}.openai`, + url, + headers: () => ({ + ...getJsonHeaders(), + ...settings.headers + }), + fetch + }) + } + + const createCompletionModel = (modelId: string, settings: OpenAIProviderSettings = {}) => + new OpenAICompletionLanguageModel(modelId, { + provider: `${CHERRYIN_PROVIDER_NAME}.completion`, + url, + headers: () => ({ + ...getJsonHeaders(), + ...settings.headers + }), + fetch + }) + + const createEmbeddingModel = (modelId: string, settings: OpenAIProviderSettings = {}) => + new OpenAIEmbeddingModel(modelId, { + provider: `${CHERRYIN_PROVIDER_NAME}.embeddings`, + url, + headers: () => ({ + ...getJsonHeaders(), + ...settings.headers + }), + fetch + }) + + const createResponsesModel = (modelId: string) => + new OpenAIResponsesLanguageModel(modelId, { + provider: `${CHERRYIN_PROVIDER_NAME}.responses`, + url, + headers: () => ({ + ...getJsonHeaders() + }), + fetch + }) + + const createImageModel = (modelId: string, settings: OpenAIProviderSettings = {}) => + new OpenAIImageModel(modelId, { + provider: `${CHERRYIN_PROVIDER_NAME}.image`, + url, + headers: () => ({ + ...getJsonHeaders(), + ...settings.headers + }), + fetch + }) + + const createTranscriptionModel = (modelId: string) => + new OpenAITranscriptionModel(modelId, { + provider: `${CHERRYIN_PROVIDER_NAME}.transcription`, + url, + headers: () => ({ + ...getAuthHeaders() + }), + fetch + }) + + const createSpeechModel = (modelId: string) => + new OpenAISpeechModel(modelId, { + provider: `${CHERRYIN_PROVIDER_NAME}.speech`, + url, + headers: () => ({ + ...getJsonHeaders() + }), + fetch + }) + + const provider: CherryInProvider = function (modelId: string, settings?: OpenAIProviderSettings) { + if (new.target) { + throw new Error('CherryIN provider function cannot be called with the new keyword.') + } + + return createChatModel(modelId, settings) + } + + provider.languageModel = createChatModel + provider.chat = createOpenAIChatModel + + provider.responses = createResponsesModel + provider.completion = createCompletionModel + + provider.embedding = createEmbeddingModel + provider.textEmbedding = createEmbeddingModel + provider.textEmbeddingModel = createEmbeddingModel + + provider.image = createImageModel + provider.imageModel = createImageModel + + provider.transcription = createTranscriptionModel + provider.transcriptionModel = createTranscriptionModel + + provider.speech = createSpeechModel + provider.speechModel = createSpeechModel + + return provider +} + +export const cherryIn = createCherryIn() diff --git a/packages/ai-sdk-provider/src/index.ts b/packages/ai-sdk-provider/src/index.ts new file mode 100644 index 000000000..d397dd5af --- /dev/null +++ b/packages/ai-sdk-provider/src/index.ts @@ -0,0 +1 @@ +export * from './cherryin-provider' diff --git a/packages/ai-sdk-provider/tsconfig.json b/packages/ai-sdk-provider/tsconfig.json new file mode 100644 index 000000000..26ee731bb --- /dev/null +++ b/packages/ai-sdk-provider/tsconfig.json @@ -0,0 +1,19 @@ +{ + "compilerOptions": { + "allowSyntheticDefaultImports": true, + "declaration": true, + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "module": "ESNext", + "moduleResolution": "bundler", + "noEmitOnError": false, + "outDir": "./dist", + "resolveJsonModule": true, + "rootDir": "./src", + "skipLibCheck": true, + "strict": true, + "target": "ES2020" + }, + "exclude": ["node_modules", "dist"], + "include": ["src/**/*"] +} diff --git a/packages/ai-sdk-provider/tsdown.config.ts b/packages/ai-sdk-provider/tsdown.config.ts new file mode 100644 index 000000000..0e07d34ca --- /dev/null +++ b/packages/ai-sdk-provider/tsdown.config.ts @@ -0,0 +1,12 @@ +import { defineConfig } from 'tsdown' + +export default defineConfig({ + entry: { + index: 'src/index.ts' + }, + outDir: 'dist', + format: ['esm', 'cjs'], + clean: true, + dts: true, + tsconfig: 'tsconfig.json' +}) diff --git a/packages/aiCore/package.json b/packages/aiCore/package.json index c8b12c489..3973bd9af 100644 --- a/packages/aiCore/package.json +++ b/packages/aiCore/package.json @@ -44,6 +44,7 @@ "@ai-sdk/provider": "^2.0.0", "@ai-sdk/provider-utils": "^3.0.16", "@ai-sdk/xai": "^2.0.31", + "@cherrystudio/ai-sdk-provider": "workspace:*", "zod": "^4.1.5" }, "devDependencies": { diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts index 42bd17e09..a50356130 100644 --- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts @@ -1,9 +1,10 @@ -import type { anthropic } from '@ai-sdk/anthropic' -import type { google } from '@ai-sdk/google' -import type { openai } from '@ai-sdk/openai' +import { anthropic } from '@ai-sdk/anthropic' +import { google } from '@ai-sdk/google' +import { openai } from '@ai-sdk/openai' import type { InferToolInput, InferToolOutput } from 'ai' import { type Tool } from 'ai' +import { createOpenRouterOptions, createXaiOptions, mergeProviderOptions } from '../../../options' import type { ProviderOptionsMap } from '../../../options/types' import type { OpenRouterSearchConfig } from './openrouter' @@ -95,3 +96,56 @@ export type WebSearchToolInputSchema = { google: InferToolInput 'openai-chat': InferToolInput } + +export const switchWebSearchTool = (providerId: string, config: WebSearchPluginConfig, params: any) => { + switch (providerId) { + case 'openai': { + if (config.openai) { + if (!params.tools) params.tools = {} + params.tools.web_search = openai.tools.webSearch(config.openai) + } + break + } + case 'openai-chat': { + if (config['openai-chat']) { + if (!params.tools) params.tools = {} + params.tools.web_search_preview = openai.tools.webSearchPreview(config['openai-chat']) + } + break + } + + case 'anthropic': { + if (config.anthropic) { + if (!params.tools) params.tools = {} + params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic) + } + break + } + + case 'google': { + // case 'google-vertex': + if (!params.tools) params.tools = {} + params.tools.web_search = google.tools.googleSearch(config.google || {}) + break + } + + case 'xai': { + if (config.xai) { + const searchOptions = createXaiOptions({ + searchParameters: { ...config.xai, mode: 'on' } + }) + params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) + } + break + } + + case 'openrouter': { + if (config.openrouter) { + const searchOptions = createOpenRouterOptions(config.openrouter) + params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) + } + break + } + } + return params +} diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts index 34eba7963..23ea95232 100644 --- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts @@ -2,15 +2,11 @@ * Web Search Plugin * 提供统一的网络搜索能力,支持多个 AI Provider */ -import { anthropic } from '@ai-sdk/anthropic' -import { google } from '@ai-sdk/google' -import { openai } from '@ai-sdk/openai' -import { createOpenRouterOptions, createXaiOptions, mergeProviderOptions } from '../../../options' import { definePlugin } from '../../' import type { AiRequestContext } from '../../types' import type { WebSearchPluginConfig } from './helper' -import { DEFAULT_WEB_SEARCH_CONFIG } from './helper' +import { DEFAULT_WEB_SEARCH_CONFIG, switchWebSearchTool } from './helper' /** * 网络搜索插件 @@ -24,56 +20,13 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR transformParams: async (params: any, context: AiRequestContext) => { const { providerId } = context - switch (providerId) { - case 'openai': { - if (config.openai) { - if (!params.tools) params.tools = {} - params.tools.web_search = openai.tools.webSearch(config.openai) - } - break - } - case 'openai-chat': { - if (config['openai-chat']) { - if (!params.tools) params.tools = {} - params.tools.web_search_preview = openai.tools.webSearchPreview(config['openai-chat']) - } - break - } + switchWebSearchTool(providerId, config, params) - case 'anthropic': { - if (config.anthropic) { - if (!params.tools) params.tools = {} - params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic) - } - break - } - - case 'google': { - // case 'google-vertex': - if (!params.tools) params.tools = {} - params.tools.web_search = google.tools.googleSearch(config.google || {}) - break - } - - case 'xai': { - if (config.xai) { - const searchOptions = createXaiOptions({ - searchParameters: { ...config.xai, mode: 'on' } - }) - params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) - } - break - } - - case 'openrouter': { - if (config.openrouter) { - const searchOptions = createOpenRouterOptions(config.openrouter) - params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) - } - break - } + if (providerId === 'cherryin' || providerId === 'cherryin-chat') { + // cherryin.gemini + const _providerId = params.model.provider.split('.')[1] + switchWebSearchTool(_providerId, config, params) } - return params } }) diff --git a/packages/aiCore/src/core/providers/schemas.ts b/packages/aiCore/src/core/providers/schemas.ts index 7ca4f6b0c..778b1b705 100644 --- a/packages/aiCore/src/core/providers/schemas.ts +++ b/packages/aiCore/src/core/providers/schemas.ts @@ -12,6 +12,7 @@ import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai' import { createOpenAICompatible } from '@ai-sdk/openai-compatible' import type { LanguageModelV2 } from '@ai-sdk/provider' import { createXai } from '@ai-sdk/xai' +import { type CherryInProviderSettings, createCherryIn } from '@cherrystudio/ai-sdk-provider' import { createOpenRouter } from '@openrouter/ai-sdk-provider' import type { Provider } from 'ai' import { customProvider } from 'ai' @@ -31,6 +32,8 @@ export const baseProviderIds = [ 'azure-responses', 'deepseek', 'openrouter', + 'cherryin', + 'cherryin-chat', 'huggingface' ] as const @@ -136,6 +139,26 @@ export const baseProviders = [ creator: createOpenRouter, supportsImageGeneration: true }, + { + id: 'cherryin', + name: 'CherryIN', + creator: createCherryIn, + supportsImageGeneration: true + }, + { + id: 'cherryin-chat', + name: 'CherryIN Chat', + creator: (options: CherryInProviderSettings) => { + const provider = createCherryIn(options) + return customProvider({ + fallbackProvider: { + ...provider, + languageModel: (modelId: string) => provider.chat(modelId) + } + }) + }, + supportsImageGeneration: true + }, { id: 'huggingface', name: 'HuggingFace', diff --git a/src/renderer/src/aiCore/provider/factory.ts b/src/renderer/src/aiCore/provider/factory.ts index 4cdbfb6d4..569b5628c 100644 --- a/src/renderer/src/aiCore/provider/factory.ts +++ b/src/renderer/src/aiCore/provider/factory.ts @@ -84,6 +84,8 @@ export async function createAiSdkProvider(config) { config.providerId = `${config.providerId}-chat` } else if (config.providerId === 'azure' && config.options?.mode === 'responses') { config.providerId = `${config.providerId}-responses` + } else if (config.providerId === 'cherryin' && config.options?.mode === 'chat') { + config.providerId = 'cherryin-chat' } localProvider = await createProviderCore(config.providerId, config.options) diff --git a/src/renderer/src/aiCore/provider/providerConfig.ts b/src/renderer/src/aiCore/provider/providerConfig.ts index 7f279a389..4eb1ffeed 100644 --- a/src/renderer/src/aiCore/provider/providerConfig.ts +++ b/src/renderer/src/aiCore/provider/providerConfig.ts @@ -171,7 +171,7 @@ export function providerToAiSdkConfig( extraOptions.endpoint = endpoint if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) { extraOptions.mode = 'responses' - } else if (aiSdkProviderId === 'openai') { + } else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && actualProvider.type === 'openai')) { extraOptions.mode = 'chat' } diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts index 60d9b1e09..9e296597c 100644 --- a/src/renderer/src/aiCore/utils/options.ts +++ b/src/renderer/src/aiCore/utils/options.ts @@ -113,6 +113,9 @@ export function buildProviderOptions( } break } + case 'cherryin': + providerSpecificOptions = buildCherryInProviderOptions(assistant, model, capabilities, actualProvider) + break default: throw new Error(`Unsupported base provider ${baseProviderId}`) } @@ -270,6 +273,34 @@ function buildXAIProviderOptions( return providerOptions } +function buildCherryInProviderOptions( + assistant: Assistant, + model: Model, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + }, + actualProvider: Provider +): Record { + const serviceTierSetting = getServiceTier(model, actualProvider) + + switch (actualProvider.type) { + case 'openai': + return { + ...buildOpenAIProviderOptions(assistant, model, capabilities), + serviceTier: serviceTierSetting + } + + case 'anthropic': + return buildAnthropicProviderOptions(assistant, model, capabilities) + + case 'gemini': + return buildGeminiProviderOptions(assistant, model, capabilities) + } + return {} +} + /** * Build Bedrock providerOptions */ diff --git a/src/renderer/src/aiCore/utils/websearch.ts b/src/renderer/src/aiCore/utils/websearch.ts index fde4ff534..02619b54c 100644 --- a/src/renderer/src/aiCore/utils/websearch.ts +++ b/src/renderer/src/aiCore/utils/websearch.ts @@ -107,6 +107,11 @@ export function buildProviderBuiltinWebSearchConfig( } } } + case 'cherryin': { + const _providerId = + { 'openai-response': 'openai', openai: 'openai-chat' }[model?.endpoint_type ?? ''] ?? model?.endpoint_type + return buildProviderBuiltinWebSearchConfig(_providerId, webSearchConfig, model) + } default: { return {} } diff --git a/tsconfig.web.json b/tsconfig.web.json index 120419225..2d91fe026 100644 --- a/tsconfig.web.json +++ b/tsconfig.web.json @@ -9,7 +9,8 @@ "packages/mcp-trace/**/*", "packages/aiCore/src/**/*", "src/main/integration/cherryai/index.js", - "packages/extension-table-plus/**/*" + "packages/extension-table-plus/**/*", + "packages/ai-sdk-provider/**/*" ], "compilerOptions": { "composite": true, @@ -27,7 +28,8 @@ "@cherrystudio/ai-core/built-in/plugins": ["./packages/aiCore/src/core/plugins/built-in/index.ts"], "@cherrystudio/ai-core/*": ["./packages/aiCore/src/*"], "@cherrystudio/ai-core": ["./packages/aiCore/src/index.ts"], - "@cherrystudio/extension-table-plus": ["./packages/extension-table-plus/src/index.ts"] + "@cherrystudio/extension-table-plus": ["./packages/extension-table-plus/src/index.ts"], + "@cherrystudio/ai-sdk-provider": ["./packages/ai-sdk-provider/src/index.ts"] }, "experimentalDecorators": true, "emitDecoratorMetadata": true, diff --git a/yarn.lock b/yarn.lock index 4dfbee786..ac9f3c830 100644 --- a/yarn.lock +++ b/yarn.lock @@ -278,7 +278,7 @@ __metadata: languageName: node linkType: hard -"@ai-sdk/provider-utils@npm:3.0.10, @ai-sdk/provider-utils@npm:^3.0.10": +"@ai-sdk/provider-utils@npm:3.0.10": version: 3.0.10 resolution: "@ai-sdk/provider-utils@npm:3.0.10" dependencies: @@ -304,6 +304,19 @@ __metadata: languageName: node linkType: hard +"@ai-sdk/provider-utils@npm:^3.0.10, @ai-sdk/provider-utils@npm:^3.0.12": + version: 3.0.17 + resolution: "@ai-sdk/provider-utils@npm:3.0.17" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@standard-schema/spec": "npm:^1.0.0" + eventsource-parser: "npm:^3.0.6" + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + checksum: 10c0/1bae6dc4cacd0305b6aa152f9589bbd61c29f150155482c285a77f83d7ed416d52bc2aa7fdaba2e5764530392d9e8f799baea34a63dce6c72ecd3de364dc62d1 + languageName: node + linkType: hard + "@ai-sdk/provider@npm:2.0.0, @ai-sdk/provider@npm:^2.0.0": version: 2.0.0 resolution: "@ai-sdk/provider@npm:2.0.0" @@ -1815,6 +1828,7 @@ __metadata: "@ai-sdk/provider": "npm:^2.0.0" "@ai-sdk/provider-utils": "npm:^3.0.16" "@ai-sdk/xai": "npm:^2.0.31" + "@cherrystudio/ai-sdk-provider": "workspace:*" tsdown: "npm:^0.12.9" typescript: "npm:^5.0.0" vitest: "npm:^3.2.4" @@ -1824,6 +1838,23 @@ __metadata: languageName: unknown linkType: soft +"@cherrystudio/ai-sdk-provider@workspace:*, @cherrystudio/ai-sdk-provider@workspace:packages/ai-sdk-provider": + version: 0.0.0-use.local + resolution: "@cherrystudio/ai-sdk-provider@workspace:packages/ai-sdk-provider" + dependencies: + "@ai-sdk/provider": "npm:^2.0.0" + "@ai-sdk/provider-utils": "npm:^3.0.12" + tsdown: "npm:^0.13.3" + typescript: "npm:^5.8.2" + vitest: "npm:^3.2.4" + peerDependencies: + "@ai-sdk/anthropic": ^2.0.29 + "@ai-sdk/google": ^2.0.23 + "@ai-sdk/openai": ^2.0.64 + ai: ^5.0.26 + languageName: unknown + linkType: soft + "@cherrystudio/embedjs-interfaces@npm:0.1.30": version: 0.1.30 resolution: "@cherrystudio/embedjs-interfaces@npm:0.1.30" @@ -24851,6 +24882,16 @@ __metadata: languageName: node linkType: hard +"typescript@npm:^5.8.2": + version: 5.9.3 + resolution: "typescript@npm:5.9.3" + bin: + tsc: bin/tsc + tsserver: bin/tsserver + checksum: 10c0/6bd7552ce39f97e711db5aa048f6f9995b53f1c52f7d8667c1abdc1700c68a76a308f579cd309ce6b53646deb4e9a1be7c813a93baaf0a28ccd536a30270e1c5 + languageName: node + linkType: hard + "typescript@patch:typescript@npm%3A^5.0.0#optional!builtin": version: 5.9.2 resolution: "typescript@patch:typescript@npm%3A5.9.2#optional!builtin::version=5.9.2&hash=5786d5" @@ -24871,6 +24912,16 @@ __metadata: languageName: node linkType: hard +"typescript@patch:typescript@npm%3A^5.8.2#optional!builtin": + version: 5.9.3 + resolution: "typescript@patch:typescript@npm%3A5.9.3#optional!builtin::version=5.9.3&hash=5786d5" + bin: + tsc: bin/tsc + tsserver: bin/tsserver + checksum: 10c0/ad09fdf7a756814dce65bc60c1657b40d44451346858eea230e10f2e95a289d9183b6e32e5c11e95acc0ccc214b4f36289dcad4bf1886b0adb84d711d336a430 + languageName: node + linkType: hard + "ua-parser-js@npm:^1.0.35": version: 1.0.40 resolution: "ua-parser-js@npm:1.0.40"