feat: add openai-compatible provider and enhance provider configuration

- Introduced the @ai-sdk/openai-compatible package to support compatibility with OpenAI.
- Added a new ProviderConfigFactory and ProviderConfigBuilder for streamlined provider configuration.
- Updated the provider registry to include the new Google Vertex AI import path.
- Enhanced the index.ts to export new provider configuration utilities for better type safety and usability.
- Refactored ApiService and middleware to integrate the new provider configurations effectively.
This commit is contained in:
suyao 2025-06-21 12:48:53 +08:00
parent c99a2fedb7
commit 8ca6341609
No known key found for this signature in database
11 changed files with 1413 additions and 51 deletions

View File

@ -35,6 +35,7 @@
"@ai-sdk/groq": "^1.2.9",
"@ai-sdk/mistral": "^1.2.8",
"@ai-sdk/openai": "^1.3.22",
"@ai-sdk/openai-compatible": "^0.2.14",
"@ai-sdk/perplexity": "^1.1.9",
"@ai-sdk/replicate": "^0.2.8",
"@ai-sdk/togetherai": "^0.2.14",

File diff suppressed because it is too large Load Diff

View File

@ -116,6 +116,15 @@ export type {
export { createClient as createApiClient, getClientInfo, getSupportedProviders } from './clients/ApiClientFactory'
export { getAllProviders, getProvider, isProviderSupported, registerProvider } from './providers/registry'
// ==================== Provider 配置工厂 ====================
export {
BaseProviderConfig,
createProviderConfig,
ProviderConfigBuilder,
providerConfigBuilder,
ProviderConfigFactory
} from './providers/factory'
// ==================== 包信息 ====================
export const AI_CORE_VERSION = '1.0.0'
export const AI_CORE_NAME = '@cherrystudio/ai-core'

View File

@ -0,0 +1,330 @@
/**
* AI Provider
* Provider
*/
import type { ProviderId, ProviderSettingsMap } from './registry'
/**
* Provider
*/
export interface BaseProviderConfig {
apiKey?: string
baseURL?: string
timeout?: number
headers?: Record<string, string>
fetch?: typeof globalThis.fetch
}
/**
* AI SDK Provider
*/
type CompleteProviderConfig<T extends ProviderId> = BaseProviderConfig & Partial<ProviderSettingsMap[T]>
type ConfigHandler<T extends ProviderId> = (
builder: ProviderConfigBuilder<T>,
provider: CompleteProviderConfig<T>
) => void
const configHandlers: {
[K in ProviderId]?: ConfigHandler<K>
} = {
azure: (builder, provider) => {
const azureBuilder = builder as ProviderConfigBuilder<'azure'>
const azureProvider = provider as CompleteProviderConfig<'azure'>
azureBuilder.withAzureConfig({
apiVersion: azureProvider.apiVersion,
resourceName: azureProvider.resourceName
})
},
'google-vertex': (builder, provider) => {
const vertexBuilder = builder as ProviderConfigBuilder<'google-vertex'>
const vertexProvider = provider as CompleteProviderConfig<'google-vertex'>
vertexBuilder.withGoogleVertexConfig({
project: vertexProvider.project,
location: vertexProvider.location
})
}
}
export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
private config: CompleteProviderConfig<T> = {} as CompleteProviderConfig<T>
constructor(private providerId: T) {}
/**
* API Key
*/
withApiKey(apiKey: string): this
withApiKey(apiKey: string, options: T extends 'openai' ? { organization?: string; project?: string } : never): this
withApiKey(apiKey: string, options?: any): this {
this.config.apiKey = apiKey
// 类型安全的 OpenAI 特定配置
if (this.providerId === 'openai' && options) {
const openaiConfig = this.config as CompleteProviderConfig<'openai'>
if (options.organization) {
openaiConfig.organization = options.organization
}
if (options.project) {
openaiConfig.project = options.project
}
}
return this
}
/**
* URL
*/
withBaseURL(baseURL: string) {
this.config.baseURL = baseURL
return this
}
/**
*
*/
withRequestConfig(options: { headers?: Record<string, string>; fetch?: typeof fetch }): this {
if (options.headers) {
this.config.headers = { ...this.config.headers, ...options.headers }
}
if (options.fetch) {
this.config.fetch = options.fetch
}
return this
}
/**
* Azure OpenAI
*/
withAzureConfig(options: { apiVersion?: string; resourceName?: string }): T extends 'azure' ? this : never
withAzureConfig(options: any): any {
if (this.providerId === 'azure') {
const azureConfig = this.config as CompleteProviderConfig<'azure'>
if (options.apiVersion) {
azureConfig.apiVersion = options.apiVersion
}
if (options.resourceName) {
azureConfig.resourceName = options.resourceName
}
}
return this
}
/**
* Google
*/
withGoogleVertexConfig(options: { project?: string; location?: string }): T extends 'google-vertex' ? this : never
withGoogleVertexConfig(options: any): any {
if (this.providerId === 'google-vertex') {
const googleConfig = this.config as CompleteProviderConfig<'google-vertex'>
if (options.project) {
googleConfig.project = options.project
}
if (options.location) {
googleConfig.location = options.location
}
}
return this
}
withGoogleCredentials(credentials: {
clientEmail: string
privateKey: string
}): T extends 'google-vertex' ? this : never
withGoogleCredentials(credentials: any): any {
if (this.providerId === 'google-vertex') {
const vertexConfig = this.config as CompleteProviderConfig<'google-vertex'>
vertexConfig.googleCredentials = credentials
}
return this
}
/**
*
*/
withCustomParams(params: Record<string, any>) {
Object.assign(this.config, params)
return this
}
/**
*
*/
build(): ProviderSettingsMap[T] {
return this.config as ProviderSettingsMap[T]
}
}
/**
* Provider
* 便
*/
export class ProviderConfigFactory {
/**
*
*/
static builder<T extends ProviderId>(providerId: T): ProviderConfigBuilder<T> {
return new ProviderConfigBuilder(providerId)
}
/**
* Provider对象创建配置 - 使
*/
static fromProvider<T extends ProviderId>(
providerId: T,
provider: CompleteProviderConfig<T>,
options?: {
headers?: Record<string, string>
[key: string]: any
}
): ProviderSettingsMap[T] {
const builder = new ProviderConfigBuilder<T>(providerId)
// 设置基本配置
if (provider.apiKey) {
builder.withApiKey(provider.apiKey)
}
if (provider.baseURL) {
builder.withBaseURL(provider.baseURL)
}
// 设置请求配置
if (options?.headers) {
builder.withRequestConfig({
headers: options.headers
})
}
// 使用配置处理器模式 - 更加优雅和可扩展
const handler = configHandlers[providerId]
if (handler) {
handler(builder, provider)
}
// 添加其他自定义参数
if (options) {
const customOptions = { ...options }
delete customOptions.headers // 已经处理过了
if (Object.keys(customOptions).length > 0) {
builder.withCustomParams(customOptions)
}
}
return builder.build()
}
/**
* OpenAI
*/
static createOpenAI(
apiKey: string,
options?: {
baseURL?: string
organization?: string
project?: string
}
) {
const builder = this.builder('openai')
// 使用类型安全的重载
if (options?.organization || options?.project) {
builder.withApiKey(apiKey, {
organization: options.organization,
project: options.project
})
} else {
builder.withApiKey(apiKey)
}
return builder.withBaseURL(options?.baseURL || 'https://api.openai.com').build()
}
/**
* Anthropic
*/
static createAnthropic(
apiKey: string,
options?: {
baseURL?: string
}
) {
return this.builder('anthropic')
.withApiKey(apiKey)
.withBaseURL(options?.baseURL || 'https://api.anthropic.com')
.build()
}
/**
* Azure OpenAI
*/
static createAzureOpenAI(
apiKey: string,
options: {
baseURL: string
apiVersion?: string
resourceName?: string
deploymentName?: string
}
) {
return this.builder('azure')
.withApiKey(apiKey)
.withBaseURL(options.baseURL)
.withAzureConfig({
apiVersion: options.apiVersion,
resourceName: options.resourceName
})
.build()
}
/**
* Google
*/
static createGoogle(
apiKey: string,
options?: {
baseURL?: string
projectId?: string
location?: string
}
) {
return this.builder('google')
.withApiKey(apiKey)
.withBaseURL(options?.baseURL || 'https://generativelanguage.googleapis.com')
.build()
}
/**
* Vertex AI
*/
static createVertexAI(
credentials: {
clientEmail: string
privateKey: string
},
options?: {
project?: string
location?: string
}
) {
return this.builder('google-vertex')
.withGoogleCredentials(credentials)
.withGoogleVertexConfig({
project: options?.project,
location: options?.location
})
.build()
}
static createOpenAICompatible(baseURL: string, apiKey: string) {
return this.builder('openai-compatible').withBaseURL(baseURL).withApiKey(apiKey).build()
}
}
/**
* 便
*/
export const createProviderConfig = ProviderConfigFactory.fromProvider
export const providerConfigBuilder = ProviderConfigFactory.builder

