diff --git a/src/renderer/src/aiCore/clients/ApiClientFactory.ts b/src/renderer/src/aiCore/clients/ApiClientFactory.ts index b0fbe3e479..14e342da83 100644 --- a/src/renderer/src/aiCore/clients/ApiClientFactory.ts +++ b/src/renderer/src/aiCore/clients/ApiClientFactory.ts @@ -5,6 +5,7 @@ import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient' import { BaseApiClient } from './BaseApiClient' import { GeminiAPIClient } from './gemini/GeminiAPIClient' import { VertexAPIClient } from './gemini/VertexAPIClient' +import { NewAPIClient } from './NewAPIClient' import { OpenAIAPIClient } from './openai/OpenAIApiClient' import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient' import { PPIOAPIClient } from './ppio/PPIOAPIClient' @@ -32,6 +33,11 @@ export class ApiClientFactory { instance = new AihubmixAPIClient(provider) as BaseApiClient return instance } + if (provider.id === 'new-api') { + console.log(`[ApiClientFactory] Creating NewAPIClient for provider: ${provider.id}`) + instance = new NewAPIClient(provider) as BaseApiClient + return instance + } if (provider.id === 'ppio') { console.log(`[ApiClientFactory] Creating PPIOAPIClient for provider: ${provider.id}`) instance = new PPIOAPIClient(provider) as BaseApiClient diff --git a/src/renderer/src/aiCore/clients/NewAPIClient.ts b/src/renderer/src/aiCore/clients/NewAPIClient.ts new file mode 100644 index 0000000000..3162cad0fe --- /dev/null +++ b/src/renderer/src/aiCore/clients/NewAPIClient.ts @@ -0,0 +1,233 @@ +import { isSupportedModel } from '@renderer/config/models' +import { + GenerateImageParams, + MCPCallToolResponse, + MCPTool, + MCPToolResponse, + Model, + Provider, + ToolCallResponse +} from '@renderer/types' +import { + NewApiModel, + RequestOptions, + SdkInstance, + SdkMessageParam, + SdkParams, + SdkRawChunk, + SdkRawOutput, + SdkTool, + SdkToolCall +} from '@renderer/types/sdk' + +import { CompletionsContext } from '../middleware/types' +import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient' +import { BaseApiClient } from './BaseApiClient' +import { GeminiAPIClient } from './gemini/GeminiAPIClient' +import { OpenAIAPIClient } from './openai/OpenAIApiClient' +import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient' +import { RequestTransformer, ResponseChunkTransformer } from './types' + +export class NewAPIClient extends BaseApiClient { + // 使用联合类型而不是any,保持类型安全 + private clients: Map = + new Map() + private defaultClient: OpenAIAPIClient + private currentClient: BaseApiClient + + constructor(provider: Provider) { + super(provider) + + const claudeClient = new AnthropicAPIClient(provider) + const geminiClient = new GeminiAPIClient(provider) + const openaiClient = new OpenAIAPIClient(provider) + const openaiResponseClient = new OpenAIResponseAPIClient(provider) + + this.clients.set('claude', claudeClient) + this.clients.set('gemini', geminiClient) + this.clients.set('openai', openaiClient) + this.clients.set('openai-response', openaiResponseClient) + + // 设置默认client + this.defaultClient = openaiClient + this.currentClient = this.defaultClient as BaseApiClient + } + + override getBaseURL(): string { + if (!this.currentClient) { + return this.provider.apiHost + } + return this.currentClient.getBaseURL() + } + + /** + * 类型守卫:确保client是BaseApiClient的实例 + */ + private isValidClient(client: unknown): client is BaseApiClient { + return ( + client !== null && + client !== undefined && + typeof client === 'object' && + 'createCompletions' in client && + 'getRequestTransformer' in client && + 'getResponseChunkTransformer' in client + ) + } + + /** + * 根据模型获取合适的client + */ + private getClient(model: Model): BaseApiClient { + if (!model.endpoint_type) { + throw new Error('Model endpoint type is not defined') + } + + if (model.endpoint_type === 'anthropic') { + const client = this.clients.get('claude') + if (!client || !this.isValidClient(client)) { + throw new Error('Failed to get claude client') + } + return client + } + + if (model.endpoint_type === 'openai-response') { + const client = this.clients.get('openai-response') + if (!client || !this.isValidClient(client)) { + throw new Error('Failed to get openai-response client') + } + return client + } + + if (model.endpoint_type === 'gemini') { + const client = this.clients.get('gemini') + if (!client || !this.isValidClient(client)) { + throw new Error('Failed to get gemini client') + } + return client + } + + if (model.endpoint_type === 'openai') { + const client = this.clients.get('openai') + if (!client || !this.isValidClient(client)) { + throw new Error('Failed to get openai client') + } + return client + } + + throw new Error('Invalid model endpoint type: ' + model.endpoint_type) + } + + /** + * 根据模型选择合适的client并委托调用 + */ + public getClientForModel(model: Model): BaseApiClient { + this.currentClient = this.getClient(model) + return this.currentClient + } + + // ============ BaseApiClient 抽象方法实现 ============ + + async createCompletions(payload: SdkParams, options?: RequestOptions): Promise { + // 尝试从payload中提取模型信息来选择client + const modelId = this.extractModelFromPayload(payload) + if (modelId) { + const modelObj = { id: modelId } as Model + const targetClient = this.getClient(modelObj) + return targetClient.createCompletions(payload, options) + } + + // 如果无法从payload中提取模型,使用当前设置的client + return this.currentClient.createCompletions(payload, options) + } + + /** + * 从SDK payload中提取模型ID + */ + private extractModelFromPayload(payload: SdkParams): string | null { + // 不同的SDK可能有不同的字段名 + if ('model' in payload && typeof payload.model === 'string') { + return payload.model + } + return null + } + + async generateImage(params: GenerateImageParams): Promise { + return this.currentClient.generateImage(params) + } + + async getEmbeddingDimensions(model?: Model): Promise { + const client = model ? this.getClient(model) : this.currentClient + return client.getEmbeddingDimensions(model) + } + + override async listModels(): Promise { + try { + const sdk = await this.defaultClient.getSdkInstance() + // Explicitly type the expected response shape so that `data` is recognised. + const response = await sdk.request<{ data: NewApiModel[] }>({ + method: 'get', + path: '/models' + }) + const models: NewApiModel[] = response.data ?? [] + + models.forEach((model) => { + model.id = model.id.trim() + }) + + return models.filter(isSupportedModel) + } catch (error) { + console.error('Error listing models:', error) + return [] + } + } + + async getSdkInstance(): Promise { + return this.currentClient.getSdkInstance() + } + + getRequestTransformer(): RequestTransformer { + return this.currentClient.getRequestTransformer() + } + + getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer { + return this.currentClient.getResponseChunkTransformer(ctx) + } + + convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] { + return this.currentClient.convertMcpToolsToSdkTools(mcpTools) + } + + convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined { + return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools) + } + + convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse { + return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool) + } + + buildSdkMessages( + currentReqMessages: SdkMessageParam[], + output: SdkRawOutput | string, + toolResults: SdkMessageParam[], + toolCalls?: SdkToolCall[] + ): SdkMessageParam[] { + return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls) + } + + convertMcpToolResponseToSdkMessageParam( + mcpToolResponse: MCPToolResponse, + resp: MCPCallToolResponse, + model: Model + ): SdkMessageParam | undefined { + const client = this.getClient(model) + return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model) + } + + extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] { + return this.currentClient.extractMessagesFromSdkPayload(sdkPayload) + } + + estimateMessageTokens(message: SdkMessageParam): number { + return this.currentClient.estimateMessageTokens(message) + } +} diff --git a/src/renderer/src/aiCore/index.ts b/src/renderer/src/aiCore/index.ts index 5b1bb5e181..18bf2e8524 100644 --- a/src/renderer/src/aiCore/index.ts +++ b/src/renderer/src/aiCore/index.ts @@ -8,6 +8,7 @@ import { isEnabledToolUse } from '@renderer/utils/mcp-tools' import { OpenAIAPIClient } from './clients' import { AihubmixAPIClient } from './clients/AihubmixAPIClient' import { AnthropicAPIClient } from './clients/anthropic/AnthropicAPIClient' +import { NewAPIClient } from './clients/NewAPIClient' import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient' import { CompletionsMiddlewareBuilder } from './middleware/builder' import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware' @@ -48,6 +49,11 @@ export default class AiProvider { if (client instanceof OpenAIResponseAPIClient) { client = client.getClient(model) as BaseApiClient } + } else if (this.apiClient instanceof NewAPIClient) { + client = this.apiClient.getClientForModel(model) + if (client instanceof OpenAIResponseAPIClient) { + client = client.getClient(model) as BaseApiClient + } } else if (this.apiClient instanceof OpenAIResponseAPIClient) { // OpenAIResponseAPIClient: 根据模型特征选择API类型 client = this.apiClient.getClient(model) as BaseApiClient diff --git a/src/renderer/src/assets/images/providers/newapi.png b/src/renderer/src/assets/images/providers/newapi.png new file mode 100644 index 0000000000..f62bfd57f1 Binary files /dev/null and b/src/renderer/src/assets/images/providers/newapi.png differ diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index 94a92c53ca..e185072e6b 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -2235,7 +2235,8 @@ export const SYSTEM_MODELS: Record = { group: 'DeepSeek' } ], - lanyun: [] + lanyun: [], + 'new-api': [] } export const TEXT_TO_IMAGES_MODELS = [ diff --git a/src/renderer/src/config/providers.ts b/src/renderer/src/config/providers.ts index 209eca2e84..3775c115a2 100644 --- a/src/renderer/src/config/providers.ts +++ b/src/renderer/src/config/providers.ts @@ -28,6 +28,7 @@ import MinimaxProviderLogo from '@renderer/assets/images/providers/minimax.png' import MistralProviderLogo from '@renderer/assets/images/providers/mistral.png' import ModelScopeProviderLogo from '@renderer/assets/images/providers/modelscope.png' import MoonshotProviderLogo from '@renderer/assets/images/providers/moonshot.png' +import NewAPIProviderLogo from '@renderer/assets/images/providers/newapi.png' import NvidiaProviderLogo from '@renderer/assets/images/providers/nvidia.png' import O3ProviderLogo from '@renderer/assets/images/providers/o3.png' import OcoolAiProviderLogo from '@renderer/assets/images/providers/ocoolai.png' @@ -104,7 +105,8 @@ const PROVIDER_LOGO_MAP = { tokenflux: TokenFluxProviderLogo, cephalon: CephalonProviderLogo, lanyun: LanyunProviderLogo, - vertexai: VertexAIProviderLogo + vertexai: VertexAIProviderLogo, + 'new-api': NewAPIProviderLogo } as const export function getProviderLogo(providerId: string) { @@ -678,5 +680,14 @@ export const PROVIDER_CONFIG = { docs: 'https://cloud.google.com/vertex-ai/generative-ai/docs', models: 'https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models' } + }, + 'new-api': { + api: { + url: 'http://localhost:3000' + }, + websites: { + official: 'https://docs.newapi.pro/', + docs: 'https://docs.newapi.pro' + } } } diff --git a/src/renderer/src/hooks/useDynamicLabelWidth.ts b/src/renderer/src/hooks/useDynamicLabelWidth.ts new file mode 100644 index 0000000000..70bf2423a4 --- /dev/null +++ b/src/renderer/src/hooks/useDynamicLabelWidth.ts @@ -0,0 +1,35 @@ +import { useMemo } from 'react' + +/** + * Compute a width string that fits the longest label text within a form. + * This is useful when using Ant Design `Form` with `labelCol` so that the layout + * adapts across different languages where label lengths vary. + * + * @param labels Array of label strings to measure. These should already be translated. + * @param extraPadding Extra pixels added to the measured width to provide spacing. + * Defaults to 50px which visually matches earlier fixed width. + * @returns A width string that can be used in CSS, e.g. "140px". + */ +export const useDynamicLabelWidth = (labels: string[], extraPadding = 40): string => { + return useMemo(() => { + if (typeof window === 'undefined' || !labels || labels.length === 0) return '170px' + + // Create a hidden span for text measurement + const span = document.createElement('span') + span.style.visibility = 'hidden' + span.style.position = 'absolute' + span.style.whiteSpace = 'nowrap' + span.style.fontSize = getComputedStyle(document.body).fontSize ?? '14px' + document.body.appendChild(span) + + let maxWidth = 0 + labels.forEach((text) => { + span.textContent = text + maxWidth = Math.max(maxWidth, span.offsetWidth) + }) + + document.body.removeChild(span) + + return `${maxWidth + extraPadding}px` + }, [extraPadding, labels]) +} diff --git a/src/renderer/src/i18n/locales/en-us.json b/src/renderer/src/i18n/locales/en-us.json index 5a7a930a79..e76ee35701 100644 --- a/src/renderer/src/i18n/locales/en-us.json +++ b/src/renderer/src/i18n/locales/en-us.json @@ -1039,6 +1039,7 @@ "modelscope": "ModelScope", "moonshot": "Moonshot", "nvidia": "Nvidia", + "new-api": "New API", "o3": "O3", "ocoolai": "ocoolAI", "ollama": "Ollama", @@ -1616,6 +1617,7 @@ "messages.use_serif_font": "Use serif font", "model": "Default Model", "models.add.add_model": "Add Model", + "models.add.batch_add_models": "Batch Add Models", "models.add.group_name": "Group Name", "models.add.group_name.placeholder": "Optional e.g. ChatGPT", "models.add.group_name.tooltip": "Optional e.g. ChatGPT", @@ -1625,6 +1627,10 @@ "models.add.model_id.tooltip": "Example: gpt-3.5-turbo", "models.add.model_name": "Model Name", "models.add.model_name.tooltip": "Optional e.g. GPT-4", + "models.add.endpoint_type": "Endpoint Type", + "models.add.endpoint_type.placeholder": "Select endpoint type", + "models.add.endpoint_type.tooltip": "Select the API endpoint type format", + "models.add.endpoint_type.required": "Please select an endpoint type", "models.add.model_name.placeholder": "Optional e.g. GPT-4", "models.check.all": "All", "models.check.all_models_passed": "All models check passed", diff --git a/src/renderer/src/i18n/locales/ja-jp.json b/src/renderer/src/i18n/locales/ja-jp.json index 6c10d00f0f..d735329335 100644 --- a/src/renderer/src/i18n/locales/ja-jp.json +++ b/src/renderer/src/i18n/locales/ja-jp.json @@ -1038,6 +1038,7 @@ "modelscope": "ModelScope", "moonshot": "月の暗面", "nvidia": "NVIDIA", + "new-api": "New API", "o3": "O3", "ocoolai": "ocoolAI", "ollama": "Ollama", @@ -1604,6 +1605,7 @@ "messages.use_serif_font": "セリフフォントを使用", "model": "デフォルトモデル", "models.add.add_model": "モデルを追加", + "models.add.batch_add_models": "モデルを一括追加", "models.add.group_name": "グループ名", "models.add.group_name.placeholder": "例:ChatGPT", "models.add.group_name.tooltip": "例:ChatGPT", @@ -1613,6 +1615,10 @@ "models.add.model_id.tooltip": "例:gpt-3.5-turbo", "models.add.model_name": "モデル名", "models.add.model_name.tooltip": "例:GPT-4", + "models.add.endpoint_type": "エンドポイントタイプ", + "models.add.endpoint_type.placeholder": "エンドポイントタイプを選択", + "models.add.endpoint_type.tooltip": "APIエンドポイントタイプフォーマットを選択", + "models.add.endpoint_type.required": "エンドポイントタイプを選択してください", "models.add.model_name.placeholder": "例:GPT-4", "models.check.all": "すべて", "models.check.all_models_passed": "すべてのモデルチェックが成功しました", diff --git a/src/renderer/src/i18n/locales/ru-ru.json b/src/renderer/src/i18n/locales/ru-ru.json index 72859a3d52..fd664fa871 100644 --- a/src/renderer/src/i18n/locales/ru-ru.json +++ b/src/renderer/src/i18n/locales/ru-ru.json @@ -1039,6 +1039,7 @@ "modelscope": "ModelScope", "moonshot": "Moonshot", "nvidia": "Nvidia", + "new-api": "New API", "o3": "O3", "ocoolai": "ocoolAI", "ollama": "Ollama", @@ -1604,6 +1605,7 @@ "messages.use_serif_font": "Использовать serif шрифт", "model": "Модель по умолчанию", "models.add.add_model": "Добавить модель", + "models.add.batch_add_models": "Пакетное добавление моделей", "models.add.group_name": "Имя группы", "models.add.group_name.placeholder": "Необязательно, например, ChatGPT", "models.add.group_name.tooltip": "Необязательно, например, ChatGPT", @@ -1613,6 +1615,10 @@ "models.add.model_id.tooltip": "Пример: gpt-3.5-turbo", "models.add.model_name": "Имя модели", "models.add.model_name.tooltip": "Необязательно, например, GPT-4", + "models.add.endpoint_type": "Тип конечной точки", + "models.add.endpoint_type.placeholder": "Выберите тип конечной точки", + "models.add.endpoint_type.tooltip": "Выберите формат типа конечной точки API", + "models.add.endpoint_type.required": "Пожалуйста, выберите тип конечной точки", "models.add.model_name.placeholder": "Необязательно, например, GPT-4", "models.check.all": "Все", "models.check.all_models_passed": "Все модели прошли проверку", diff --git a/src/renderer/src/i18n/locales/zh-cn.json b/src/renderer/src/i18n/locales/zh-cn.json index 953ea0e796..fee1a37288 100644 --- a/src/renderer/src/i18n/locales/zh-cn.json +++ b/src/renderer/src/i18n/locales/zh-cn.json @@ -1039,6 +1039,7 @@ "modelscope": "ModelScope 魔搭", "moonshot": "月之暗面", "nvidia": "英伟达", + "new-api": "New API", "o3": "O3", "ocoolai": "ocoolAI", "ollama": "Ollama", @@ -1616,6 +1617,7 @@ "messages.use_serif_font": "使用衬线字体", "model": "默认模型", "models.add.add_model": "添加模型", + "models.add.batch_add_models": "批量添加模型", "models.add.group_name": "分组名称", "models.add.group_name.placeholder": "例如 ChatGPT", "models.add.group_name.tooltip": "例如 ChatGPT", @@ -1626,6 +1628,10 @@ "models.add.model_name": "模型名称", "models.add.model_name.placeholder": "例如 GPT-4", "models.add.model_name.tooltip": "例如 GPT-4", + "models.add.endpoint_type": "端点类型", + "models.add.endpoint_type.placeholder": "选择端点类型", + "models.add.endpoint_type.tooltip": "选择 API 的端点类型格式", + "models.add.endpoint_type.required": "请选择端点类型", "models.check.all": "所有", "models.check.all_models_passed": "所有模型检测通过", "models.check.button_caption": "健康检测", diff --git a/src/renderer/src/i18n/locales/zh-tw.json b/src/renderer/src/i18n/locales/zh-tw.json index 0098e45368..d16130ef52 100644 --- a/src/renderer/src/i18n/locales/zh-tw.json +++ b/src/renderer/src/i18n/locales/zh-tw.json @@ -1039,6 +1039,7 @@ "modelscope": "ModelScope 魔搭", "moonshot": "月之暗面", "nvidia": "輝達", + "new-api": "New API", "o3": "O3", "ocoolai": "ocoolAI", "ollama": "Ollama", @@ -1607,6 +1608,7 @@ "messages.use_serif_font": "使用襯線字型", "model": "預設模型", "models.add.add_model": "新增模型", + "models.add.batch_add_models": "批量新增模型", "models.add.group_name": "群組名稱", "models.add.group_name.placeholder": "選填,例如 ChatGPT", "models.add.group_name.tooltip": "選填,例如 ChatGPT", @@ -1617,6 +1619,10 @@ "models.add.model_name": "模型名稱", "models.add.model_name.placeholder": "選填,例如 GPT-4", "models.add.model_name.tooltip": "例如 GPT-4", + "models.add.endpoint_type": "端點類型", + "models.add.endpoint_type.placeholder": "選擇端點類型", + "models.add.endpoint_type.tooltip": "選擇 API 的端點類型格式", + "models.add.endpoint_type.required": "請選擇端點類型", "models.check.all": "所有", "models.check.all_models_passed": "所有模型檢查通過", "models.check.button_caption": "健康檢查", diff --git a/src/renderer/src/pages/knowledge/components/AddKnowledgePopup.tsx b/src/renderer/src/pages/knowledge/components/AddKnowledgePopup.tsx index eb3639931a..f19f23b4b6 100644 --- a/src/renderer/src/pages/knowledge/components/AddKnowledgePopup.tsx +++ b/src/renderer/src/pages/knowledge/components/AddKnowledgePopup.tsx @@ -113,7 +113,7 @@ const PopupContainer: React.FC = ({ title, resolve }) => { } return isMac ? [preprocessOptions, ocrOptions] : [preprocessOptions] - }, [ocrProviders, preprocessProviders]) + }, [ocrProviders, preprocessProviders, t]) const onOk = async () => { try { diff --git a/src/renderer/src/pages/knowledge/components/StatusIcon.tsx b/src/renderer/src/pages/knowledge/components/StatusIcon.tsx index 69435d4e14..465a7cef0f 100644 --- a/src/renderer/src/pages/knowledge/components/StatusIcon.tsx +++ b/src/renderer/src/pages/knowledge/components/StatusIcon.tsx @@ -28,7 +28,7 @@ const StatusIcon: FC = ({ const errorText = item?.processingError console.log('[StatusIcon] Rendering for item:', item?.id, 'Status:', status, 'Progress:', progress) - const statusDisplay = useMemo(() => { + return useMemo(() => { if (!status) { if (item?.uniqueId) { if (isPreprocessed && item.type === 'file') { @@ -83,9 +83,7 @@ const StatusIcon: FC = ({ default: return null } - }, [status, item?.uniqueId, type, progress, errorText, t]) - - return statusDisplay + }, [status, item?.uniqueId, item?.type, t, isPreprocessed, errorText, type, progress]) } const StatusDot = styled.div<{ $status: 'pending' | 'processing' | 'new' }>` diff --git a/src/renderer/src/pages/settings/ProviderSettings/EditModelsPopup.tsx b/src/renderer/src/pages/settings/ProviderSettings/EditModelsPopup.tsx index c12844df32..4dcb67bb70 100644 --- a/src/renderer/src/pages/settings/ProviderSettings/EditModelsPopup.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings/EditModelsPopup.tsx @@ -16,6 +16,8 @@ import { } from '@renderer/config/models' import { useProvider } from '@renderer/hooks/useProvider' import FileItem from '@renderer/pages/files/FileItem' +import NewApiAddModelPopup from '@renderer/pages/settings/ProviderSettings/NewApiAddModelPopup' +import NewApiBatchAddModelPopup from '@renderer/pages/settings/ProviderSettings/NewApiBatchAddModelPopup' import { fetchModels } from '@renderer/services/ApiService' import { Model, Provider } from '@renderer/types' import { getDefaultGroupName, isFreeModel, runAsyncFunction } from '@renderer/utils' @@ -43,6 +45,10 @@ const isModelInProvider = (provider: Provider, modelId: string): boolean => { return provider.models.some((m) => m.id === modelId) } +const isValidNewApiModel = (model: Model): boolean => { + return !!(model.supported_endpoint_types && model.supported_endpoint_types.length > 0) +} + const PopupContainer: React.FC = ({ provider: _provider, resolve }) => { const [open, setOpen] = useState(true) const { provider, models, addModel, removeModel } = useProvider(_provider.id) @@ -129,10 +135,21 @@ const PopupContainer: React.FC = ({ provider: _provider, resolve }) => { const onAddModel = useCallback( (model: Model) => { if (!isEmpty(model.name)) { - addModel(model) + if (provider.id === 'new-api') { + if (model.supported_endpoint_types && model.supported_endpoint_types.length > 0) { + addModel({ + ...model, + endpoint_type: model.supported_endpoint_types[0] + }) + } else { + NewApiAddModelPopup.show({ title: t('settings.models.add.add_model'), provider, model }) + } + } else { + addModel(model) + } } }, - [addModel] + [addModel, provider, t] ) const onRemoveModel = useCallback((model: Model) => removeModel(model), [removeModel]) @@ -155,7 +172,9 @@ const PopupContainer: React.FC = ({ provider: _provider, resolve }) => { // @ts-ignore description description: model?.description || '', // @ts-ignore owned_by - owned_by: model?.owned_by || '' + owned_by: model?.owned_by || '', + // @ts-ignore supported_endpoint_types + supported_endpoint_types: model?.supported_endpoint_types })) .filter((model) => !isEmpty(model.name)) ) @@ -207,14 +226,27 @@ const PopupContainer: React.FC = ({ provider: _provider, resolve }) => { if (isAllFilteredInProvider) { list.filter((model) => isModelInProvider(provider, model.id)).forEach(onRemoveModel) } else { - list.filter((model) => !isModelInProvider(provider, model.id)).forEach(onAddModel) + const wouldAddModel = list.filter((model) => !isModelInProvider(provider, model.id)) + if (provider.id === 'new-api') { + if (models.every(isValidNewApiModel)) { + wouldAddModel.forEach(onAddModel) + } else { + NewApiBatchAddModelPopup.show({ + title: t('settings.models.add.batch_add_models'), + batchModels: wouldAddModel, + provider + }) + } + } else { + wouldAddModel.forEach(onAddModel) + } } }} disabled={list.length === 0} /> ) - }, [list, provider, onAddModel, onRemoveModel, t]) + }, [list, t, provider, onRemoveModel, models, onAddModel]) const renderGroupTools = useCallback( (group: string) => { @@ -237,7 +269,20 @@ const PopupContainer: React.FC = ({ provider: _provider, resolve }) => { if (isAllInProvider) { modelGroups[group].filter((model) => isModelInProvider(provider, model.id)).forEach(onRemoveModel) } else { - modelGroups[group].filter((model) => !isModelInProvider(provider, model.id)).forEach(onAddModel) + const wouldAddModel = modelGroups[group].filter((model) => !isModelInProvider(provider, model.id)) + if (provider.id === 'new-api') { + if (wouldAddModel.every(isValidNewApiModel)) { + wouldAddModel.forEach(onAddModel) + } else { + NewApiBatchAddModelPopup.show({ + title: t('settings.models.add.batch_add_models'), + batchModels: wouldAddModel, + provider + }) + } + } else { + wouldAddModel.forEach(onAddModel) + } } }} /> diff --git a/src/renderer/src/pages/settings/ProviderSettings/ModelEditContent.tsx b/src/renderer/src/pages/settings/ProviderSettings/ModelEditContent.tsx index 5236fd1cc5..b66c554858 100644 --- a/src/renderer/src/pages/settings/ProviderSettings/ModelEditContent.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings/ModelEditContent.tsx @@ -6,14 +6,17 @@ import { isVisionModel, isWebSearchModel } from '@renderer/config/models' -import { Model, ModelType } from '@renderer/types' +import { useDynamicLabelWidth } from '@renderer/hooks/useDynamicLabelWidth' +import { Model, ModelType, Provider } from '@renderer/types' import { getDefaultGroupName } from '@renderer/utils' import { Button, Checkbox, Divider, Flex, Form, Input, InputNumber, message, Modal, Select } from 'antd' import { ChevronDown, ChevronUp } from 'lucide-react' import { FC, useState } from 'react' import { useTranslation } from 'react-i18next' import styled from 'styled-components' + interface ModelEditContentProps { + provider: Provider model: Model onUpdateModel: (model: Model) => void open: boolean @@ -21,13 +24,15 @@ interface ModelEditContentProps { } const symbols = ['$', '¥', '€', '£'] -const ModelEditContent: FC = ({ model, onUpdateModel, open, onClose }) => { +const ModelEditContent: FC = ({ provider, model, onUpdateModel, open, onClose }) => { const [form] = Form.useForm() const { t } = useTranslation() const [showMoreSettings, setShowMoreSettings] = useState(false) const [currencySymbol, setCurrencySymbol] = useState(model.pricing?.currencySymbol || '$') const [isCustomCurrency, setIsCustomCurrency] = useState(!symbols.includes(model.pricing?.currencySymbol || '$')) + const labelWidth = useDynamicLabelWidth([t('settings.models.add.endpoint_type')]) + const onFinish = (values: any) => { const finalCurrencySymbol = isCustomCurrency ? values.customCurrencySymbol : values.currencySymbol const updatedModel = { @@ -35,6 +40,7 @@ const ModelEditContent: FC = ({ model, onUpdateModel, ope id: values.id || model.id, name: values.name || model.name, group: values.group || model.group, + endpoint_type: provider.id === 'new-api' ? values.endpointType : model.endpoint_type, pricing: { input_per_million_tokens: Number(values.input_per_million_tokens) || 0, output_per_million_tokens: Number(values.output_per_million_tokens) || 0, @@ -74,7 +80,7 @@ const ModelEditContent: FC = ({ model, onUpdateModel, ope }}>
= ({ model, onUpdateModel, ope id: model.id, name: model.name, group: model.group, + endpointType: model.endpoint_type, input_per_million_tokens: model.pricing?.input_per_million_tokens ?? 0, output_per_million_tokens: model.pricing?.output_per_million_tokens ?? 0, currencySymbol: symbols.includes(model.pricing?.currencySymbol || '$') @@ -133,6 +140,21 @@ const ModelEditContent: FC = ({ model, onUpdateModel, ope tooltip={t('settings.models.add.group_name.tooltip')}> + {provider.id === 'new-api' && ( + + + + )} + + + + + ) +} + +export default class NewApiAddModelPopup { + static topviewId = 0 + static hide() { + TopView.hide('NewApiAddModelPopup') + } + static show(props: ShowParams) { + return new Promise((resolve) => { + TopView.show( + { + resolve(v) + this.hide() + }} + />, + 'NewApiAddModelPopup' + ) + }) + } +} diff --git a/src/renderer/src/pages/settings/ProviderSettings/NewApiBatchAddModelPopup.tsx b/src/renderer/src/pages/settings/ProviderSettings/NewApiBatchAddModelPopup.tsx new file mode 100644 index 0000000000..0b94c0cf0a --- /dev/null +++ b/src/renderer/src/pages/settings/ProviderSettings/NewApiBatchAddModelPopup.tsx @@ -0,0 +1,124 @@ +import { TopView } from '@renderer/components/TopView' +import { useDynamicLabelWidth } from '@renderer/hooks/useDynamicLabelWidth' +import { useProvider } from '@renderer/hooks/useProvider' +import { EndpointType, Model, Provider } from '@renderer/types' +import { Button, Flex, Form, FormProps, Modal, Select } from 'antd' +import { useState } from 'react' +import { useTranslation } from 'react-i18next' + +interface ShowParams { + title: string + provider: Provider + batchModels: Model[] +} + +interface Props extends ShowParams { + resolve: (data: any) => void +} + +type FieldType = { + provider: string + group?: string + endpointType?: EndpointType +} + +const PopupContainer: React.FC = ({ title, provider, resolve, batchModels }) => { + const [open, setOpen] = useState(true) + const [form] = Form.useForm() + const { addModel } = useProvider(provider.id) + const { t } = useTranslation() + + const onOk = () => { + setOpen(false) + } + + const onCancel = () => { + setOpen(false) + } + + const onClose = () => { + resolve({}) + } + + const onAddModel = (values: FieldType) => { + batchModels.forEach((model) => { + addModel({ + ...model, + endpoint_type: values.endpointType + }) + }) + return true + } + + const onFinish: FormProps['onFinish'] = (values) => { + if (onAddModel(values)) { + resolve({}) + } + } + + return ( + +
+ + + + + + + + +
+
+ ) +} + +export default class NewApiBatchAddModelPopup { + static topviewId = 0 + static hide() { + TopView.hide('NewApiBatchAddModelPopup') + } + static show(props: ShowParams) { + return new Promise((resolve) => { + TopView.show( + { + resolve(v) + this.hide() + }} + />, + 'NewApiBatchAddModelPopup' + ) + }) + } +} diff --git a/src/renderer/src/pages/settings/ProviderSettings/index.tsx b/src/renderer/src/pages/settings/ProviderSettings/index.tsx index 65f0886f5f..d4aa5447ea 100644 --- a/src/renderer/src/pages/settings/ProviderSettings/index.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings/index.tsx @@ -259,7 +259,7 @@ const ProvidersList: FC = () => { window.message.error(t('settings.models.provider_key_add_failed_by_invalid_data')) window.navigate('/settings/provider') } - }, [searchParams]) + }, [addProvider, providers, searchParams, t, updateProvider]) const onDragEnd = (result: DropResult) => { setDragging(false) diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts index d40139670e..f0c0cb2680 100644 --- a/src/renderer/src/store/index.ts +++ b/src/renderer/src/store/index.ts @@ -54,7 +54,7 @@ const persistedReducer = persistReducer( { key: 'cherry-studio', storage, - version: 118, + version: 119, blacklist: ['runtime', 'messages', 'messageBlocks'], migrate }, diff --git a/src/renderer/src/store/llm.ts b/src/renderer/src/store/llm.ts index 9f3a5a0e59..c70fd6a0da 100644 --- a/src/renderer/src/store/llm.ts +++ b/src/renderer/src/store/llm.ts @@ -194,6 +194,16 @@ export const INITIAL_PROVIDERS: Provider[] = [ isSystem: true, enabled: false }, + { + id: 'new-api', + name: 'New API', + type: 'openai', + apiKey: '', + apiHost: 'http://localhost:3000', + models: SYSTEM_MODELS['new-api'], + isSystem: true, + enabled: false + }, { id: 'lmstudio', name: 'LM Studio', diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index 1cc830a2bb..935da99054 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -1712,6 +1712,15 @@ const migrateConfig = { } }) + return state + } catch (error) { + return state + } + }, + '119': (state: RootState) => { + try { + addProvider(state, 'new-api') + state.llm.providers = moveProvider(state.llm.providers, 'new-api', 16) return state } catch (error) { return state diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 02502c586f..b5cd425f03 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -178,6 +178,8 @@ export type ProviderType = export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search' +export type EndpointType = 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'jina-rerank' + export type ModelPricing = { input_per_million_tokens: number output_per_million_tokens: number @@ -193,6 +195,8 @@ export type Model = { description?: string type?: ModelType[] pricing?: ModelPricing + endpoint_type?: EndpointType + supported_endpoint_types?: EndpointType[] } export type Suggestion = { diff --git a/src/renderer/src/types/sdk.ts b/src/renderer/src/types/sdk.ts index c7eeb9500c..ae7d823d4f 100644 --- a/src/renderer/src/types/sdk.ts +++ b/src/renderer/src/types/sdk.ts @@ -21,6 +21,8 @@ import { import OpenAI, { AzureOpenAI } from 'openai' import { Stream } from 'openai/streaming' +import { EndpointType } from './index' + export type SdkInstance = OpenAI | AzureOpenAI | Anthropic | GoogleGenAI export type SdkParams = OpenAISdkParams | OpenAIResponseSdkParams | AnthropicSdkParams | GeminiSdkParams export type SdkRawChunk = OpenAISdkRawChunk | OpenAIResponseSdkRawChunk | AnthropicSdkRawChunk | GeminiSdkRawChunk @@ -36,7 +38,7 @@ export type SdkToolCall = | FunctionCall | OpenAIResponseSdkToolCall export type SdkTool = OpenAI.Chat.Completions.ChatCompletionTool | ToolUnion | Tool | OpenAIResponseSdkTool -export type SdkModel = OpenAI.Models.Model | Anthropic.ModelInfo | GeminiModel +export type SdkModel = OpenAI.Models.Model | Anthropic.ModelInfo | GeminiModel | NewApiModel export type RequestOptions = Anthropic.RequestOptions | OpenAI.RequestOptions | GeminiOptions @@ -106,3 +108,10 @@ export type GeminiOptions = { signal?: AbortSignal timeout?: number } + +/** + * New API + */ +export interface NewApiModel extends OpenAI.Models.Model { + supported_endpoint_types?: EndpointType[] +}