mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-01 01:30:51 +08:00
feat: add shared AI SDK middlewares and refactor middleware handling
This commit is contained in:
parent
77c1b77113
commit
ce25001590
15
packages/shared/middleware/index.ts
Normal file
15
packages/shared/middleware/index.ts
Normal file
@ -0,0 +1,15 @@
|
||||
/**
|
||||
* Shared AI SDK Middlewares
|
||||
*
|
||||
* Environment-agnostic middlewares that can be used in both
|
||||
* renderer process and main process (API server).
|
||||
*/
|
||||
|
||||
export {
|
||||
buildSharedMiddlewares,
|
||||
getReasoningTagName,
|
||||
isGemini3ModelId,
|
||||
openrouterReasoningMiddleware,
|
||||
type SharedMiddlewareConfig,
|
||||
skipGeminiThoughtSignatureMiddleware
|
||||
} from './middlewares'
|
||||
205
packages/shared/middleware/middlewares.ts
Normal file
205
packages/shared/middleware/middlewares.ts
Normal file
@ -0,0 +1,205 @@
|
||||
/**
|
||||
* Shared AI SDK Middlewares
|
||||
*
|
||||
* These middlewares are environment-agnostic and can be used in both
|
||||
* renderer process and main process (API server).
|
||||
*/
|
||||
import type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
|
||||
import { extractReasoningMiddleware } from 'ai'
|
||||
|
||||
/**
|
||||
* Configuration for building shared middlewares
|
||||
*/
|
||||
export interface SharedMiddlewareConfig {
|
||||
/**
|
||||
* Whether to enable reasoning extraction
|
||||
*/
|
||||
enableReasoning?: boolean
|
||||
|
||||
/**
|
||||
* Tag name for reasoning extraction
|
||||
* Defaults based on model ID
|
||||
*/
|
||||
reasoningTagName?: string
|
||||
|
||||
/**
|
||||
* Model ID - used to determine default reasoning tag and model detection
|
||||
*/
|
||||
modelId?: string
|
||||
|
||||
/**
|
||||
* Provider ID (Cherry Studio provider ID)
|
||||
* Used for provider-specific middlewares like OpenRouter
|
||||
*/
|
||||
providerId?: string
|
||||
|
||||
/**
|
||||
* AI SDK Provider ID
|
||||
* Used for Gemini thought signature middleware
|
||||
* e.g., 'google', 'google-vertex'
|
||||
*/
|
||||
aiSdkProviderId?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if model ID represents a Gemini 3 (2.5) model
|
||||
* that requires thought signature handling
|
||||
*
|
||||
* @param modelId - The model ID string (not Model object)
|
||||
*/
|
||||
export function isGemini3ModelId(modelId?: string): boolean {
|
||||
if (!modelId) return false
|
||||
const lowerModelId = modelId.toLowerCase()
|
||||
return lowerModelId.includes('gemini-2.5') || lowerModelId.includes('gemini-exp') || lowerModelId.includes('gemini-3')
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the default reasoning tag name based on model ID
|
||||
*
|
||||
* Different models use different tags for reasoning content:
|
||||
* - Most models: 'think'
|
||||
* - GPT-OSS models: 'reasoning'
|
||||
* - Gemini models: 'thought'
|
||||
* - Seed models: 'seed:think'
|
||||
*/
|
||||
export function getReasoningTagName(modelId?: string): string {
|
||||
if (!modelId) return 'think'
|
||||
const lowerModelId = modelId.toLowerCase()
|
||||
if (lowerModelId.includes('gpt-oss')) return 'reasoning'
|
||||
if (lowerModelId.includes('gemini')) return 'thought'
|
||||
if (lowerModelId.includes('seed-oss-36b')) return 'seed:think'
|
||||
return 'think'
|
||||
}
|
||||
|
||||
/**
|
||||
* Skip Gemini Thought Signature Middleware
|
||||
*
|
||||
* Due to the complexity of multi-model client requests (which can switch
|
||||
* to other models mid-process), this middleware skips all Gemini 3
|
||||
* thinking signatures validation.
|
||||
*
|
||||
* @param aiSdkId - AI SDK Provider ID (e.g., 'google', 'google-vertex')
|
||||
* @returns LanguageModelV2Middleware
|
||||
*/
|
||||
export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelV2Middleware {
|
||||
const MAGIC_STRING = 'skip_thought_signature_validator'
|
||||
return {
|
||||
middlewareVersion: 'v2',
|
||||
|
||||
transformParams: async ({ params }) => {
|
||||
const transformedParams = { ...params }
|
||||
// Process messages in prompt
|
||||
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
|
||||
transformedParams.prompt = transformedParams.prompt.map((message) => {
|
||||
if (typeof message.content !== 'string') {
|
||||
for (const part of message.content) {
|
||||
const googleOptions = part?.providerOptions?.[aiSdkId]
|
||||
if (googleOptions?.thoughtSignature) {
|
||||
googleOptions.thoughtSignature = MAGIC_STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
return message
|
||||
})
|
||||
}
|
||||
|
||||
return transformedParams
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenRouter Reasoning Middleware
|
||||
*
|
||||
* Filters out [REDACTED] blocks from OpenRouter reasoning responses.
|
||||
* OpenRouter may include [REDACTED] markers in reasoning content that
|
||||
* should be removed for cleaner output.
|
||||
*
|
||||
* @see https://openrouter.ai/docs/docs/best-practices/reasoning-tokens
|
||||
* @returns LanguageModelV2Middleware
|
||||
*/
|
||||
export function openrouterReasoningMiddleware(): LanguageModelV2Middleware {
|
||||
const REDACTED_BLOCK = '[REDACTED]'
|
||||
return {
|
||||
middlewareVersion: 'v2',
|
||||
wrapGenerate: async ({ doGenerate }) => {
|
||||
const { content, ...rest } = await doGenerate()
|
||||
const modifiedContent = content.map((part) => {
|
||||
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
|
||||
return {
|
||||
...part,
|
||||
text: part.text.replace(REDACTED_BLOCK, '')
|
||||
}
|
||||
}
|
||||
return part
|
||||
})
|
||||
return { content: modifiedContent, ...rest }
|
||||
},
|
||||
wrapStream: async ({ doStream }) => {
|
||||
const { stream, ...rest } = await doStream()
|
||||
return {
|
||||
stream: stream.pipeThrough(
|
||||
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
|
||||
transform(
|
||||
chunk: LanguageModelV2StreamPart,
|
||||
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
|
||||
) {
|
||||
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
delta: chunk.delta.replace(REDACTED_BLOCK, '')
|
||||
})
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
),
|
||||
...rest
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build shared middlewares based on configuration
|
||||
*
|
||||
* This function builds a set of middlewares that are commonly needed
|
||||
* across different environments (renderer, API server).
|
||||
*
|
||||
* @param config - Configuration for middleware building
|
||||
* @returns Array of AI SDK middlewares
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* import { buildSharedMiddlewares } from '@shared/middleware'
|
||||
*
|
||||
* const middlewares = buildSharedMiddlewares({
|
||||
* enableReasoning: true,
|
||||
* modelId: 'gemini-2.5-pro',
|
||||
* providerId: 'openrouter',
|
||||
* aiSdkProviderId: 'google'
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export function buildSharedMiddlewares(config: SharedMiddlewareConfig): LanguageModelV2Middleware[] {
|
||||
const middlewares: LanguageModelV2Middleware[] = []
|
||||
|
||||
// 1. Reasoning extraction middleware
|
||||
if (config.enableReasoning) {
|
||||
const tagName = config.reasoningTagName || getReasoningTagName(config.modelId)
|
||||
middlewares.push(extractReasoningMiddleware({ tagName }))
|
||||
}
|
||||
|
||||
// 2. OpenRouter-specific: filter [REDACTED] blocks
|
||||
if (config.providerId === 'openrouter' && config.enableReasoning) {
|
||||
middlewares.push(openrouterReasoningMiddleware())
|
||||
}
|
||||
|
||||
// 3. Gemini 3 (2.5) specific: skip thought signature validation
|
||||
if (isGemini3ModelId(config.modelId) && config.aiSdkProviderId) {
|
||||
middlewares.push(skipGeminiThoughtSignatureMiddleware(config.aiSdkProviderId))
|
||||
}
|
||||
|
||||
return middlewares
|
||||
}
|
||||
@ -1,4 +1,4 @@
|
||||
import type { LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
|
||||
import type { LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
|
||||
import type { ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
|
||||
import type {
|
||||
ImageBlockParam,
|
||||
@ -6,7 +6,7 @@ import type {
|
||||
TextBlockParam,
|
||||
Tool as AnthropicTool
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
|
||||
import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import { reduxService } from '@main/services/ReduxService'
|
||||
import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@shared/adapters'
|
||||
@ -21,8 +21,8 @@ import {
|
||||
} from '@shared/provider'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import type { Provider } from '@types'
|
||||
import type { ImagePart, LanguageModel, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai'
|
||||
import { jsonSchema, stepCountIs, streamText, tool } from 'ai'
|
||||
import type { ImagePart, ModelMessage, TextPart, Tool } from 'ai'
|
||||
import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool } from 'ai'
|
||||
import { net } from 'electron'
|
||||
import type { Response } from 'express'
|
||||
|
||||
@ -33,6 +33,9 @@ initializeSharedProviders({
|
||||
error: (message, error) => logger.error(message, error)
|
||||
})
|
||||
|
||||
/**
|
||||
* Configuration for unified message streaming
|
||||
*/
|
||||
export interface UnifiedStreamConfig {
|
||||
response: Response
|
||||
provider: Provider
|
||||
@ -40,12 +43,31 @@ export interface UnifiedStreamConfig {
|
||||
params: MessageCreateParams
|
||||
onError?: (error: unknown) => void
|
||||
onComplete?: () => void
|
||||
/**
|
||||
* Optional AI SDK middlewares to apply
|
||||
*/
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
/**
|
||||
* Optional AI Core plugins to use with the executor
|
||||
*/
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Main process format context for formatProviderApiHost
|
||||
* Unlike renderer, main process doesn't have direct access to store getters, so use reduxService cache
|
||||
* Configuration for non-streaming message generation
|
||||
*/
|
||||
export interface GenerateUnifiedMessageConfig {
|
||||
provider: Provider
|
||||
modelId: string
|
||||
params: MessageCreateParams
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Internal Utilities
|
||||
// ============================================================================
|
||||
|
||||
function getMainProcessFormatContext(): ProviderFormatContext {
|
||||
const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai')
|
||||
return {
|
||||
@ -56,12 +78,7 @@ function getMainProcessFormatContext(): ProviderFormatContext {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Main process context for providerToAiSdkConfig
|
||||
* Main process doesn't have access to browser APIs like window.keyv
|
||||
*/
|
||||
const mainProcessSdkContext: AiSdkConfigContext = {
|
||||
// Simple key rotation - just return first key (no persistent rotation in main process)
|
||||
getRotatedApiKey: (provider) => {
|
||||
const keys = provider.apiKey.split(',').map((k) => k.trim())
|
||||
return keys[0] || provider.apiKey
|
||||
@ -69,199 +86,82 @@ const mainProcessSdkContext: AiSdkConfigContext = {
|
||||
fetch: net.fetch as typeof globalThis.fetch
|
||||
}
|
||||
|
||||
/**
|
||||
* Get actual provider configuration for a model
|
||||
*
|
||||
* For aggregated providers (new-api, aihubmix, vertexai, azure-openai),
|
||||
* this resolves the actual provider type based on the model's characteristics.
|
||||
*/
|
||||
function getActualProvider(provider: Provider, modelId: string): Provider {
|
||||
// Find the model in provider's models list
|
||||
const model = provider.models?.find((m) => m.id === modelId)
|
||||
if (!model) {
|
||||
// If model not found, return provider as-is
|
||||
return provider
|
||||
}
|
||||
|
||||
// Resolve actual provider based on model
|
||||
if (!model) return provider
|
||||
return resolveActualProvider(provider, model)
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert Cherry Studio Provider to AI SDK config
|
||||
* Uses shared implementation with main process context
|
||||
*/
|
||||
function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig {
|
||||
// First resolve actual provider for aggregated providers
|
||||
const actualProvider = getActualProvider(provider, modelId)
|
||||
|
||||
// Format the provider's apiHost for AI SDK
|
||||
const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext())
|
||||
|
||||
// Use shared implementation
|
||||
return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an AI SDK provider from Cherry Studio provider configuration
|
||||
*/
|
||||
async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {
|
||||
try {
|
||||
const provider = await createProviderCore(config.providerId, config.options)
|
||||
logger.debug('AI SDK provider created', {
|
||||
providerId: config.providerId,
|
||||
hasOptions: !!config.options
|
||||
})
|
||||
return provider
|
||||
} catch (error) {
|
||||
logger.error('Failed to create AI SDK provider', error as Error, {
|
||||
providerId: config.providerId
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an AI SDK language model from a Cherry Studio provider configuration
|
||||
* Uses shared provider utilities for consistent behavior with renderer
|
||||
*/
|
||||
async function createLanguageModel(provider: Provider, modelId: string): Promise<LanguageModel> {
|
||||
logger.debug('Creating language model', {
|
||||
providerId: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId,
|
||||
apiHost: provider.apiHost
|
||||
})
|
||||
|
||||
// Convert provider config to AI SDK config
|
||||
const config = providerToAiSdkConfig(provider, modelId)
|
||||
|
||||
// Create the AI SDK provider
|
||||
const aiSdkProvider = await createAiSdkProvider(config)
|
||||
if (!aiSdkProvider) {
|
||||
throw new Error(`Failed to create AI SDK provider for ${provider.id}`)
|
||||
}
|
||||
|
||||
// Get the language model
|
||||
return aiSdkProvider.languageModel(modelId)
|
||||
}
|
||||
|
||||
function convertAnthropicToolResultToAiSdk(
|
||||
content: string | Array<TextBlockParam | ImageBlockParam>
|
||||
): LanguageModelV2ToolResultOutput {
|
||||
if (typeof content === 'string') {
|
||||
return {
|
||||
type: 'text',
|
||||
value: content
|
||||
}
|
||||
} else {
|
||||
const values: Array<
|
||||
| { type: 'text'; text: string }
|
||||
| {
|
||||
type: 'media'
|
||||
/**
|
||||
Base-64 encoded media data.
|
||||
*/
|
||||
data: string
|
||||
/**
|
||||
IANA media type.
|
||||
@see https://www.iana.org/assignments/media-types/media-types.xhtml
|
||||
*/
|
||||
mediaType: string
|
||||
}
|
||||
> = []
|
||||
for (const block of content) {
|
||||
if (block.type === 'text') {
|
||||
values.push({
|
||||
type: 'text',
|
||||
text: block.text
|
||||
})
|
||||
} else if (block.type === 'image') {
|
||||
values.push({
|
||||
type: 'media',
|
||||
data: block.source.type === 'base64' ? block.source.data : block.source.url,
|
||||
mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png'
|
||||
})
|
||||
}
|
||||
}
|
||||
return {
|
||||
type: 'content',
|
||||
value: values
|
||||
return { type: 'text', value: content }
|
||||
}
|
||||
const values: Array<{ type: 'text'; text: string } | { type: 'media'; data: string; mediaType: string }> = []
|
||||
for (const block of content) {
|
||||
if (block.type === 'text') {
|
||||
values.push({ type: 'text', text: block.text })
|
||||
} else if (block.type === 'image') {
|
||||
values.push({
|
||||
type: 'media',
|
||||
data: block.source.type === 'base64' ? block.source.data : block.source.url,
|
||||
mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png'
|
||||
})
|
||||
}
|
||||
}
|
||||
return { type: 'content', value: values }
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert Anthropic tools format to AI SDK tools format
|
||||
*/
|
||||
function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Record<string, Tool> | undefined {
|
||||
if (!tools || tools.length === 0) {
|
||||
return undefined
|
||||
}
|
||||
if (!tools || tools.length === 0) return undefined
|
||||
|
||||
const aiSdkTools: Record<string, Tool> = {}
|
||||
|
||||
for (const anthropicTool of tools) {
|
||||
// Handle different tool types
|
||||
if (anthropicTool.type === 'bash_20250124') {
|
||||
// Skip computer use and bash tools - these are Anthropic-specific
|
||||
continue
|
||||
}
|
||||
|
||||
// Regular tool (type === 'custom' or no type)
|
||||
if (anthropicTool.type === 'bash_20250124') continue
|
||||
const toolDef = anthropicTool as AnthropicTool
|
||||
const parameters = toolDef.input_schema as Parameters<typeof jsonSchema>[0]
|
||||
|
||||
aiSdkTools[toolDef.name] = tool({
|
||||
description: toolDef.description || '',
|
||||
inputSchema: jsonSchema(parameters),
|
||||
execute: async (input: Record<string, unknown>) => {
|
||||
return input
|
||||
}
|
||||
execute: async (input: Record<string, unknown>) => input
|
||||
})
|
||||
}
|
||||
|
||||
return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert Anthropic MessageCreateParams to AI SDK message format
|
||||
*/
|
||||
function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage[] {
|
||||
const messages: ModelMessage[] = []
|
||||
|
||||
// Add system message if present
|
||||
// System message
|
||||
if (params.system) {
|
||||
if (typeof params.system === 'string') {
|
||||
messages.push({
|
||||
role: 'system',
|
||||
content: params.system
|
||||
})
|
||||
messages.push({ role: 'system', content: params.system })
|
||||
} else if (Array.isArray(params.system)) {
|
||||
// Handle TextBlockParam array
|
||||
const systemText = params.system
|
||||
.filter((block) => block.type === 'text')
|
||||
.map((block) => block.text)
|
||||
.join('\n')
|
||||
if (systemText) {
|
||||
messages.push({
|
||||
role: 'system',
|
||||
content: systemText
|
||||
})
|
||||
messages.push({ role: 'system', content: systemText })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert user/assistant messages
|
||||
// User/assistant messages
|
||||
for (const msg of params.messages) {
|
||||
if (typeof msg.content === 'string') {
|
||||
if (msg.role === 'user') {
|
||||
messages.push({ role: 'user', content: msg.content })
|
||||
} else {
|
||||
messages.push({ role: 'assistant', content: msg.content })
|
||||
}
|
||||
messages.push({
|
||||
role: msg.role === 'user' ? 'user' : 'assistant',
|
||||
content: msg.content
|
||||
})
|
||||
} else if (Array.isArray(msg.content)) {
|
||||
// Handle content blocks
|
||||
const textParts: TextPart[] = []
|
||||
const imageParts: ImagePart[] = []
|
||||
const reasoningParts: ReasoningPart[] = []
|
||||
@ -278,15 +178,9 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
|
||||
} else if (block.type === 'image') {
|
||||
const source = block.source
|
||||
if (source.type === 'base64') {
|
||||
imageParts.push({
|
||||
type: 'image',
|
||||
image: `data:${source.media_type};base64,${source.data}`
|
||||
})
|
||||
imageParts.push({ type: 'image', image: `data:${source.media_type};base64,${source.data}` })
|
||||
} else if (source.type === 'url') {
|
||||
imageParts.push({
|
||||
type: 'image',
|
||||
image: source.url
|
||||
})
|
||||
imageParts.push({ type: 'image', image: source.url })
|
||||
}
|
||||
} else if (block.type === 'tool_use') {
|
||||
toolCallParts.push({
|
||||
@ -306,30 +200,18 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
|
||||
}
|
||||
|
||||
if (toolResultParts.length > 0) {
|
||||
messages.push({
|
||||
role: 'tool',
|
||||
content: [...toolResultParts]
|
||||
})
|
||||
messages.push({ role: 'tool', content: [...toolResultParts] })
|
||||
}
|
||||
|
||||
// Build the message based on role
|
||||
// Only push user/assistant message if there's actual content (avoid empty messages)
|
||||
if (msg.role === 'user') {
|
||||
const userContent = [...textParts, ...imageParts]
|
||||
if (userContent.length > 0) {
|
||||
messages.push({
|
||||
role: 'user',
|
||||
content: userContent
|
||||
})
|
||||
messages.push({ role: 'user', content: userContent })
|
||||
}
|
||||
} else {
|
||||
// Assistant messages contain tool calls, not tool results
|
||||
const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts]
|
||||
if (assistantContent.length > 0) {
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
content: assistantContent
|
||||
})
|
||||
messages.push({ role: 'assistant', content: assistantContent })
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -338,67 +220,54 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
|
||||
return messages
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream a message request using AI SDK and convert to Anthropic SSE format
|
||||
*/
|
||||
// TODO: 使用ai-core executor集成中间件和transformstream进来
|
||||
export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise<void> {
|
||||
const { response, provider, modelId, params, onError, onComplete } = config
|
||||
interface ExecuteStreamConfig {
|
||||
provider: Provider
|
||||
modelId: string
|
||||
params: MessageCreateParams
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
plugins?: AiPlugin[]
|
||||
onEvent?: (event: Parameters<typeof formatSSEEvent>[0]) => void
|
||||
}
|
||||
|
||||
logger.info('Starting unified message stream', {
|
||||
providerId: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId,
|
||||
stream: params.stream
|
||||
/**
|
||||
* Core stream execution function - single source of truth for AI SDK calls
|
||||
*/
|
||||
async function executeStream(config: ExecuteStreamConfig): Promise<AiSdkToAnthropicSSE> {
|
||||
const { provider, modelId, params, middlewares = [], plugins = [], onEvent } = config
|
||||
|
||||
// Convert provider config to AI SDK config
|
||||
const sdkConfig = providerToAiSdkConfig(provider, modelId)
|
||||
|
||||
logger.debug('Created AI SDK config', {
|
||||
providerId: sdkConfig.providerId,
|
||||
hasOptions: !!sdkConfig.options
|
||||
})
|
||||
|
||||
try {
|
||||
response.setHeader('Content-Type', 'text/event-stream')
|
||||
response.setHeader('Cache-Control', 'no-cache')
|
||||
response.setHeader('Connection', 'keep-alive')
|
||||
response.setHeader('X-Accel-Buffering', 'no')
|
||||
// Create executor with plugins
|
||||
const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins)
|
||||
|
||||
const model = await createLanguageModel(provider, modelId)
|
||||
// Convert messages and tools
|
||||
const coreMessages = convertAnthropicToAiMessages(params)
|
||||
const tools = convertAnthropicToolsToAiSdk(params.tools)
|
||||
|
||||
const coreMessages = convertAnthropicToAiMessages(params)
|
||||
logger.debug('Converted messages', {
|
||||
originalCount: params.messages.length,
|
||||
convertedCount: coreMessages.length,
|
||||
hasSystem: !!params.system,
|
||||
hasTools: !!tools,
|
||||
toolCount: tools ? Object.keys(tools).length : 0
|
||||
})
|
||||
|
||||
// Convert tools if present
|
||||
const tools = convertAnthropicToolsToAiSdk(params.tools)
|
||||
// Create the adapter
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: `${provider.id}:${modelId}`,
|
||||
onEvent: onEvent || (() => {})
|
||||
})
|
||||
|
||||
logger.debug('Converted messages', {
|
||||
originalCount: params.messages.length,
|
||||
convertedCount: coreMessages.length,
|
||||
hasSystem: !!params.system,
|
||||
hasTools: !!tools,
|
||||
toolCount: tools ? Object.keys(tools).length : 0,
|
||||
toolNames: tools ? Object.keys(tools).slice(0, 10) : [],
|
||||
paramsToolCount: params.tools?.length || 0
|
||||
})
|
||||
|
||||
// Debug: Log message structure to understand tool_result handling
|
||||
logger.silly('Message structure for debugging', {
|
||||
messages: coreMessages.map((m) => ({
|
||||
role: m.role,
|
||||
contentTypes: Array.isArray(m.content)
|
||||
? m.content.map((c: { type: string }) => c.type)
|
||||
: typeof m.content === 'string'
|
||||
? ['string']
|
||||
: ['unknown']
|
||||
}))
|
||||
})
|
||||
|
||||
// Create the adapter
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: `${provider.id}:${modelId}`,
|
||||
onEvent: (event) => {
|
||||
const sseData = formatSSEEvent(event)
|
||||
response.write(sseData)
|
||||
}
|
||||
})
|
||||
|
||||
// Start streaming
|
||||
const result = streamText({
|
||||
model,
|
||||
// Execute stream
|
||||
const result = await executor.streamText(
|
||||
{
|
||||
model: modelId,
|
||||
messages: coreMessages,
|
||||
maxOutputTokens: params.max_tokens,
|
||||
temperature: params.temperature,
|
||||
@ -408,38 +277,65 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis
|
||||
headers: defaultAppHeaders(),
|
||||
tools,
|
||||
providerOptions: {}
|
||||
})
|
||||
},
|
||||
{ middlewares }
|
||||
)
|
||||
|
||||
// Process the stream through the adapter
|
||||
await adapter.processStream(result.fullStream)
|
||||
// Process the stream through the adapter
|
||||
await adapter.processStream(result.fullStream)
|
||||
|
||||
return adapter
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream a message request using AI SDK executor and convert to Anthropic SSE format
|
||||
*/
|
||||
export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise<void> {
|
||||
const { response, provider, modelId, params, onError, onComplete, middlewares = [], plugins = [] } = config
|
||||
|
||||
logger.info('Starting unified message stream', {
|
||||
providerId: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId,
|
||||
stream: params.stream,
|
||||
middlewareCount: middlewares.length,
|
||||
pluginCount: plugins.length
|
||||
})
|
||||
|
||||
try {
|
||||
response.setHeader('Content-Type', 'text/event-stream')
|
||||
response.setHeader('Cache-Control', 'no-cache')
|
||||
response.setHeader('Connection', 'keep-alive')
|
||||
response.setHeader('X-Accel-Buffering', 'no')
|
||||
|
||||
await executeStream({
|
||||
provider,
|
||||
modelId,
|
||||
params,
|
||||
middlewares,
|
||||
plugins,
|
||||
onEvent: (event) => {
|
||||
const sseData = formatSSEEvent(event)
|
||||
response.write(sseData)
|
||||
}
|
||||
})
|
||||
|
||||
// Send done marker
|
||||
response.write(formatSSEDone())
|
||||
response.end()
|
||||
|
||||
logger.info('Unified message stream completed', {
|
||||
providerId: provider.id,
|
||||
modelId
|
||||
})
|
||||
|
||||
logger.info('Unified message stream completed', { providerId: provider.id, modelId })
|
||||
onComplete?.()
|
||||
} catch (error) {
|
||||
logger.error('Error in unified message stream', error as Error, {
|
||||
providerId: provider.id,
|
||||
modelId
|
||||
})
|
||||
logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId })
|
||||
|
||||
// Try to send error event if response is still writable
|
||||
if (!response.writableEnded) {
|
||||
try {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error'
|
||||
response.write(
|
||||
`event: error\ndata: ${JSON.stringify({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'api_error',
|
||||
message: errorMessage
|
||||
}
|
||||
error: { type: 'api_error', message: errorMessage }
|
||||
})}\n\n`
|
||||
)
|
||||
response.end()
|
||||
@ -455,64 +351,61 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis
|
||||
|
||||
/**
|
||||
* Generate a non-streaming message response
|
||||
*
|
||||
* Uses simulateStreamingMiddleware to reuse the same streaming logic,
|
||||
* similar to renderer's ModernAiProvider pattern.
|
||||
*/
|
||||
export async function generateUnifiedMessage(
|
||||
provider: Provider,
|
||||
modelId: string,
|
||||
params: MessageCreateParams
|
||||
providerOrConfig: Provider | GenerateUnifiedMessageConfig,
|
||||
modelId?: string,
|
||||
params?: MessageCreateParams
|
||||
): Promise<ReturnType<typeof AiSdkToAnthropicSSE.prototype.buildNonStreamingResponse>> {
|
||||
// Support both old signature and new config-based signature
|
||||
let config: GenerateUnifiedMessageConfig
|
||||
if ('provider' in providerOrConfig && 'modelId' in providerOrConfig && 'params' in providerOrConfig) {
|
||||
config = providerOrConfig
|
||||
} else {
|
||||
config = {
|
||||
provider: providerOrConfig as Provider,
|
||||
modelId: modelId!,
|
||||
params: params!
|
||||
}
|
||||
}
|
||||
|
||||
const { provider, middlewares = [], plugins = [] } = config
|
||||
|
||||
logger.info('Starting unified message generation', {
|
||||
providerId: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId
|
||||
modelId: config.modelId,
|
||||
middlewareCount: middlewares.length,
|
||||
pluginCount: plugins.length
|
||||
})
|
||||
|
||||
try {
|
||||
// Create language model (async - uses @cherrystudio/ai-core)
|
||||
const model = await createLanguageModel(provider, modelId)
|
||||
// Add simulateStreamingMiddleware to reuse streaming logic for non-streaming
|
||||
const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares]
|
||||
|
||||
// Convert messages and tools
|
||||
const coreMessages = convertAnthropicToAiMessages(params)
|
||||
const tools = convertAnthropicToolsToAiSdk(params.tools)
|
||||
|
||||
// Create adapter to collect the response
|
||||
let finalResponse: ReturnType<typeof AiSdkToAnthropicSSE.prototype.buildNonStreamingResponse> | null = null
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: `${provider.id}:${modelId}`,
|
||||
onEvent: () => {
|
||||
// We don't need to emit events for non-streaming
|
||||
}
|
||||
const adapter = await executeStream({
|
||||
provider,
|
||||
modelId: config.modelId,
|
||||
params: config.params,
|
||||
middlewares: allMiddlewares,
|
||||
plugins
|
||||
})
|
||||
|
||||
// Generate text
|
||||
const result = streamText({
|
||||
model,
|
||||
messages: coreMessages,
|
||||
maxOutputTokens: params.max_tokens,
|
||||
temperature: params.temperature,
|
||||
topP: params.top_p,
|
||||
stopSequences: params.stop_sequences,
|
||||
headers: defaultAppHeaders(),
|
||||
tools,
|
||||
stopWhen: stepCountIs(100)
|
||||
})
|
||||
|
||||
// Process the stream to build the response
|
||||
await adapter.processStream(result.fullStream)
|
||||
|
||||
// Get the final response
|
||||
finalResponse = adapter.buildNonStreamingResponse()
|
||||
const finalResponse = adapter.buildNonStreamingResponse()
|
||||
|
||||
logger.info('Unified message generation completed', {
|
||||
providerId: provider.id,
|
||||
modelId
|
||||
modelId: config.modelId
|
||||
})
|
||||
|
||||
return finalResponse
|
||||
} catch (error) {
|
||||
logger.error('Error in unified message generation', error as Error, {
|
||||
providerId: provider.id,
|
||||
modelId
|
||||
modelId: config.modelId
|
||||
})
|
||||
throw error
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ import type { MCPTool } from '@renderer/types'
|
||||
import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
|
||||
import { openrouterReasoningMiddleware, skipGeminiThoughtSignatureMiddleware } from '@shared/middleware'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
|
||||
import { isEmpty } from 'lodash'
|
||||
@ -13,9 +14,7 @@ import { getAiSdkProviderId } from '../provider/factory'
|
||||
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
|
||||
import { noThinkMiddleware } from './noThinkMiddleware'
|
||||
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
|
||||
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
|
||||
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
|
||||
import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware'
|
||||
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
|
||||
|
||||
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')
|
||||
|
||||
@ -1,50 +0,0 @@
|
||||
import type { LanguageModelV2StreamPart } from '@ai-sdk/provider'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
|
||||
/**
|
||||
* https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude
|
||||
*
|
||||
* @returns LanguageModelMiddleware - a middleware filter redacted block
|
||||
*/
|
||||
export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
|
||||
const REDACTED_BLOCK = '[REDACTED]'
|
||||
return {
|
||||
middlewareVersion: 'v2',
|
||||
wrapGenerate: async ({ doGenerate }) => {
|
||||
const { content, ...rest } = await doGenerate()
|
||||
const modifiedContent = content.map((part) => {
|
||||
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
|
||||
return {
|
||||
...part,
|
||||
text: part.text.replace(REDACTED_BLOCK, '')
|
||||
}
|
||||
}
|
||||
return part
|
||||
})
|
||||
return { content: modifiedContent, ...rest }
|
||||
},
|
||||
wrapStream: async ({ doStream }) => {
|
||||
const { stream, ...rest } = await doStream()
|
||||
return {
|
||||
stream: stream.pipeThrough(
|
||||
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
|
||||
transform(
|
||||
chunk: LanguageModelV2StreamPart,
|
||||
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
|
||||
) {
|
||||
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
delta: chunk.delta.replace(REDACTED_BLOCK, '')
|
||||
})
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
),
|
||||
...rest
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,36 +0,0 @@
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
|
||||
/**
|
||||
* skip Gemini Thought Signature Middleware
|
||||
* 由于多模型客户端请求的复杂性(可以中途切换其他模型),这里选择通过中间件方式添加跳过所有 Gemini3 思考签名
|
||||
* Due to the complexity of multi-model client requests (which can switch to other models mid-process),
|
||||
* it was decided to add a skip for all Gemini3 thinking signatures via middleware.
|
||||
* @param aiSdkId AI SDK Provider ID
|
||||
* @returns LanguageModelMiddleware
|
||||
*/
|
||||
export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware {
|
||||
const MAGIC_STRING = 'skip_thought_signature_validator'
|
||||
return {
|
||||
middlewareVersion: 'v2',
|
||||
|
||||
transformParams: async ({ params }) => {
|
||||
const transformedParams = { ...params }
|
||||
// Process messages in prompt
|
||||
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
|
||||
transformedParams.prompt = transformedParams.prompt.map((message) => {
|
||||
if (typeof message.content !== 'string') {
|
||||
for (const part of message.content) {
|
||||
const googleOptions = part?.providerOptions?.[aiSdkId]
|
||||
if (googleOptions?.thoughtSignature) {
|
||||
googleOptions.thoughtSignature = MAGIC_STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
return message
|
||||
})
|
||||
}
|
||||
|
||||
return transformedParams
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -7,9 +7,11 @@
|
||||
"src/main/env.d.ts",
|
||||
"src/renderer/src/types/*",
|
||||
"packages/shared/**/*",
|
||||
"packages/aiCore/src/**/*",
|
||||
"scripts",
|
||||
"packages/mcp-trace/**/*",
|
||||
"src/renderer/src/services/traceApi.ts"
|
||||
"src/renderer/src/services/traceApi.ts",
|
||||
"packages/ai-sdk-provider/**/*"
|
||||
],
|
||||
"compilerOptions": {
|
||||
"composite": true,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user