mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
Merge branch 'feat/aisdk-package' of https://github.com/CherryHQ/cherry-studio into feat/aisdk-package
This commit is contained in:
commit
c769e3aa41
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@cherrystudio/ai-core",
|
||||
"version": "1.0.0-alpha.10",
|
||||
"version": "1.0.0-alpha.11",
|
||||
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/index.mjs",
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
|
||||
import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
|
||||
import { wrapModelWithMiddlewares } from '../middleware/wrapper'
|
||||
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../providers/RegistryManagement'
|
||||
|
||||
@ -28,14 +27,9 @@ export class ModelResolver {
|
||||
): Promise<LanguageModelV2> {
|
||||
let finalProviderId = fallbackProviderId
|
||||
let model: LanguageModelV2
|
||||
|
||||
// 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移)
|
||||
if (fallbackProviderId === 'openai' && providerOptions?.mode === 'chat') {
|
||||
// 检查是否支持 chat 模式且不是只支持 chat 的模型
|
||||
if (!isOpenAIChatCompletionOnlyModel(modelId)) {
|
||||
finalProviderId = 'openai-chat'
|
||||
}
|
||||
// 否则使用默认的 openai (responses 模式)
|
||||
finalProviderId = 'openai-chat'
|
||||
}
|
||||
|
||||
// 检查是否是命名空间格式
|
||||
|
||||
@ -1,49 +0,0 @@
|
||||
import { type ProviderId } from '../core/providers/types'
|
||||
|
||||
export function isOpenAIChatCompletionOnlyModel(modelId: string): boolean {
|
||||
if (!modelId) {
|
||||
return false
|
||||
}
|
||||
|
||||
return (
|
||||
modelId.includes('gpt-4o-search-preview') ||
|
||||
modelId.includes('gpt-4o-mini-search-preview') ||
|
||||
modelId.includes('o1-mini') ||
|
||||
modelId.includes('o1-preview')
|
||||
)
|
||||
}
|
||||
|
||||
export function isOpenAIReasoningModel(modelId: string): boolean {
|
||||
return modelId.includes('o1') || modelId.includes('o3') || modelId.includes('o4')
|
||||
}
|
||||
|
||||
export function isOpenAILLMModel(modelId: string): boolean {
|
||||
if (modelId.includes('gpt-4o-image')) {
|
||||
return false
|
||||
}
|
||||
if (isOpenAIReasoningModel(modelId)) {
|
||||
return true
|
||||
}
|
||||
if (modelId.includes('gpt')) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
export function getModelToProviderId(modelId: string): ProviderId | 'openai-compatible' {
|
||||
const id = modelId.toLowerCase()
|
||||
|
||||
if (id.startsWith('claude')) {
|
||||
return 'anthropic'
|
||||
}
|
||||
|
||||
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
|
||||
return 'google'
|
||||
}
|
||||
|
||||
if (isOpenAILLMModel(modelId)) {
|
||||
return 'openai'
|
||||
}
|
||||
|
||||
return 'openai-compatible'
|
||||
}
|
||||
@ -11,6 +11,7 @@ import { createExecutor, generateImage } from '@cherrystudio/ai-core'
|
||||
import { createAndRegisterProvider } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
|
||||
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
@ -36,7 +37,7 @@ export default class ModernAiProvider {
|
||||
this.legacyProvider = new LegacyAiProvider(this.actualProvider)
|
||||
|
||||
// 只保存配置,不预先创建executor
|
||||
this.config = providerToAiSdkConfig(this.actualProvider)
|
||||
this.config = providerToAiSdkConfig(this.actualProvider, model)
|
||||
}
|
||||
|
||||
public getActualProvider() {
|
||||
@ -52,6 +53,28 @@ export default class ModernAiProvider {
|
||||
topicId?: string
|
||||
callType: string
|
||||
}
|
||||
) {
|
||||
if (config.topicId && getEnableDeveloperMode()) {
|
||||
// TypeScript类型窄化:确保topicId是string类型
|
||||
const traceConfig = {
|
||||
...config,
|
||||
topicId: config.topicId
|
||||
}
|
||||
return await this._completionsForTrace(modelId, params, traceConfig)
|
||||
} else {
|
||||
return await this._completions(modelId, params, config)
|
||||
}
|
||||
}
|
||||
|
||||
private async _completions(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
config: AiSdkMiddlewareConfig & {
|
||||
assistant: Assistant
|
||||
// topicId for tracing
|
||||
topicId?: string
|
||||
callType: string
|
||||
}
|
||||
): Promise<CompletionsResult> {
|
||||
// 初始化 provider 到全局管理器
|
||||
try {
|
||||
@ -79,7 +102,7 @@ export default class ModernAiProvider {
|
||||
* 带trace支持的completions方法
|
||||
* 类似于legacy的completionsForTrace,确保AI SDK spans在正确的trace上下文中
|
||||
*/
|
||||
public async completionsForTrace(
|
||||
private async _completionsForTrace(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
config: AiSdkMiddlewareConfig & {
|
||||
@ -114,7 +137,7 @@ export default class ModernAiProvider {
|
||||
modelId,
|
||||
traceName
|
||||
})
|
||||
return await this.completions(modelId, params, config)
|
||||
return await this._completions(modelId, params, config)
|
||||
}
|
||||
|
||||
try {
|
||||
@ -126,7 +149,7 @@ export default class ModernAiProvider {
|
||||
parentSpanCreated: true
|
||||
})
|
||||
|
||||
const result = await this.completions(modelId, params, config)
|
||||
const result = await this._completions(modelId, params, config)
|
||||
|
||||
logger.info('Completions finished, ending parent span', {
|
||||
spanId: span.spanContext().spanId,
|
||||
@ -172,7 +195,6 @@ export default class ModernAiProvider {
|
||||
params: StreamTextParams,
|
||||
config: AiSdkMiddlewareConfig & {
|
||||
assistant: Assistant
|
||||
// topicId for tracing
|
||||
topicId?: string
|
||||
callType: string
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ import {
|
||||
type ProviderId,
|
||||
type ProviderSettingsMap
|
||||
} from '@cherrystudio/ai-core/provider'
|
||||
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { loggerService } from '@renderer/services/LoggerService'
|
||||
@ -68,7 +69,10 @@ export function getActualProvider(model: Model): Provider {
|
||||
* 将 Provider 配置转换为新 AI SDK 格式
|
||||
* 简化版:利用新的别名映射系统
|
||||
*/
|
||||
export function providerToAiSdkConfig(actualProvider: Provider): {
|
||||
export function providerToAiSdkConfig(
|
||||
actualProvider: Provider,
|
||||
model: Model
|
||||
): {
|
||||
providerId: ProviderId | 'openai-compatible'
|
||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||
} {
|
||||
@ -80,10 +84,9 @@ export function providerToAiSdkConfig(actualProvider: Provider): {
|
||||
baseURL: actualProvider.apiHost,
|
||||
apiKey: actualProvider.apiKey
|
||||
}
|
||||
|
||||
// 处理OpenAI模式(简化逻辑)
|
||||
const extraOptions: any = {}
|
||||
if (actualProvider.type === 'openai-response') {
|
||||
if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) {
|
||||
extraOptions.mode = 'responses'
|
||||
} else if (aiSdkProviderId === 'openai') {
|
||||
extraOptions.mode = 'chat'
|
||||
|
||||
@ -395,105 +395,3 @@ export async function buildGenerateTextParams(
|
||||
// 复用流式参数的构建逻辑
|
||||
return await buildStreamTextParams(messages, assistant, provider, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 提取外部工具搜索关键词和问题
|
||||
* 从用户消息中提取用于网络搜索和知识库搜索的关键词
|
||||
* @deprecated
|
||||
*/
|
||||
// export async function extractSearchKeywords(
|
||||
// lastUserMessage: Message,
|
||||
// assistant: Assistant,
|
||||
// options: {
|
||||
// shouldWebSearch?: boolean
|
||||
// shouldKnowledgeSearch?: boolean
|
||||
// lastAnswer?: Message
|
||||
// } = {}
|
||||
// ): Promise<ExtractResults | undefined> {
|
||||
// const { shouldWebSearch = false, shouldKnowledgeSearch = false, lastAnswer } = options
|
||||
|
||||
// if (!lastUserMessage) return undefined
|
||||
|
||||
// // 根据配置决定是否需要提取
|
||||
// const needWebExtract = shouldWebSearch
|
||||
// const needKnowledgeExtract = shouldKnowledgeSearch
|
||||
|
||||
// if (!needWebExtract && !needKnowledgeExtract) return undefined
|
||||
|
||||
// // 选择合适的提示词
|
||||
// let prompt: string
|
||||
// if (needWebExtract && !needKnowledgeExtract) {
|
||||
// prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||
// } else if (!needWebExtract && needKnowledgeExtract) {
|
||||
// prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
|
||||
// } else {
|
||||
// prompt = SEARCH_SUMMARY_PROMPT
|
||||
// }
|
||||
|
||||
// // 构建用于提取的助手配置
|
||||
// const summaryAssistant = getDefaultAssistant()
|
||||
// summaryAssistant.model = assistant.model || getDefaultModel()
|
||||
// summaryAssistant.prompt = prompt
|
||||
|
||||
// try {
|
||||
// const result = await fetchSearchSummary({
|
||||
// messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage],
|
||||
// assistant: summaryAssistant
|
||||
// })
|
||||
|
||||
// if (!result) return getFallbackResult()
|
||||
|
||||
// const extracted = extractInfoFromXML(result.getText())
|
||||
// // 根据需求过滤结果
|
||||
// return {
|
||||
// websearch: needWebExtract ? extracted?.websearch : undefined,
|
||||
// knowledge: needKnowledgeExtract ? extracted?.knowledge : undefined
|
||||
// }
|
||||
// } catch (e: any) {
|
||||
// console.error('extract error', e)
|
||||
// return getFallbackResult()
|
||||
// }
|
||||
|
||||
// function getFallbackResult(): ExtractResults {
|
||||
// const fallbackContent = getMainTextContent(lastUserMessage)
|
||||
// return {
|
||||
// websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
|
||||
// knowledge: shouldKnowledgeSearch
|
||||
// ? {
|
||||
// question: [fallbackContent || 'search'],
|
||||
// rewrite: fallbackContent || 'search'
|
||||
// }
|
||||
// : undefined
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
/**
|
||||
* 获取搜索摘要 - 内部辅助函数
|
||||
* @deprecated
|
||||
*/
|
||||
// async function fetchSearchSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
|
||||
// const model = assistant.model || getDefaultModel()
|
||||
// const provider = getProviderByModel(model)
|
||||
|
||||
// if (!hasApiKey(provider)) {
|
||||
// return null
|
||||
// }
|
||||
|
||||
// const AI = new AiProvider(provider)
|
||||
|
||||
// const params: CompletionsParams = {
|
||||
// callType: 'search',
|
||||
// messages: messages,
|
||||
// assistant,
|
||||
// streamOutput: false
|
||||
// }
|
||||
|
||||
// return await AI.completions(params)
|
||||
// }
|
||||
|
||||
// function hasApiKey(provider: Provider) {
|
||||
// if (!provider) return false
|
||||
// if (provider.id === 'ollama' || provider.id === 'lmstudio' || provider.type === 'vertexai') return true
|
||||
// return !isEmpty(provider.apiKey)
|
||||
// }
|
||||
|
||||
@ -1,4 +1,15 @@
|
||||
import { Assistant, Model, Provider } from '@renderer/types'
|
||||
import { isOpenAIModel, isSupportFlexServiceTierModel } from '@renderer/config/models'
|
||||
import { isSupportServiceTierProvider } from '@renderer/config/providers'
|
||||
import {
|
||||
Assistant,
|
||||
GroqServiceTiers,
|
||||
isGroqServiceTier,
|
||||
isOpenAIServiceTier,
|
||||
Model,
|
||||
OpenAIServiceTiers,
|
||||
Provider,
|
||||
SystemProviderIds
|
||||
} from '@renderer/types'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import { buildGeminiGenerateImageParams } from './image'
|
||||
@ -12,6 +23,35 @@ import {
|
||||
} from './reasoning'
|
||||
import { getWebSearchParams } from './websearch'
|
||||
|
||||
// copy from BaseApiClient.ts
|
||||
const getServiceTier = (model: Model, provider: Provider) => {
|
||||
const serviceTierSetting = provider.serviceTier
|
||||
|
||||
if (!isSupportServiceTierProvider(provider) || !isOpenAIModel(model) || !serviceTierSetting) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 处理不同供应商需要 fallback 到默认值的情况
|
||||
if (provider.id === SystemProviderIds.groq) {
|
||||
if (
|
||||
!isGroqServiceTier(serviceTierSetting) ||
|
||||
(serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
} else {
|
||||
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
|
||||
if (
|
||||
!isOpenAIServiceTier(serviceTierSetting) ||
|
||||
(serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
return serviceTierSetting
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 AI SDK 的 providerOptions
|
||||
* 按 provider 类型分离,保持类型安全
|
||||
@ -28,6 +68,7 @@ export function buildProviderOptions(
|
||||
}
|
||||
): Record<string, any> {
|
||||
const providerId = getAiSdkProviderId(actualProvider)
|
||||
const serviceTierSetting = getServiceTier(model, actualProvider)
|
||||
// 构建 provider 特定的选项
|
||||
let providerSpecificOptions: Record<string, any> = {}
|
||||
|
||||
@ -62,6 +103,7 @@ export function buildProviderOptions(
|
||||
// 合并自定义参数到 provider 特定的选项中
|
||||
providerSpecificOptions = {
|
||||
...providerSpecificOptions,
|
||||
serviceTier: serviceTierSetting,
|
||||
...getCustomParameters(assistant)
|
||||
}
|
||||
|
||||
|
||||
@ -13,7 +13,9 @@ import {
|
||||
isSupportedThinkingTokenModel,
|
||||
isSupportedThinkingTokenQwenModel
|
||||
} from '@renderer/config/models'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { SettingsState } from '@renderer/store/settings'
|
||||
import { Assistant, EFFORT_RATIO, Model } from '@renderer/types'
|
||||
import { ReasoningEffortOptionalParams } from '@renderer/types/sdk'
|
||||
|
||||
@ -205,6 +207,16 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
|
||||
const summaryText = openAI?.summaryText || 'off'
|
||||
|
||||
let reasoningSummary: string | undefined = undefined
|
||||
|
||||
if (summaryText === 'off' || model.id.includes('o1-pro')) {
|
||||
reasoningSummary = undefined
|
||||
} else {
|
||||
reasoningSummary = summaryText
|
||||
}
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
@ -215,7 +227,8 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
|
||||
// OpenAI 推理参数
|
||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
return {
|
||||
reasoningEffort
|
||||
reasoningEffort,
|
||||
reasoningSummary
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -181,6 +181,7 @@ ${t('error.stack')}: ${error.stack || 'N/A'}
|
||||
|
||||
return (
|
||||
<Modal
|
||||
centered
|
||||
title={t('error.detail')}
|
||||
open={open}
|
||||
onCancel={onClose}
|
||||
|
||||
@ -10,7 +10,6 @@ import type { StreamTextParams } from '@renderer/aiCore/types'
|
||||
import { isDedicatedImageGenerationModel, isEmbeddingModel, isQwenMTModel } from '@renderer/config/models'
|
||||
import { LANG_DETECT_PROMPT } from '@renderer/config/prompts'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import store from '@renderer/store'
|
||||
import { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types'
|
||||
@ -138,23 +137,12 @@ export async function fetchChatCompletion({
|
||||
|
||||
// --- Call AI Completions ---
|
||||
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
const enableDeveloperMode = getEnableDeveloperMode()
|
||||
// 在 AI SDK 调用时设置正确的 OpenTelemetry 上下文
|
||||
if (topicId && enableDeveloperMode) {
|
||||
// 使用带trace支持的completions方法,它会自动创建子span并关联到父span
|
||||
await AI.completionsForTrace(modelId, aiSdkParams, {
|
||||
...middlewareConfig,
|
||||
assistant,
|
||||
topicId,
|
||||
callType: 'chat'
|
||||
})
|
||||
} else {
|
||||
await AI.completions(modelId, aiSdkParams, {
|
||||
...middlewareConfig,
|
||||
assistant,
|
||||
callType: 'chat'
|
||||
})
|
||||
}
|
||||
await AI.completions(modelId, aiSdkParams, {
|
||||
...middlewareConfig,
|
||||
assistant,
|
||||
topicId,
|
||||
callType: 'chat'
|
||||
})
|
||||
}
|
||||
|
||||
interface FetchLanguageDetectionProps {
|
||||
@ -311,7 +299,7 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages:
|
||||
await appendTrace({ topicId, traceId: messageWithTrace.traceId, model })
|
||||
}
|
||||
|
||||
const { getText } = await AI.completionsForTrace(model.id, llmMessages, {
|
||||
const { getText } = await AI.completions(model.id, llmMessages, {
|
||||
...middlewareConfig,
|
||||
assistant: summaryAssistant,
|
||||
topicId,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user