mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-07 05:39:05 +08:00
526 lines
17 KiB
TypeScript
526 lines
17 KiB
TypeScript
import type { LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
|
|
import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
|
|
import type {
|
|
ImageBlockParam,
|
|
MessageCreateParams,
|
|
TextBlockParam,
|
|
Tool as AnthropicTool
|
|
} from '@anthropic-ai/sdk/resources/messages'
|
|
import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core'
|
|
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
|
|
import { loggerService } from '@logger'
|
|
import anthropicService from '@main/services/AnthropicService'
|
|
import copilotService from '@main/services/CopilotService'
|
|
import { reduxService } from '@main/services/ReduxService'
|
|
import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@shared/adapters'
|
|
import { isGemini3ModelId } from '@shared/middleware'
|
|
import {
|
|
type AiSdkConfig,
|
|
type AiSdkConfigContext,
|
|
formatProviderApiHost,
|
|
initializeSharedProviders,
|
|
type ProviderFormatContext,
|
|
providerToAiSdkConfig as sharedProviderToAiSdkConfig,
|
|
resolveActualProvider
|
|
} from '@shared/provider'
|
|
import { COPILOT_DEFAULT_HEADERS } from '@shared/provider/constant'
|
|
import { defaultAppHeaders } from '@shared/utils'
|
|
import type { Provider } from '@types'
|
|
import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai'
|
|
import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel } from 'ai'
|
|
import { net } from 'electron'
|
|
import type { Response } from 'express'
|
|
|
|
import { reasoningCache } from './cache'
|
|
|
|
const logger = loggerService.withContext('UnifiedMessagesService')
|
|
|
|
const MAGIC_STRING = 'skip_thought_signature_validator'
|
|
|
|
initializeSharedProviders({
|
|
warn: (message) => logger.warn(message),
|
|
error: (message, error) => logger.error(message, error)
|
|
})
|
|
|
|
/**
|
|
* Configuration for unified message streaming
|
|
*/
|
|
export interface UnifiedStreamConfig {
|
|
response: Response
|
|
provider: Provider
|
|
modelId: string
|
|
params: MessageCreateParams
|
|
onError?: (error: unknown) => void
|
|
onComplete?: () => void
|
|
/**
|
|
* Optional AI SDK middlewares to apply
|
|
*/
|
|
middlewares?: LanguageModelV2Middleware[]
|
|
/**
|
|
* Optional AI Core plugins to use with the executor
|
|
*/
|
|
plugins?: AiPlugin[]
|
|
}
|
|
|
|
/**
|
|
* Configuration for non-streaming message generation
|
|
*/
|
|
export interface GenerateUnifiedMessageConfig {
|
|
provider: Provider
|
|
modelId: string
|
|
params: MessageCreateParams
|
|
middlewares?: LanguageModelV2Middleware[]
|
|
plugins?: AiPlugin[]
|
|
}
|
|
|
|
function getMainProcessFormatContext(): ProviderFormatContext {
|
|
const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai')
|
|
return {
|
|
vertex: {
|
|
project: vertexSettings?.projectId || 'default-project',
|
|
location: vertexSettings?.location || 'us-central1'
|
|
}
|
|
}
|
|
}
|
|
|
|
const mainProcessSdkContext: AiSdkConfigContext = {
|
|
getRotatedApiKey: (provider) => {
|
|
const keys = provider.apiKey.split(',').map((k) => k.trim())
|
|
return keys[0] || provider.apiKey
|
|
},
|
|
fetch: net.fetch as typeof globalThis.fetch
|
|
}
|
|
|
|
function getActualProvider(provider: Provider, modelId: string): Provider {
|
|
const model = provider.models?.find((m) => m.id === modelId)
|
|
if (!model) return provider
|
|
return resolveActualProvider(provider, model)
|
|
}
|
|
|
|
function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig {
|
|
const actualProvider = getActualProvider(provider, modelId)
|
|
const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext())
|
|
return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext)
|
|
}
|
|
|
|
function convertAnthropicToolResultToAiSdk(
|
|
content: string | Array<TextBlockParam | ImageBlockParam>
|
|
): LanguageModelV2ToolResultOutput {
|
|
if (typeof content === 'string') {
|
|
return { type: 'text', value: content }
|
|
}
|
|
const values: Array<{ type: 'text'; text: string } | { type: 'media'; data: string; mediaType: string }> = []
|
|
for (const block of content) {
|
|
if (block.type === 'text') {
|
|
values.push({ type: 'text', text: block.text })
|
|
} else if (block.type === 'image') {
|
|
values.push({
|
|
type: 'media',
|
|
data: block.source.type === 'base64' ? block.source.data : block.source.url,
|
|
mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png'
|
|
})
|
|
}
|
|
}
|
|
return { type: 'content', value: values }
|
|
}
|
|
|
|
function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Record<string, Tool> | undefined {
|
|
if (!tools || tools.length === 0) return undefined
|
|
|
|
const aiSdkTools: Record<string, Tool> = {}
|
|
for (const anthropicTool of tools) {
|
|
if (anthropicTool.type === 'bash_20250124') continue
|
|
const toolDef = anthropicTool as AnthropicTool
|
|
const parameters = toolDef.input_schema as Parameters<typeof jsonSchema>[0]
|
|
aiSdkTools[toolDef.name] = tool({
|
|
description: toolDef.description || '',
|
|
inputSchema: jsonSchema(parameters),
|
|
execute: async (input: Record<string, unknown>) => input
|
|
})
|
|
}
|
|
return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined
|
|
}
|
|
|
|
function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage[] {
|
|
const messages: ModelMessage[] = []
|
|
|
|
// System message
|
|
if (params.system) {
|
|
if (typeof params.system === 'string') {
|
|
messages.push({ role: 'system', content: params.system })
|
|
} else if (Array.isArray(params.system)) {
|
|
const systemText = params.system
|
|
.filter((block) => block.type === 'text')
|
|
.map((block) => block.text)
|
|
.join('\n')
|
|
if (systemText) {
|
|
messages.push({ role: 'system', content: systemText })
|
|
}
|
|
}
|
|
}
|
|
|
|
const toolCallIdToName = new Map<string, string>()
|
|
for (const msg of params.messages) {
|
|
if (Array.isArray(msg.content)) {
|
|
for (const block of msg.content) {
|
|
if (block.type === 'tool_use') {
|
|
toolCallIdToName.set(block.id, block.name)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// User/assistant messages
|
|
for (const msg of params.messages) {
|
|
if (typeof msg.content === 'string') {
|
|
messages.push({
|
|
role: msg.role === 'user' ? 'user' : 'assistant',
|
|
content: msg.content
|
|
})
|
|
} else if (Array.isArray(msg.content)) {
|
|
const textParts: TextPart[] = []
|
|
const imageParts: ImagePart[] = []
|
|
const reasoningParts: ReasoningPart[] = []
|
|
const toolCallParts: ToolCallPart[] = []
|
|
const toolResultParts: ToolResultPart[] = []
|
|
|
|
for (const block of msg.content) {
|
|
if (block.type === 'text') {
|
|
textParts.push({ type: 'text', text: block.text })
|
|
} else if (block.type === 'thinking') {
|
|
reasoningParts.push({ type: 'reasoning', text: block.thinking })
|
|
} else if (block.type === 'redacted_thinking') {
|
|
reasoningParts.push({ type: 'reasoning', text: block.data })
|
|
} else if (block.type === 'image') {
|
|
const source = block.source
|
|
if (source.type === 'base64') {
|
|
imageParts.push({ type: 'image', image: `data:${source.media_type};base64,${source.data}` })
|
|
} else if (source.type === 'url') {
|
|
imageParts.push({ type: 'image', image: source.url })
|
|
}
|
|
} else if (block.type === 'tool_use') {
|
|
toolCallParts.push({
|
|
type: 'tool-call',
|
|
toolName: block.name,
|
|
toolCallId: block.id,
|
|
input: block.input
|
|
})
|
|
} else if (block.type === 'tool_result') {
|
|
// Look up toolName from the pre-built map (covers cross-message references)
|
|
const toolName = toolCallIdToName.get(block.tool_use_id) || 'unknown'
|
|
toolResultParts.push({
|
|
type: 'tool-result',
|
|
toolCallId: block.tool_use_id,
|
|
toolName,
|
|
output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' }
|
|
})
|
|
}
|
|
}
|
|
|
|
if (toolResultParts.length > 0) {
|
|
messages.push({ role: 'tool', content: [...toolResultParts] })
|
|
}
|
|
|
|
if (msg.role === 'user') {
|
|
const userContent = [...textParts, ...imageParts]
|
|
if (userContent.length > 0) {
|
|
messages.push({ role: 'user', content: userContent })
|
|
}
|
|
} else {
|
|
const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts]
|
|
if (assistantContent.length > 0) {
|
|
let providerOptions: ProviderOptions | undefined = undefined
|
|
if (reasoningCache.get('openrouter')) {
|
|
providerOptions = {
|
|
openrouter: {
|
|
reasoning_details: (reasoningCache.get('openrouter') as JSONValue[]) || []
|
|
}
|
|
}
|
|
} else if (isGemini3ModelId(params.model)) {
|
|
providerOptions = {
|
|
google: {
|
|
thoughtSignature: MAGIC_STRING
|
|
}
|
|
}
|
|
}
|
|
messages.push({ role: 'assistant', content: assistantContent, providerOptions })
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return messages
|
|
}
|
|
|
|
interface ExecuteStreamConfig {
|
|
provider: Provider
|
|
modelId: string
|
|
params: MessageCreateParams
|
|
middlewares?: LanguageModelV2Middleware[]
|
|
plugins?: AiPlugin[]
|
|
onEvent?: (event: Parameters<typeof formatSSEEvent>[0]) => void
|
|
}
|
|
|
|
/**
|
|
* Create AI SDK provider instance from config
|
|
* Similar to renderer's createAiSdkProvider
|
|
*/
|
|
async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider> {
|
|
let providerId = config.providerId
|
|
|
|
// Handle special provider modes (same as renderer)
|
|
if (providerId === 'openai' && config.options?.mode === 'chat') {
|
|
providerId = 'openai-chat'
|
|
} else if (providerId === 'azure' && config.options?.mode === 'responses') {
|
|
providerId = 'azure-responses'
|
|
} else if (providerId === 'cherryin' && config.options?.mode === 'chat') {
|
|
providerId = 'cherryin-chat'
|
|
}
|
|
|
|
const provider = await createProviderCore(providerId, config.options)
|
|
|
|
logger.debug('AI SDK provider created', {
|
|
providerId,
|
|
hasOptions: !!config.options
|
|
})
|
|
|
|
return provider
|
|
}
|
|
|
|
/**
|
|
* Prepare special provider configuration for providers that need dynamic tokens
|
|
* Similar to renderer's prepareSpecialProviderConfig
|
|
*/
|
|
async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkConfig): Promise<AiSdkConfig> {
|
|
switch (provider.id) {
|
|
case 'copilot': {
|
|
const storedHeaders =
|
|
((await reduxService.select('state.copilot.defaultHeaders')) as Record<string, string> | null) ?? {}
|
|
const headers: Record<string, string> = {
|
|
...COPILOT_DEFAULT_HEADERS,
|
|
...storedHeaders
|
|
}
|
|
|
|
try {
|
|
const { token } = await copilotService.getToken(null as any, headers)
|
|
config.options.apiKey = token
|
|
const existingHeaders = (config.options.headers as Record<string, string> | undefined) ?? {}
|
|
config.options.headers = {
|
|
...headers,
|
|
...existingHeaders
|
|
}
|
|
logger.debug('Copilot token retrieved successfully')
|
|
} catch (error) {
|
|
logger.error('Failed to get Copilot token', error as Error)
|
|
throw new Error('Failed to get Copilot token. Please re-authorize Copilot.')
|
|
}
|
|
break
|
|
}
|
|
case 'anthropic': {
|
|
if (provider.authType === 'oauth') {
|
|
try {
|
|
const oauthToken = await anthropicService.getValidAccessToken()
|
|
if (!oauthToken) {
|
|
throw new Error('Anthropic OAuth token not available. Please re-authorize.')
|
|
}
|
|
config.options = {
|
|
...config.options,
|
|
headers: {
|
|
...(config.options.headers ? config.options.headers : {}),
|
|
'Content-Type': 'application/json',
|
|
'anthropic-version': '2023-06-01',
|
|
'anthropic-beta': 'oauth-2025-04-20',
|
|
Authorization: `Bearer ${oauthToken}`
|
|
},
|
|
baseURL: 'https://api.anthropic.com/v1',
|
|
apiKey: ''
|
|
}
|
|
logger.debug('Anthropic OAuth token retrieved successfully')
|
|
} catch (error) {
|
|
logger.error('Failed to get Anthropic OAuth token', error as Error)
|
|
throw new Error('Failed to get Anthropic OAuth token. Please re-authorize.')
|
|
}
|
|
}
|
|
break
|
|
}
|
|
// Note: cherryai requires request-level signing which is not easily supported here
|
|
// It would need custom fetch implementation similar to renderer
|
|
}
|
|
return config
|
|
}
|
|
|
|
/**
|
|
* Core stream execution function - single source of truth for AI SDK calls
|
|
*/
|
|
async function executeStream(config: ExecuteStreamConfig): Promise<AiSdkToAnthropicSSE> {
|
|
const { provider, modelId, params, middlewares = [], plugins = [], onEvent } = config
|
|
|
|
// Convert provider config to AI SDK config
|
|
let sdkConfig = providerToAiSdkConfig(provider, modelId)
|
|
|
|
// Prepare special provider config (Copilot, Anthropic OAuth, etc.)
|
|
sdkConfig = await prepareSpecialProviderConfig(provider, sdkConfig)
|
|
|
|
logger.debug('Created AI SDK config', {
|
|
providerId: sdkConfig.providerId,
|
|
hasOptions: !!sdkConfig.options,
|
|
message: params.messages
|
|
})
|
|
|
|
// Create provider instance and get language model
|
|
const aiSdkProvider = await createAiSdkProvider(sdkConfig)
|
|
const baseModel = aiSdkProvider.languageModel(modelId)
|
|
|
|
// Apply middlewares if present
|
|
const model =
|
|
middlewares.length > 0 && typeof baseModel === 'object'
|
|
? (wrapLanguageModel({ model: baseModel, middleware: middlewares }) as typeof baseModel)
|
|
: baseModel
|
|
|
|
// Create executor with plugins
|
|
const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins)
|
|
|
|
// Convert messages and tools
|
|
const coreMessages = convertAnthropicToAiMessages(params)
|
|
const tools = convertAnthropicToolsToAiSdk(params.tools)
|
|
|
|
// Create the adapter
|
|
const adapter = new AiSdkToAnthropicSSE({
|
|
model: `${provider.id}:${modelId}`,
|
|
onEvent: onEvent || (() => {})
|
|
})
|
|
|
|
// Execute stream - pass model object instead of string
|
|
const result = await executor.streamText({
|
|
model, // Now passing LanguageModel object, not string
|
|
messages: coreMessages,
|
|
maxOutputTokens: params.max_tokens,
|
|
temperature: params.temperature,
|
|
topP: params.top_p,
|
|
stopSequences: params.stop_sequences,
|
|
stopWhen: stepCountIs(100),
|
|
headers: defaultAppHeaders(),
|
|
tools,
|
|
providerOptions: {}
|
|
})
|
|
|
|
// Process the stream through the adapter
|
|
await adapter.processStream(result.fullStream)
|
|
|
|
return adapter
|
|
}
|
|
|
|
/**
|
|
* Stream a message request using AI SDK executor and convert to Anthropic SSE format
|
|
*/
|
|
export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise<void> {
|
|
const { response, provider, modelId, params, onError, onComplete, middlewares = [], plugins = [] } = config
|
|
|
|
logger.info('Starting unified message stream', {
|
|
providerId: provider.id,
|
|
providerType: provider.type,
|
|
modelId,
|
|
stream: params.stream,
|
|
middlewareCount: middlewares.length,
|
|
pluginCount: plugins.length
|
|
})
|
|
|
|
try {
|
|
response.setHeader('Content-Type', 'text/event-stream')
|
|
response.setHeader('Cache-Control', 'no-cache')
|
|
response.setHeader('Connection', 'keep-alive')
|
|
response.setHeader('X-Accel-Buffering', 'no')
|
|
|
|
await executeStream({
|
|
provider,
|
|
modelId,
|
|
params,
|
|
middlewares,
|
|
plugins,
|
|
onEvent: (event) => {
|
|
logger.silly('Streaming event', { eventType: event.type })
|
|
const sseData = formatSSEEvent(event)
|
|
response.write(sseData)
|
|
}
|
|
})
|
|
|
|
// Send done marker
|
|
response.write(formatSSEDone())
|
|
response.end()
|
|
|
|
logger.info('Unified message stream completed', { providerId: provider.id, modelId })
|
|
onComplete?.()
|
|
} catch (error) {
|
|
logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId })
|
|
onError?.(error)
|
|
throw error
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Generate a non-streaming message response
|
|
*
|
|
* Uses simulateStreamingMiddleware to reuse the same streaming logic,
|
|
* similar to renderer's ModernAiProvider pattern.
|
|
*/
|
|
export async function generateUnifiedMessage(
|
|
providerOrConfig: Provider | GenerateUnifiedMessageConfig,
|
|
modelId?: string,
|
|
params?: MessageCreateParams
|
|
): Promise<ReturnType<typeof AiSdkToAnthropicSSE.prototype.buildNonStreamingResponse>> {
|
|
// Support both old signature and new config-based signature
|
|
let config: GenerateUnifiedMessageConfig
|
|
if ('provider' in providerOrConfig && 'modelId' in providerOrConfig && 'params' in providerOrConfig) {
|
|
config = providerOrConfig
|
|
} else {
|
|
config = {
|
|
provider: providerOrConfig as Provider,
|
|
modelId: modelId!,
|
|
params: params!
|
|
}
|
|
}
|
|
|
|
const { provider, middlewares = [], plugins = [] } = config
|
|
|
|
logger.info('Starting unified message generation', {
|
|
providerId: provider.id,
|
|
providerType: provider.type,
|
|
modelId: config.modelId,
|
|
middlewareCount: middlewares.length,
|
|
pluginCount: plugins.length
|
|
})
|
|
|
|
try {
|
|
// Add simulateStreamingMiddleware to reuse streaming logic for non-streaming
|
|
const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares]
|
|
|
|
const adapter = await executeStream({
|
|
provider,
|
|
modelId: config.modelId,
|
|
params: config.params,
|
|
middlewares: allMiddlewares,
|
|
plugins
|
|
})
|
|
|
|
const finalResponse = adapter.buildNonStreamingResponse()
|
|
|
|
logger.info('Unified message generation completed', {
|
|
providerId: provider.id,
|
|
modelId: config.modelId
|
|
})
|
|
|
|
return finalResponse
|
|
} catch (error) {
|
|
logger.error('Error in unified message generation', error as Error, {
|
|
providerId: provider.id,
|
|
modelId: config.modelId
|
|
})
|
|
throw error
|
|
}
|
|
}
|
|
|
|
export default {
|
|
streamUnifiedMessages,
|
|
generateUnifiedMessage
|
|
}
|