View File

@ -14,7 +14,7 @@ import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
import { type FalProviderSettings } from '@ai-sdk/fal'
import { type FireworksProviderSettings } from '@ai-sdk/fireworks'
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
import { type GoogleVertexProviderSettings } from '@ai-sdk/google-vertex'
import { type GoogleVertexProviderSettings } from '@ai-sdk/google-vertex/edge'
import { type GroqProviderSettings } from '@ai-sdk/groq'
import { type MistralProviderSettings } from '@ai-sdk/mistral'
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
@ -123,7 +123,7 @@ export class AiProviderRegistry {
{
id: 'google-vertex',
name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex'),
import: () => import('@ai-sdk/google-vertex/edge'),
creatorFunctionName: 'createVertex',
supportsImageGeneration: true
},
@ -270,7 +270,6 @@ export class AiProviderRegistry {
}
]
// 注册所有 providers (总计24个)
providers.forEach((config) => {
this.registry.set(config.id, config)
})

View File

@ -10,76 +10,50 @@
import {
AiClient,
AiCore,
createClient,
type OpenAICompatibleProviderSettings,
ProviderConfigFactory,
type ProviderId,
type ProviderSettingsMap,
smoothStream,
StreamTextParams
} from '@cherrystudio/ai-core'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import { getDefaultModel } from '@renderer/services/AssistantService'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
// 引入适配器
import AiSdkToChunkAdapter from './AiSdkToChunkAdapter'
// 引入原有的AiProvider作为fallback
import LegacyAiProvider from './index'
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
import { CompletionsResult } from './middleware/schemas'
// 引入参数转换模块
/**
* Provider AI SDK Provider ID
* registry.ts
*/
function mapProviderTypeToAiSdkId(providerType: string): string {
// Cherry Studio Provider Type -> AI SDK Provider ID 映射表
const typeMapping: Record<string, string> = {
// 需要转换的映射
grok: 'xai', // grok -> xai
'azure-openai': 'azure', // azure-openai -> azure
gemini: 'google', // gemini -> google
vertexai: 'google-vertex' // vertexai -> google-vertex
}
return typeMapping[providerType]
}
import { getAiSdkProviderId } from './provider/factory'
import { getTimeout } from './transformParameters'
/**
* Provider AI SDK
*/
function providerToAiSdkConfig(provider: Provider): {
providerId: ProviderId | 'openai-compatible'
options: any
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
console.log('provider', provider)
// 1. 先映射 provider 类型到 AI SDK ID
const mappedProviderId = mapProviderTypeToAiSdkId(provider.id)
const aiSdkProviderId = getAiSdkProviderId(provider)
// 2. 检查映射后的 provider ID 是否在 AI SDK 注册表中
const isSupported = AiCore.isSupported(mappedProviderId)
if (aiSdkProviderId !== 'openai-compatible') {
const defaultModel = getDefaultModel()
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, provider, {
timeout: getTimeout(defaultModel)
})
console.log(`Provider mapping: ${provider.type} -> ${mappedProviderId}, supported: ${isSupported}`)
// 3. 如果映射的 provider 不支持,则使用 openai-compatible
if (isSupported) {
return {
providerId: mappedProviderId as ProviderId,
options: {
apiKey: provider.apiKey
}
providerId: aiSdkProviderId as ProviderId,
options
}
} else {
console.log(`Using openai-compatible fallback for provider: ${provider.type}`)
const compatibleConfig: OpenAICompatibleProviderSettings = {
name: provider.name || provider.type,
apiKey: provider.apiKey,
baseURL: provider.apiHost
}
const options = ProviderConfigFactory.createOpenAICompatible(provider.apiHost, provider.apiKey)
return {
providerId: 'openai-compatible',
options: compatibleConfig
options
}
}
}

