mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
feat: add openai-compatible provider and enhance provider configuration
- Introduced the @ai-sdk/openai-compatible package to support compatibility with OpenAI. - Added a new ProviderConfigFactory and ProviderConfigBuilder for streamlined provider configuration. - Updated the provider registry to include the new Google Vertex AI import path. - Enhanced the index.ts to export new provider configuration utilities for better type safety and usability. - Refactored ApiService and middleware to integrate the new provider configurations effectively.
This commit is contained in:
parent
c99a2fedb7
commit
8ca6341609
@ -35,6 +35,7 @@
|
||||
"@ai-sdk/groq": "^1.2.9",
|
||||
"@ai-sdk/mistral": "^1.2.8",
|
||||
"@ai-sdk/openai": "^1.3.22",
|
||||
"@ai-sdk/openai-compatible": "^0.2.14",
|
||||
"@ai-sdk/perplexity": "^1.1.9",
|
||||
"@ai-sdk/replicate": "^0.2.8",
|
||||
"@ai-sdk/togetherai": "^0.2.14",
|
||||
|
||||
1022
packages/aiCore/pnpm-lock.yaml
Normal file
1022
packages/aiCore/pnpm-lock.yaml
Normal file
File diff suppressed because it is too large
Load Diff
@ -116,6 +116,15 @@ export type {
|
||||
export { createClient as createApiClient, getClientInfo, getSupportedProviders } from './clients/ApiClientFactory'
|
||||
export { getAllProviders, getProvider, isProviderSupported, registerProvider } from './providers/registry'
|
||||
|
||||
// ==================== Provider 配置工厂 ====================
|
||||
export {
|
||||
BaseProviderConfig,
|
||||
createProviderConfig,
|
||||
ProviderConfigBuilder,
|
||||
providerConfigBuilder,
|
||||
ProviderConfigFactory
|
||||
} from './providers/factory'
|
||||
|
||||
// ==================== 包信息 ====================
|
||||
export const AI_CORE_VERSION = '1.0.0'
|
||||
export const AI_CORE_NAME = '@cherrystudio/ai-core'
|
||||
|
||||
330
packages/aiCore/src/providers/factory.ts
Normal file
330
packages/aiCore/src/providers/factory.ts
Normal file
@ -0,0 +1,330 @@
|
||||
/**
|
||||
* AI Provider 配置工厂
|
||||
* 提供类型安全的 Provider 配置构建器
|
||||
*/
|
||||
|
||||
import type { ProviderId, ProviderSettingsMap } from './registry'
|
||||
|
||||
/**
|
||||
* 通用配置基础类型,包含所有 Provider 共有的属性
|
||||
*/
|
||||
export interface BaseProviderConfig {
|
||||
apiKey?: string
|
||||
baseURL?: string
|
||||
timeout?: number
|
||||
headers?: Record<string, string>
|
||||
fetch?: typeof globalThis.fetch
|
||||
}
|
||||
|
||||
/**
|
||||
* 完整的配置类型,结合基础配置、AI SDK 配置和特定 Provider 配置
|
||||
*/
|
||||
type CompleteProviderConfig<T extends ProviderId> = BaseProviderConfig & Partial<ProviderSettingsMap[T]>
|
||||
|
||||
type ConfigHandler<T extends ProviderId> = (
|
||||
builder: ProviderConfigBuilder<T>,
|
||||
provider: CompleteProviderConfig<T>
|
||||
) => void
|
||||
|
||||
const configHandlers: {
|
||||
[K in ProviderId]?: ConfigHandler<K>
|
||||
} = {
|
||||
azure: (builder, provider) => {
|
||||
const azureBuilder = builder as ProviderConfigBuilder<'azure'>
|
||||
const azureProvider = provider as CompleteProviderConfig<'azure'>
|
||||
azureBuilder.withAzureConfig({
|
||||
apiVersion: azureProvider.apiVersion,
|
||||
resourceName: azureProvider.resourceName
|
||||
})
|
||||
},
|
||||
'google-vertex': (builder, provider) => {
|
||||
const vertexBuilder = builder as ProviderConfigBuilder<'google-vertex'>
|
||||
const vertexProvider = provider as CompleteProviderConfig<'google-vertex'>
|
||||
vertexBuilder.withGoogleVertexConfig({
|
||||
project: vertexProvider.project,
|
||||
location: vertexProvider.location
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
|
||||
private config: CompleteProviderConfig<T> = {} as CompleteProviderConfig<T>
|
||||
|
||||
constructor(private providerId: T) {}
|
||||
|
||||
/**
|
||||
* 设置 API Key
|
||||
*/
|
||||
withApiKey(apiKey: string): this
|
||||
withApiKey(apiKey: string, options: T extends 'openai' ? { organization?: string; project?: string } : never): this
|
||||
withApiKey(apiKey: string, options?: any): this {
|
||||
this.config.apiKey = apiKey
|
||||
|
||||
// 类型安全的 OpenAI 特定配置
|
||||
if (this.providerId === 'openai' && options) {
|
||||
const openaiConfig = this.config as CompleteProviderConfig<'openai'>
|
||||
if (options.organization) {
|
||||
openaiConfig.organization = options.organization
|
||||
}
|
||||
if (options.project) {
|
||||
openaiConfig.project = options.project
|
||||
}
|
||||
}
|
||||
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置基础 URL
|
||||
*/
|
||||
withBaseURL(baseURL: string) {
|
||||
this.config.baseURL = baseURL
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置请求配置
|
||||
*/
|
||||
withRequestConfig(options: { headers?: Record<string, string>; fetch?: typeof fetch }): this {
|
||||
if (options.headers) {
|
||||
this.config.headers = { ...this.config.headers, ...options.headers }
|
||||
}
|
||||
if (options.fetch) {
|
||||
this.config.fetch = options.fetch
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Azure OpenAI 特定配置
|
||||
*/
|
||||
withAzureConfig(options: { apiVersion?: string; resourceName?: string }): T extends 'azure' ? this : never
|
||||
withAzureConfig(options: any): any {
|
||||
if (this.providerId === 'azure') {
|
||||
const azureConfig = this.config as CompleteProviderConfig<'azure'>
|
||||
if (options.apiVersion) {
|
||||
azureConfig.apiVersion = options.apiVersion
|
||||
}
|
||||
if (options.resourceName) {
|
||||
azureConfig.resourceName = options.resourceName
|
||||
}
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Google 特定配置
|
||||
*/
|
||||
withGoogleVertexConfig(options: { project?: string; location?: string }): T extends 'google-vertex' ? this : never
|
||||
withGoogleVertexConfig(options: any): any {
|
||||
if (this.providerId === 'google-vertex') {
|
||||
const googleConfig = this.config as CompleteProviderConfig<'google-vertex'>
|
||||
if (options.project) {
|
||||
googleConfig.project = options.project
|
||||
}
|
||||
if (options.location) {
|
||||
googleConfig.location = options.location
|
||||
}
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
withGoogleCredentials(credentials: {
|
||||
clientEmail: string
|
||||
privateKey: string
|
||||
}): T extends 'google-vertex' ? this : never
|
||||
withGoogleCredentials(credentials: any): any {
|
||||
if (this.providerId === 'google-vertex') {
|
||||
const vertexConfig = this.config as CompleteProviderConfig<'google-vertex'>
|
||||
vertexConfig.googleCredentials = credentials
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置自定义参数
|
||||
*/
|
||||
withCustomParams(params: Record<string, any>) {
|
||||
Object.assign(this.config, params)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建最终配置
|
||||
*/
|
||||
build(): ProviderSettingsMap[T] {
|
||||
return this.config as ProviderSettingsMap[T]
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider 配置工厂
|
||||
* 提供便捷的配置创建方法
|
||||
*/
|
||||
export class ProviderConfigFactory {
|
||||
/**
|
||||
* 创建配置构建器
|
||||
*/
|
||||
static builder<T extends ProviderId>(providerId: T): ProviderConfigBuilder<T> {
|
||||
return new ProviderConfigBuilder(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 从通用Provider对象创建配置 - 使用更优雅的处理器模式
|
||||
*/
|
||||
static fromProvider<T extends ProviderId>(
|
||||
providerId: T,
|
||||
provider: CompleteProviderConfig<T>,
|
||||
options?: {
|
||||
headers?: Record<string, string>
|
||||
[key: string]: any
|
||||
}
|
||||
): ProviderSettingsMap[T] {
|
||||
const builder = new ProviderConfigBuilder<T>(providerId)
|
||||
|
||||
// 设置基本配置
|
||||
if (provider.apiKey) {
|
||||
builder.withApiKey(provider.apiKey)
|
||||
}
|
||||
|
||||
if (provider.baseURL) {
|
||||
builder.withBaseURL(provider.baseURL)
|
||||
}
|
||||
|
||||
// 设置请求配置
|
||||
if (options?.headers) {
|
||||
builder.withRequestConfig({
|
||||
headers: options.headers
|
||||
})
|
||||
}
|
||||
|
||||
// 使用配置处理器模式 - 更加优雅和可扩展
|
||||
const handler = configHandlers[providerId]
|
||||
if (handler) {
|
||||
handler(builder, provider)
|
||||
}
|
||||
|
||||
// 添加其他自定义参数
|
||||
if (options) {
|
||||
const customOptions = { ...options }
|
||||
delete customOptions.headers // 已经处理过了
|
||||
if (Object.keys(customOptions).length > 0) {
|
||||
builder.withCustomParams(customOptions)
|
||||
}
|
||||
}
|
||||
|
||||
return builder.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 OpenAI 配置
|
||||
*/
|
||||
static createOpenAI(
|
||||
apiKey: string,
|
||||
options?: {
|
||||
baseURL?: string
|
||||
organization?: string
|
||||
project?: string
|
||||
}
|
||||
) {
|
||||
const builder = this.builder('openai')
|
||||
|
||||
// 使用类型安全的重载
|
||||
if (options?.organization || options?.project) {
|
||||
builder.withApiKey(apiKey, {
|
||||
organization: options.organization,
|
||||
project: options.project
|
||||
})
|
||||
} else {
|
||||
builder.withApiKey(apiKey)
|
||||
}
|
||||
|
||||
return builder.withBaseURL(options?.baseURL || 'https://api.openai.com').build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Anthropic 配置
|
||||
*/
|
||||
static createAnthropic(
|
||||
apiKey: string,
|
||||
options?: {
|
||||
baseURL?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('anthropic')
|
||||
.withApiKey(apiKey)
|
||||
.withBaseURL(options?.baseURL || 'https://api.anthropic.com')
|
||||
.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Azure OpenAI 配置
|
||||
*/
|
||||
static createAzureOpenAI(
|
||||
apiKey: string,
|
||||
options: {
|
||||
baseURL: string
|
||||
apiVersion?: string
|
||||
resourceName?: string
|
||||
deploymentName?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('azure')
|
||||
.withApiKey(apiKey)
|
||||
.withBaseURL(options.baseURL)
|
||||
.withAzureConfig({
|
||||
apiVersion: options.apiVersion,
|
||||
resourceName: options.resourceName
|
||||
})
|
||||
.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Google 配置
|
||||
*/
|
||||
static createGoogle(
|
||||
apiKey: string,
|
||||
options?: {
|
||||
baseURL?: string
|
||||
projectId?: string
|
||||
location?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('google')
|
||||
.withApiKey(apiKey)
|
||||
.withBaseURL(options?.baseURL || 'https://generativelanguage.googleapis.com')
|
||||
.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Vertex AI 配置
|
||||
*/
|
||||
static createVertexAI(
|
||||
credentials: {
|
||||
clientEmail: string
|
||||
privateKey: string
|
||||
},
|
||||
options?: {
|
||||
project?: string
|
||||
location?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('google-vertex')
|
||||
.withGoogleCredentials(credentials)
|
||||
.withGoogleVertexConfig({
|
||||
project: options?.project,
|
||||
location: options?.location
|
||||
})
|
||||
.build()
|
||||
}
|
||||
|
||||
static createOpenAICompatible(baseURL: string, apiKey: string) {
|
||||
return this.builder('openai-compatible').withBaseURL(baseURL).withApiKey(apiKey).build()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 便捷的配置创建函数
|
||||
*/
|
||||
export const createProviderConfig = ProviderConfigFactory.fromProvider
|
||||
export const providerConfigBuilder = ProviderConfigFactory.builder
|
||||
@ -14,7 +14,7 @@ import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
|
||||
import { type FalProviderSettings } from '@ai-sdk/fal'
|
||||
import { type FireworksProviderSettings } from '@ai-sdk/fireworks'
|
||||
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
|
||||
import { type GoogleVertexProviderSettings } from '@ai-sdk/google-vertex'
|
||||
import { type GoogleVertexProviderSettings } from '@ai-sdk/google-vertex/edge'
|
||||
import { type GroqProviderSettings } from '@ai-sdk/groq'
|
||||
import { type MistralProviderSettings } from '@ai-sdk/mistral'
|
||||
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||
@ -123,7 +123,7 @@ export class AiProviderRegistry {
|
||||
{
|
||||
id: 'google-vertex',
|
||||
name: 'Google Vertex AI',
|
||||
import: () => import('@ai-sdk/google-vertex'),
|
||||
import: () => import('@ai-sdk/google-vertex/edge'),
|
||||
creatorFunctionName: 'createVertex',
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
@ -270,7 +270,6 @@ export class AiProviderRegistry {
|
||||
}
|
||||
]
|
||||
|
||||
// 注册所有 providers (总计24个)
|
||||
providers.forEach((config) => {
|
||||
this.registry.set(config.id, config)
|
||||
})
|
||||
|
||||
@ -10,76 +10,50 @@
|
||||
|
||||
import {
|
||||
AiClient,
|
||||
AiCore,
|
||||
createClient,
|
||||
type OpenAICompatibleProviderSettings,
|
||||
ProviderConfigFactory,
|
||||
type ProviderId,
|
||||
type ProviderSettingsMap,
|
||||
smoothStream,
|
||||
StreamTextParams
|
||||
} from '@cherrystudio/ai-core'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import { getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
|
||||
// 引入适配器
|
||||
import AiSdkToChunkAdapter from './AiSdkToChunkAdapter'
|
||||
// 引入原有的AiProvider作为fallback
|
||||
import LegacyAiProvider from './index'
|
||||
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
|
||||
import { CompletionsResult } from './middleware/schemas'
|
||||
// 引入参数转换模块
|
||||
|
||||
/**
|
||||
* 将现有 Provider 类型映射到 AI SDK 的 Provider ID
|
||||
* 根据 registry.ts 中的支持列表进行映射
|
||||
*/
|
||||
function mapProviderTypeToAiSdkId(providerType: string): string {
|
||||
// Cherry Studio Provider Type -> AI SDK Provider ID 映射表
|
||||
const typeMapping: Record<string, string> = {
|
||||
// 需要转换的映射
|
||||
grok: 'xai', // grok -> xai
|
||||
'azure-openai': 'azure', // azure-openai -> azure
|
||||
gemini: 'google', // gemini -> google
|
||||
vertexai: 'google-vertex' // vertexai -> google-vertex
|
||||
}
|
||||
|
||||
return typeMapping[providerType]
|
||||
}
|
||||
import { getAiSdkProviderId } from './provider/factory'
|
||||
import { getTimeout } from './transformParameters'
|
||||
|
||||
/**
|
||||
* 将 Provider 配置转换为新 AI SDK 格式
|
||||
*/
|
||||
function providerToAiSdkConfig(provider: Provider): {
|
||||
providerId: ProviderId | 'openai-compatible'
|
||||
options: any
|
||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||
} {
|
||||
console.log('provider', provider)
|
||||
// 1. 先映射 provider 类型到 AI SDK ID
|
||||
const mappedProviderId = mapProviderTypeToAiSdkId(provider.id)
|
||||
const aiSdkProviderId = getAiSdkProviderId(provider)
|
||||
|
||||
// 2. 检查映射后的 provider ID 是否在 AI SDK 注册表中
|
||||
const isSupported = AiCore.isSupported(mappedProviderId)
|
||||
if (aiSdkProviderId !== 'openai-compatible') {
|
||||
const defaultModel = getDefaultModel()
|
||||
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, provider, {
|
||||
timeout: getTimeout(defaultModel)
|
||||
})
|
||||
|
||||
console.log(`Provider mapping: ${provider.type} -> ${mappedProviderId}, supported: ${isSupported}`)
|
||||
|
||||
// 3. 如果映射的 provider 不支持,则使用 openai-compatible
|
||||
if (isSupported) {
|
||||
return {
|
||||
providerId: mappedProviderId as ProviderId,
|
||||
options: {
|
||||
apiKey: provider.apiKey
|
||||
}
|
||||
providerId: aiSdkProviderId as ProviderId,
|
||||
options
|
||||
}
|
||||
} else {
|
||||
console.log(`Using openai-compatible fallback for provider: ${provider.type}`)
|
||||
const compatibleConfig: OpenAICompatibleProviderSettings = {
|
||||
name: provider.name || provider.type,
|
||||
apiKey: provider.apiKey,
|
||||
baseURL: provider.apiHost
|
||||
}
|
||||
const options = ProviderConfigFactory.createOpenAICompatible(provider.apiHost, provider.apiKey)
|
||||
|
||||
return {
|
||||
providerId: 'openai-compatible',
|
||||
options: compatibleConfig
|
||||
options
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { AiPlugin, simulateStreamingMiddleware } from '@cherrystudio/ai-core'
|
||||
import { AiPlugin, extractReasoningMiddleware, simulateStreamingMiddleware } from '@cherrystudio/ai-core'
|
||||
import { isReasoningModel } from '@renderer/config/models'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
@ -140,7 +140,10 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config:
|
||||
// Anthropic特定中间件
|
||||
break
|
||||
case 'openai':
|
||||
// OpenAI特定中间件
|
||||
builder.add({
|
||||
name: 'thinking-tag-extraction',
|
||||
aiSdkMiddlewares: [extractReasoningMiddleware({ tagName: 'think', separator: '\n', startWithReasoning: true })]
|
||||
})
|
||||
break
|
||||
case 'gemini':
|
||||
// Gemini特定中间件
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import { LanguageModelV1Middleware, LanguageModelV1StreamPart } from '@cherrystudio/ai-core'
|
||||
import { ChunkType, ThinkingCompleteChunk } from '@renderer/types/chunk'
|
||||
|
||||
/**
|
||||
* 一个用于统计 LLM "思考时间"(Time to First Token)的 AI SDK 中间件。
|
||||
@ -47,8 +46,8 @@ export default function thinkingTimeMiddleware(): LanguageModelV1Middleware {
|
||||
// 如果流的末尾都是 reasoning,也需要发送 complete 事件
|
||||
if (hasThinkingContent && thinkingStartTime > 0) {
|
||||
const thinkingTime = Date.now() - thinkingStartTime
|
||||
const thinkingCompleteChunk: ThinkingCompleteChunk = {
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
const thinkingCompleteChunk = {
|
||||
type: 'reasoning-signature',
|
||||
text: accumulatedThinkingContent,
|
||||
thinking_millsec: thinkingTime
|
||||
}
|
||||
|
||||
24
src/renderer/src/aiCore/provider/factory.ts
Normal file
24
src/renderer/src/aiCore/provider/factory.ts
Normal file
@ -0,0 +1,24 @@
|
||||
import { AiCore, ProviderId } from '@cherrystudio/ai-core'
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
const PROVIDER_MAPPING: Record<string, ProviderId> = {
|
||||
anthropic: 'anthropic',
|
||||
gemini: 'google',
|
||||
vertexai: 'google-vertex',
|
||||
'azure-openai': 'azure',
|
||||
'openai-response': 'openai'
|
||||
}
|
||||
|
||||
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
|
||||
const providerId = PROVIDER_MAPPING[provider.type]
|
||||
|
||||
if (providerId) {
|
||||
return providerId
|
||||
}
|
||||
|
||||
if (AiCore.isSupported(provider.id)) {
|
||||
return provider.id as ProviderId
|
||||
}
|
||||
|
||||
return 'openai-compatible'
|
||||
}
|
||||
@ -3,7 +3,7 @@
|
||||
* 统一管理从各个 apiClient 提取的参数处理和转换功能
|
||||
*/
|
||||
|
||||
import type { CoreMessage, StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { type CoreMessage, type StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import {
|
||||
isGenerateImageModel,
|
||||
isNotSupportTemperatureAndTopP,
|
||||
|
||||
@ -14,6 +14,7 @@ import {
|
||||
} from '@renderer/config/models'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import store from '@renderer/store'
|
||||
import { Assistant, MCPTool, Model, Provider } from '@renderer/types'
|
||||
import { type Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
|
||||
Loading…
Reference in New Issue
Block a user