Merge branch 'feat/aisdk-package' of https://github.com/CherryHQ/cherry-studio into feat/aisdk-package

This commit is contained in:
suyao 2025-08-29 13:08:30 +08:00
commit c769e3aa41
No known key found for this signature in database
10 changed files with 100 additions and 188 deletions

View File

@ -1,6 +1,6 @@
{
"name": "@cherrystudio/ai-core",
"version": "1.0.0-alpha.10",
"version": "1.0.0-alpha.11",
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
"main": "dist/index.js",
"module": "dist/index.mjs",

View File

@ -7,7 +7,6 @@
import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
import { wrapModelWithMiddlewares } from '../middleware/wrapper'
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../providers/RegistryManagement'
@ -28,14 +27,9 @@ export class ModelResolver {
): Promise<LanguageModelV2> {
let finalProviderId = fallbackProviderId
let model: LanguageModelV2
// 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移)
if (fallbackProviderId === 'openai' && providerOptions?.mode === 'chat') {
// 检查是否支持 chat 模式且不是只支持 chat 的模型
if (!isOpenAIChatCompletionOnlyModel(modelId)) {
finalProviderId = 'openai-chat'
}
// 否则使用默认的 openai (responses 模式)
finalProviderId = 'openai-chat'
}
// 检查是否是命名空间格式

View File

@ -1,49 +0,0 @@
import { type ProviderId } from '../core/providers/types'
export function isOpenAIChatCompletionOnlyModel(modelId: string): boolean {
if (!modelId) {
return false
}
return (
modelId.includes('gpt-4o-search-preview') ||
modelId.includes('gpt-4o-mini-search-preview') ||
modelId.includes('o1-mini') ||
modelId.includes('o1-preview')
)
}
export function isOpenAIReasoningModel(modelId: string): boolean {
return modelId.includes('o1') || modelId.includes('o3') || modelId.includes('o4')
}
export function isOpenAILLMModel(modelId: string): boolean {
if (modelId.includes('gpt-4o-image')) {
return false
}
if (isOpenAIReasoningModel(modelId)) {
return true
}
if (modelId.includes('gpt')) {
return true
}
return false
}
export function getModelToProviderId(modelId: string): ProviderId | 'openai-compatible' {
const id = modelId.toLowerCase()
if (id.startsWith('claude')) {
return 'anthropic'
}
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
return 'google'
}
if (isOpenAILLMModel(modelId)) {
return 'openai'
}
return 'openai-compatible'
}

View File

@ -11,6 +11,7 @@ import { createExecutor, generateImage } from '@cherrystudio/ai-core'
import { createAndRegisterProvider } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
@ -36,7 +37,7 @@ export default class ModernAiProvider {
this.legacyProvider = new LegacyAiProvider(this.actualProvider)
// 只保存配置不预先创建executor
this.config = providerToAiSdkConfig(this.actualProvider)
this.config = providerToAiSdkConfig(this.actualProvider, model)
}
public getActualProvider() {
@ -52,6 +53,28 @@ export default class ModernAiProvider {
topicId?: string
callType: string
}
) {
if (config.topicId && getEnableDeveloperMode()) {
// TypeScript类型窄化确保topicId是string类型
const traceConfig = {
...config,
topicId: config.topicId
}
return await this._completionsForTrace(modelId, params, traceConfig)
} else {
return await this._completions(modelId, params, config)
}
}
private async _completions(
modelId: string,
params: StreamTextParams,
config: AiSdkMiddlewareConfig & {
assistant: Assistant
// topicId for tracing
topicId?: string
callType: string
}
): Promise<CompletionsResult> {
// 初始化 provider 到全局管理器
try {
@ -79,7 +102,7 @@ export default class ModernAiProvider {
* trace支持的completions方法
* legacy的completionsForTraceAI SDK spans在正确的trace上下文中
*/
public async completionsForTrace(
private async _completionsForTrace(
modelId: string,
params: StreamTextParams,
config: AiSdkMiddlewareConfig & {
@ -114,7 +137,7 @@ export default class ModernAiProvider {
modelId,
traceName
})
return await this.completions(modelId, params, config)
return await this._completions(modelId, params, config)
}
try {
@ -126,7 +149,7 @@ export default class ModernAiProvider {
parentSpanCreated: true
})
const result = await this.completions(modelId, params, config)
const result = await this._completions(modelId, params, config)
logger.info('Completions finished, ending parent span', {
spanId: span.spanContext().spanId,
@ -172,7 +195,6 @@ export default class ModernAiProvider {
params: StreamTextParams,
config: AiSdkMiddlewareConfig & {
assistant: Assistant
// topicId for tracing
topicId?: string
callType: string
}

View File

@ -4,6 +4,7 @@ import {
type ProviderId,
type ProviderSettingsMap
} from '@cherrystudio/ai-core/provider'
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { loggerService } from '@renderer/services/LoggerService'
@ -68,7 +69,10 @@ export function getActualProvider(model: Model): Provider {
* Provider AI SDK
*
*/
export function providerToAiSdkConfig(actualProvider: Provider): {
export function providerToAiSdkConfig(
actualProvider: Provider,
model: Model
): {
providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
@ -80,10 +84,9 @@ export function providerToAiSdkConfig(actualProvider: Provider): {
baseURL: actualProvider.apiHost,
apiKey: actualProvider.apiKey
}
// 处理OpenAI模式简化逻辑
const extraOptions: any = {}
if (actualProvider.type === 'openai-response') {
if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'openai') {
extraOptions.mode = 'chat'

View File

@ -395,105 +395,3 @@ export async function buildGenerateTextParams(
// 复用流式参数的构建逻辑
return await buildStreamTextParams(messages, assistant, provider, options)
}
/**
*
*
* @deprecated
*/
// export async function extractSearchKeywords(
// lastUserMessage: Message,
// assistant: Assistant,
// options: {
// shouldWebSearch?: boolean
// shouldKnowledgeSearch?: boolean
// lastAnswer?: Message
// } = {}
// ): Promise<ExtractResults | undefined> {
// const { shouldWebSearch = false, shouldKnowledgeSearch = false, lastAnswer } = options
// if (!lastUserMessage) return undefined
// // 根据配置决定是否需要提取
// const needWebExtract = shouldWebSearch
// const needKnowledgeExtract = shouldKnowledgeSearch
// if (!needWebExtract && !needKnowledgeExtract) return undefined
// // 选择合适的提示词
// let prompt: string
// if (needWebExtract && !needKnowledgeExtract) {
// prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
// } else if (!needWebExtract && needKnowledgeExtract) {
// prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
// } else {
// prompt = SEARCH_SUMMARY_PROMPT
// }
// // 构建用于提取的助手配置
// const summaryAssistant = getDefaultAssistant()
// summaryAssistant.model = assistant.model || getDefaultModel()
// summaryAssistant.prompt = prompt
// try {
// const result = await fetchSearchSummary({
// messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage],
// assistant: summaryAssistant
// })
// if (!result) return getFallbackResult()
// const extracted = extractInfoFromXML(result.getText())
// // 根据需求过滤结果
// return {
// websearch: needWebExtract ? extracted?.websearch : undefined,
// knowledge: needKnowledgeExtract ? extracted?.knowledge : undefined
// }
// } catch (e: any) {
// console.error('extract error', e)
// return getFallbackResult()
// }
// function getFallbackResult(): ExtractResults {
// const fallbackContent = getMainTextContent(lastUserMessage)
// return {
// websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
// knowledge: shouldKnowledgeSearch
// ? {
// question: [fallbackContent || 'search'],
// rewrite: fallbackContent || 'search'
// }
// : undefined
// }
// }
// }
/**
* -
* @deprecated
*/
// async function fetchSearchSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
// const model = assistant.model || getDefaultModel()
// const provider = getProviderByModel(model)
// if (!hasApiKey(provider)) {
// return null
// }
// const AI = new AiProvider(provider)
// const params: CompletionsParams = {
// callType: 'search',
// messages: messages,
// assistant,
// streamOutput: false
// }
// return await AI.completions(params)
// }
// function hasApiKey(provider: Provider) {
// if (!provider) return false
// if (provider.id === 'ollama' || provider.id === 'lmstudio' || provider.type === 'vertexai') return true
// return !isEmpty(provider.apiKey)
// }

View File

@ -1,4 +1,15 @@
import { Assistant, Model, Provider } from '@renderer/types'
import { isOpenAIModel, isSupportFlexServiceTierModel } from '@renderer/config/models'
import { isSupportServiceTierProvider } from '@renderer/config/providers'
import {
Assistant,
GroqServiceTiers,
isGroqServiceTier,
isOpenAIServiceTier,
Model,
OpenAIServiceTiers,
Provider,
SystemProviderIds
} from '@renderer/types'
import { getAiSdkProviderId } from '../provider/factory'
import { buildGeminiGenerateImageParams } from './image'
@ -12,6 +23,35 @@ import {
} from './reasoning'
import { getWebSearchParams } from './websearch'
// copy from BaseApiClient.ts
const getServiceTier = (model: Model, provider: Provider) => {
const serviceTierSetting = provider.serviceTier
if (!isSupportServiceTierProvider(provider) || !isOpenAIModel(model) || !serviceTierSetting) {
return undefined
}
// 处理不同供应商需要 fallback 到默认值的情况
if (provider.id === SystemProviderIds.groq) {
if (
!isGroqServiceTier(serviceTierSetting) ||
(serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
}
} else {
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
if (
!isOpenAIServiceTier(serviceTierSetting) ||
(serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
}
}
return serviceTierSetting
}
/**
* AI SDK providerOptions
* provider
@ -28,6 +68,7 @@ export function buildProviderOptions(
}
): Record<string, any> {
const providerId = getAiSdkProviderId(actualProvider)
const serviceTierSetting = getServiceTier(model, actualProvider)
// 构建 provider 特定的选项
let providerSpecificOptions: Record<string, any> = {}
@ -62,6 +103,7 @@ export function buildProviderOptions(
// 合并自定义参数到 provider 特定的选项中
providerSpecificOptions = {
...providerSpecificOptions,
serviceTier: serviceTierSetting,
...getCustomParameters(assistant)
}

View File

@ -13,7 +13,9 @@ import {
isSupportedThinkingTokenModel,
isSupportedThinkingTokenQwenModel
} from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
import { SettingsState } from '@renderer/store/settings'
import { Assistant, EFFORT_RATIO, Model } from '@renderer/types'
import { ReasoningEffortOptionalParams } from '@renderer/types/sdk'
@ -205,6 +207,16 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
if (!isReasoningModel(model)) {
return {}
}
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
const summaryText = openAI?.summaryText || 'off'
let reasoningSummary: string | undefined = undefined
if (summaryText === 'off' || model.id.includes('o1-pro')) {
reasoningSummary = undefined
} else {
reasoningSummary = summaryText
}
const reasoningEffort = assistant?.settings?.reasoning_effort
@ -215,7 +227,8 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
// OpenAI 推理参数
if (isSupportedReasoningEffortOpenAIModel(model)) {
return {
reasoningEffort
reasoningEffort,
reasoningSummary
}
}

View File

@ -181,6 +181,7 @@ ${t('error.stack')}: ${error.stack || 'N/A'}
return (
<Modal
centered
title={t('error.detail')}
open={open}
onCancel={onClose}

View File

@ -10,7 +10,6 @@ import type { StreamTextParams } from '@renderer/aiCore/types'
import { isDedicatedImageGenerationModel, isEmbeddingModel, isQwenMTModel } from '@renderer/config/models'
import { LANG_DETECT_PROMPT } from '@renderer/config/prompts'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import store from '@renderer/store'
import { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types'
@ -138,23 +137,12 @@ export async function fetchChatCompletion({
// --- Call AI Completions ---
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
const enableDeveloperMode = getEnableDeveloperMode()
// 在 AI SDK 调用时设置正确的 OpenTelemetry 上下文
if (topicId && enableDeveloperMode) {
// 使用带trace支持的completions方法它会自动创建子span并关联到父span
await AI.completionsForTrace(modelId, aiSdkParams, {
...middlewareConfig,
assistant,
topicId,
callType: 'chat'
})
} else {
await AI.completions(modelId, aiSdkParams, {
...middlewareConfig,
assistant,
callType: 'chat'
})
}
await AI.completions(modelId, aiSdkParams, {
...middlewareConfig,
assistant,
topicId,
callType: 'chat'
})
}
interface FetchLanguageDetectionProps {
@ -311,7 +299,7 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages:
await appendTrace({ topicId, traceId: messageWithTrace.traceId, model })
}
const { getText } = await AI.completionsForTrace(model.id, llmMessages, {
const { getText } = await AI.completions(model.id, llmMessages, {
...middlewareConfig,
assistant: summaryAssistant,
topicId,