View File

@ -1,4 +1,4 @@
import { AiPlugin, simulateStreamingMiddleware } from '@cherrystudio/ai-core'
import { AiPlugin, extractReasoningMiddleware, simulateStreamingMiddleware } from '@cherrystudio/ai-core'
import { isReasoningModel } from '@renderer/config/models'
import type { Model, Provider } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
@ -140,7 +140,10 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config:
// Anthropic特定中间件
break
case 'openai':
// OpenAI特定中间件
builder.add({
name: 'thinking-tag-extraction',
aiSdkMiddlewares: [extractReasoningMiddleware({ tagName: 'think', separator: '\n', startWithReasoning: true })]
})
break
case 'gemini':
// Gemini特定中间件

View File

@ -1,5 +1,4 @@
import { LanguageModelV1Middleware, LanguageModelV1StreamPart } from '@cherrystudio/ai-core'
import { ChunkType, ThinkingCompleteChunk } from '@renderer/types/chunk'
/**
* LLM "思考时间"Time to First Token AI SDK
@ -47,8 +46,8 @@ export default function thinkingTimeMiddleware(): LanguageModelV1Middleware {
// 如果流的末尾都是 reasoning也需要发送 complete 事件
if (hasThinkingContent && thinkingStartTime > 0) {
const thinkingTime = Date.now() - thinkingStartTime
const thinkingCompleteChunk: ThinkingCompleteChunk = {
type: ChunkType.THINKING_COMPLETE,
const thinkingCompleteChunk = {
type: 'reasoning-signature',
text: accumulatedThinkingContent,
thinking_millsec: thinkingTime
}

View File

@ -0,0 +1,24 @@
import { AiCore, ProviderId } from '@cherrystudio/ai-core'
import { Provider } from '@renderer/types'
const PROVIDER_MAPPING: Record<string, ProviderId> = {
anthropic: 'anthropic',
gemini: 'google',
vertexai: 'google-vertex',
'azure-openai': 'azure',
'openai-response': 'openai'
}
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
const providerId = PROVIDER_MAPPING[provider.type]
if (providerId) {
return providerId
}
if (AiCore.isSupported(provider.id)) {
return provider.id as ProviderId
}
return 'openai-compatible'
}

View File

@ -3,7 +3,7 @@
* apiClient
*/
import type { CoreMessage, StreamTextParams } from '@cherrystudio/ai-core'
import { type CoreMessage, type StreamTextParams } from '@cherrystudio/ai-core'
import {
isGenerateImageModel,
isNotSupportTemperatureAndTopP,

View File

@ -14,6 +14,7 @@ import {
} from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import store from '@renderer/store'
import { Assistant, MCPTool, Model, Provider } from '@renderer/types'
import { type Chunk, ChunkType } from '@renderer/types/chunk'
import { Message } from '@renderer/types/newMessage'