From 1755fd9bcb122f7e3db159db43db336fbe743b24 Mon Sep 17 00:00:00 2001 From: suyao Date: Sat, 3 Jan 2026 16:11:54 +0800 Subject: [PATCH] refactor: streaming adapter --- .../__tests__/AiSdkToAnthropicSSE.test.ts | 186 ++--- .../converters/AnthropicMessageConverter.ts | 256 ++++++ .../converters/OpenAIMessageConverter.ts | 281 +++++++ .../apiServer/adapters/converters/index.ts | 2 + .../adapters/converters/json-schema-to-zod.ts | 141 ++++ .../converters/provider-options-mapper.ts | 194 +++++ .../factory/MessageConverterFactory.ts | 82 ++ .../adapters/factory/StreamAdapterFactory.ts | 127 +++ src/main/apiServer/adapters/factory/index.ts | 1 + .../formatters/AnthropicSSEFormatter.ts | 36 + .../adapters/formatters/OpenAISSEFormatter.ts | 42 + .../apiServer/adapters/formatters/index.ts | 2 + src/main/apiServer/adapters/index.ts | 51 +- src/main/apiServer/adapters/interfaces.ts | 182 +++++ .../{ => stream}/AiSdkToAnthropicSSE.ts | 247 ++---- .../adapters/stream/AiSdkToOpenAISSE.ts | 416 ++++++++++ .../adapters/stream/BaseStreamAdapter.ts | 161 ++++ src/main/apiServer/adapters/stream/index.ts | 3 + src/main/apiServer/routes/chat.ts | 131 +-- src/main/apiServer/routes/messages.ts | 6 +- .../apiServer/services/ProxyStreamService.ts | 465 +++++++++++ .../__tests__/jsonSchemaToZod.test.ts | 2 +- .../__tests__/unified-messages.test.ts | 16 +- .../apiServer/services/chat-completion.ts | 260 ------ .../apiServer/services/unified-messages.ts | 762 ------------------ src/main/apiServer/utils/index.ts | 2 +- 26 files changed, 2666 insertions(+), 1388 deletions(-) create mode 100644 src/main/apiServer/adapters/converters/AnthropicMessageConverter.ts create mode 100644 src/main/apiServer/adapters/converters/OpenAIMessageConverter.ts create mode 100644 src/main/apiServer/adapters/converters/index.ts create mode 100644 src/main/apiServer/adapters/converters/json-schema-to-zod.ts create mode 100644 src/main/apiServer/adapters/converters/provider-options-mapper.ts create mode 100644 src/main/apiServer/adapters/factory/MessageConverterFactory.ts create mode 100644 src/main/apiServer/adapters/factory/StreamAdapterFactory.ts create mode 100644 src/main/apiServer/adapters/factory/index.ts create mode 100644 src/main/apiServer/adapters/formatters/AnthropicSSEFormatter.ts create mode 100644 src/main/apiServer/adapters/formatters/OpenAISSEFormatter.ts create mode 100644 src/main/apiServer/adapters/formatters/index.ts create mode 100644 src/main/apiServer/adapters/interfaces.ts rename src/main/apiServer/adapters/{ => stream}/AiSdkToAnthropicSSE.ts (71%) create mode 100644 src/main/apiServer/adapters/stream/AiSdkToOpenAISSE.ts create mode 100644 src/main/apiServer/adapters/stream/BaseStreamAdapter.ts create mode 100644 src/main/apiServer/adapters/stream/index.ts create mode 100644 src/main/apiServer/services/ProxyStreamService.ts delete mode 100644 src/main/apiServer/services/chat-completion.ts delete mode 100644 src/main/apiServer/services/unified-messages.ts diff --git a/src/main/apiServer/adapters/__tests__/AiSdkToAnthropicSSE.test.ts b/src/main/apiServer/adapters/__tests__/AiSdkToAnthropicSSE.test.ts index bbeed2563c..c8ba7e7fed 100644 --- a/src/main/apiServer/adapters/__tests__/AiSdkToAnthropicSSE.test.ts +++ b/src/main/apiServer/adapters/__tests__/AiSdkToAnthropicSSE.test.ts @@ -1,8 +1,9 @@ import type { RawMessageStreamEvent } from '@anthropic-ai/sdk/resources/messages' import type { FinishReason, LanguageModelUsage, TextStreamPart, ToolSet } from 'ai' -import { describe, expect, it, vi } from 'vitest' +import { describe, expect, it } from 'vitest' -import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '../AiSdkToAnthropicSSE' +import { AnthropicSSEFormatter } from '../formatters/AnthropicSSEFormatter' +import { AiSdkToAnthropicSSE } from '../stream/AiSdkToAnthropicSSE' const createTextDelta = (text: string, id = 'text_0'): TextStreamPart => ({ type: 'text-delta', @@ -24,17 +25,17 @@ const createFinish = ( finishReason: FinishReason | undefined = 'stop', totalUsage?: Partial ): TextStreamPart => { - const defaultUsage: LanguageModelUsage = { + const defaultUsage = { inputTokens: 0, outputTokens: 0, totalTokens: 0 } - const event: TextStreamPart = { + // Cast to TextStreamPart to avoid strict type checking on optional fields + return { type: 'finish', finishReason: finishReason || 'stop', totalUsage: { ...defaultUsage, ...totalUsage } - } - return event + } as TextStreamPart } // Helper to create stream @@ -49,19 +50,32 @@ function createMockStream(events: readonly TextStreamPart[]) { }) } +// Helper to collect all events from output stream +async function collectEvents(stream: ReadableStream): Promise { + const events: RawMessageStreamEvent[] = [] + const reader = stream.getReader() + try { + while (true) { + const { done, value } = await reader.read() + if (done) break + events.push(value) + } + } finally { + reader.releaseLock() + } + return events +} + describe('AiSdkToAnthropicSSE', () => { describe('Text Processing', () => { it('should emit message_start and process text-delta events', async () => { - const events: RawMessageStreamEvent[] = [] - const adapter = new AiSdkToAnthropicSSE({ - model: 'test:model', - onEvent: (event) => events.push(event) - }) + const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' }) // Create a mock stream with text events const stream = createMockStream([createTextDelta('Hello'), createTextDelta(' world'), createFinish('stop')]) - await adapter.processStream(stream) + const outputStream = adapter.transform(stream) + const events = await collectEvents(outputStream) // Verify message_start expect(events[0]).toMatchObject({ @@ -106,11 +120,7 @@ describe('AiSdkToAnthropicSSE', () => { }) it('should handle text-start and text-end events', async () => { - const events: RawMessageStreamEvent[] = [] - const adapter = new AiSdkToAnthropicSSE({ - model: 'test:model', - onEvent: (event) => events.push(event) - }) + const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' }) const stream = createMockStream([ createTextStart(), @@ -119,7 +129,8 @@ describe('AiSdkToAnthropicSSE', () => { createFinish('stop') ]) - await adapter.processStream(stream) + const outputStream = adapter.transform(stream) + const events = await collectEvents(outputStream) // Should have content_block_start, delta, and content_block_stop const blockEvents = events.filter((e) => e.type.startsWith('content_block')) @@ -127,15 +138,12 @@ describe('AiSdkToAnthropicSSE', () => { }) it('should auto-start text block if not explicitly started', async () => { - const events: RawMessageStreamEvent[] = [] - const adapter = new AiSdkToAnthropicSSE({ - model: 'test:model', - onEvent: (event) => events.push(event) - }) + const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' }) const stream = createMockStream([createTextDelta('Auto-started'), createFinish('stop')]) - await adapter.processStream(stream) + const outputStream = adapter.transform(stream) + const events = await collectEvents(outputStream) // Should automatically emit content_block_start expect(events.some((e) => e.type === 'content_block_start')).toBe(true) @@ -144,11 +152,7 @@ describe('AiSdkToAnthropicSSE', () => { describe('Tool Call Processing', () => { it('should emit tool_use block for tool-call events', async () => { - const events: RawMessageStreamEvent[] = [] - const adapter = new AiSdkToAnthropicSSE({ - model: 'test:model', - onEvent: (event) => events.push(event) - }) + const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' }) const stream = createMockStream([ { @@ -160,7 +164,8 @@ describe('AiSdkToAnthropicSSE', () => { createFinish('tool-calls') ]) - await adapter.processStream(stream) + const outputStream = adapter.transform(stream) + const events = await collectEvents(outputStream) // Find tool_use block events const blockStart = events.find((e) => { @@ -195,11 +200,7 @@ describe('AiSdkToAnthropicSSE', () => { }) it('should not create duplicate tool blocks', async () => { - const events: RawMessageStreamEvent[] = [] - const adapter = new AiSdkToAnthropicSSE({ - model: 'test:model', - onEvent: (event) => events.push(event) - }) + const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' }) const toolCallEvent: TextStreamPart = { type: 'tool-call', @@ -209,7 +210,8 @@ describe('AiSdkToAnthropicSSE', () => { } const stream = createMockStream([toolCallEvent, toolCallEvent, createFinish()]) - await adapter.processStream(stream) + const outputStream = adapter.transform(stream) + const events = await collectEvents(outputStream) // Should only have one tool_use block const toolBlocks = events.filter((e) => { @@ -224,11 +226,7 @@ describe('AiSdkToAnthropicSSE', () => { describe('Reasoning/Thinking Processing', () => { it('should emit thinking block for reasoning events', async () => { - const events: RawMessageStreamEvent[] = [] - const adapter = new AiSdkToAnthropicSSE({ - model: 'test:model', - onEvent: (event) => events.push(event) - }) + const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' }) const stream = createMockStream([ { type: 'reasoning-start', id: 'reason_1' }, @@ -237,7 +235,8 @@ describe('AiSdkToAnthropicSSE', () => { createFinish() ]) - await adapter.processStream(stream) + const outputStream = adapter.transform(stream) + const events = await collectEvents(outputStream) // Find thinking block events const blockStart = events.find((e) => { @@ -262,11 +261,7 @@ describe('AiSdkToAnthropicSSE', () => { }) it('should handle multiple thinking blocks', async () => { - const events: RawMessageStreamEvent[] = [] - const adapter = new AiSdkToAnthropicSSE({ - model: 'test:model', - onEvent: (event) => events.push(event) - }) + const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' }) const stream = createMockStream([ { type: 'reasoning-start', id: 'reason_1' }, @@ -278,7 +273,8 @@ describe('AiSdkToAnthropicSSE', () => { createFinish() ]) - await adapter.processStream(stream) + const outputStream = adapter.transform(stream) + const events = await collectEvents(outputStream) // Should have two thinking blocks const thinkingBlocks = events.filter((e) => { @@ -304,15 +300,12 @@ describe('AiSdkToAnthropicSSE', () => { ] for (const { aiSdkReason, expectedReason } of testCases) { - const events: RawMessageStreamEvent[] = [] - const adapter = new AiSdkToAnthropicSSE({ - model: 'test:model', - onEvent: (event) => events.push(event) - }) + const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' }) const stream = createMockStream([createFinish(aiSdkReason)]) - await adapter.processStream(stream) + const outputStream = adapter.transform(stream) + const events = await collectEvents(outputStream) const messageDelta = events.find((e) => e.type === 'message_delta') if (messageDelta && messageDelta.type === 'message_delta') { @@ -324,11 +317,9 @@ describe('AiSdkToAnthropicSSE', () => { describe('Usage Tracking', () => { it('should track token usage', async () => { - const events: RawMessageStreamEvent[] = [] const adapter = new AiSdkToAnthropicSSE({ model: 'test:model', - inputTokens: 100, - onEvent: (event) => events.push(event) + inputTokens: 100 }) const stream = createMockStream([ @@ -340,7 +331,8 @@ describe('AiSdkToAnthropicSSE', () => { }) ]) - await adapter.processStream(stream) + const outputStream = adapter.transform(stream) + const events = await collectEvents(outputStream) const messageDelta = events.find((e) => e.type === 'message_delta') if (messageDelta && messageDelta.type === 'message_delta') { @@ -355,10 +347,7 @@ describe('AiSdkToAnthropicSSE', () => { describe('Non-Streaming Response', () => { it('should build complete message for non-streaming', async () => { - const adapter = new AiSdkToAnthropicSSE({ - model: 'test:model', - onEvent: vi.fn() - }) + const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' }) const stream = createMockStream([ createTextDelta('Hello world'), @@ -371,7 +360,14 @@ describe('AiSdkToAnthropicSSE', () => { createFinish('tool-calls', { inputTokens: 10, outputTokens: 20 }) ]) - await adapter.processStream(stream) + // Consume the stream to populate adapter state + const outputStream = adapter.transform(stream) + const reader = outputStream.getReader() + while (true) { + const { done } = await reader.read() + if (done) break + } + reader.releaseLock() const response = adapter.buildNonStreamingResponse() @@ -403,25 +399,20 @@ describe('AiSdkToAnthropicSSE', () => { describe('Error Handling', () => { it('should throw on error events', async () => { - const adapter = new AiSdkToAnthropicSSE({ - model: 'test:model', - onEvent: vi.fn() - }) + const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' }) const testError = new Error('Test error') const stream = createMockStream([{ type: 'error', error: testError }]) - await expect(adapter.processStream(stream)).rejects.toThrow('Test error') + const outputStream = adapter.transform(stream) + + await expect(collectEvents(outputStream)).rejects.toThrow('Test error') }) }) describe('Edge Cases', () => { it('should handle empty stream', async () => { - const events: RawMessageStreamEvent[] = [] - const adapter = new AiSdkToAnthropicSSE({ - model: 'test:model', - onEvent: (event) => events.push(event) - }) + const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' }) const stream = new ReadableStream>({ start(controller) { @@ -429,7 +420,8 @@ describe('AiSdkToAnthropicSSE', () => { } }) - await adapter.processStream(stream) + const outputStream = adapter.transform(stream) + const events = await collectEvents(outputStream) // Should still emit message_start, message_delta, and message_stop expect(events.some((e) => e.type === 'message_start')).toBe(true) @@ -438,15 +430,12 @@ describe('AiSdkToAnthropicSSE', () => { }) it('should handle empty text deltas', async () => { - const events: RawMessageStreamEvent[] = [] - const adapter = new AiSdkToAnthropicSSE({ - model: 'test:model', - onEvent: (event) => events.push(event) - }) + const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' }) const stream = createMockStream([createTextDelta(''), createTextDelta(''), createFinish()]) - await adapter.processStream(stream) + const outputStream = adapter.transform(stream) + const events = await collectEvents(outputStream) // Should not emit deltas for empty text const deltas = events.filter((e) => e.type === 'content_block_delta') @@ -454,8 +443,9 @@ describe('AiSdkToAnthropicSSE', () => { }) }) - describe('Utility Functions', () => { + describe('AnthropicSSEFormatter', () => { it('should format SSE events correctly', () => { + const formatter = new AnthropicSSEFormatter() const event: RawMessageStreamEvent = { type: 'message_start', message: { @@ -476,7 +466,7 @@ describe('AiSdkToAnthropicSSE', () => { } } - const formatted = formatSSEEvent(event) + const formatted = formatter.formatEvent(event) expect(formatted).toContain('event: message_start') expect(formatted).toContain('data: ') @@ -485,7 +475,8 @@ describe('AiSdkToAnthropicSSE', () => { }) it('should format SSE done marker correctly', () => { - const done = formatSSEDone() + const formatter = new AnthropicSSEFormatter() + const done = formatter.formatDone() expect(done).toBe('data: [DONE]\n\n') }) @@ -495,18 +486,14 @@ describe('AiSdkToAnthropicSSE', () => { it('should use provided message ID', () => { const adapter = new AiSdkToAnthropicSSE({ model: 'test:model', - messageId: 'custom_msg_123', - onEvent: vi.fn() + messageId: 'custom_msg_123' }) expect(adapter.getMessageId()).toBe('custom_msg_123') }) it('should generate message ID if not provided', () => { - const adapter = new AiSdkToAnthropicSSE({ - model: 'test:model', - onEvent: vi.fn() - }) + const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' }) const messageId = adapter.getMessageId() expect(messageId).toMatch(/^msg_/) @@ -514,23 +501,20 @@ describe('AiSdkToAnthropicSSE', () => { }) describe('Input Tokens', () => { - it('should allow setting input tokens', () => { - const events: RawMessageStreamEvent[] = [] - const adapter = new AiSdkToAnthropicSSE({ - model: 'test:model', - onEvent: (event) => events.push(event) - }) + it('should allow setting input tokens', async () => { + const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' }) adapter.setInputTokens(500) const stream = createMockStream([createFinish()]) - return adapter.processStream(stream).then(() => { - const messageStart = events.find((e) => e.type === 'message_start') - if (messageStart && messageStart.type === 'message_start') { - expect(messageStart.message.usage.input_tokens).toBe(500) - } - }) + const outputStream = adapter.transform(stream) + const events = await collectEvents(outputStream) + + const messageStart = events.find((e) => e.type === 'message_start') + if (messageStart && messageStart.type === 'message_start') { + expect(messageStart.message.usage.input_tokens).toBe(500) + } }) }) }) diff --git a/src/main/apiServer/adapters/converters/AnthropicMessageConverter.ts b/src/main/apiServer/adapters/converters/AnthropicMessageConverter.ts new file mode 100644 index 0000000000..c28d8fa605 --- /dev/null +++ b/src/main/apiServer/adapters/converters/AnthropicMessageConverter.ts @@ -0,0 +1,256 @@ +/** + * Anthropic Message Converter + * + * Converts Anthropic Messages API format to AI SDK format. + * Handles messages, tools, and special content types (images, thinking, tool results). + */ + +import type { LanguageModelV2ToolResultOutput } from '@ai-sdk/provider' +import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils' +import type { + ImageBlockParam, + MessageCreateParams, + TextBlockParam, + Tool as AnthropicTool +} from '@anthropic-ai/sdk/resources/messages' +import { isGemini3ModelId } from '@shared/aiCore/middlewares' +import type { Provider } from '@types' +import type { ImagePart, JSONValue, ModelMessage, TextPart, Tool as AiSdkTool } from 'ai' +import { tool, zodSchema } from 'ai' + +import type { IMessageConverter, StreamTextOptions } from '../interfaces' +import { type JsonSchemaLike, jsonSchemaToZod } from './json-schema-to-zod' +import { mapAnthropicThinkingToProviderOptions } from './provider-options-mapper' + +const MAGIC_STRING = 'skip_thought_signature_validator' + +/** + * Sanitize value for JSON serialization + */ +function sanitizeJson(value: unknown): JSONValue { + return JSON.parse(JSON.stringify(value)) +} + +/** + * Convert Anthropic tool result content to AI SDK format + */ +function convertToolResultToAiSdk( + content: string | Array +): LanguageModelV2ToolResultOutput { + if (typeof content === 'string') { + return { type: 'text', value: content } + } + const values: Array<{ type: 'text'; text: string } | { type: 'media'; data: string; mediaType: string }> = [] + for (const block of content) { + if (block.type === 'text') { + values.push({ type: 'text', text: block.text }) + } else if (block.type === 'image') { + values.push({ + type: 'media', + data: block.source.type === 'base64' ? block.source.data : block.source.url, + mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png' + }) + } + } + return { type: 'content', value: values } +} + +/** + * Reasoning cache interface for storing provider-specific reasoning state + */ +export interface ReasoningCache { + get(key: string): unknown + set(key: string, value: unknown): void +} + +/** + * Anthropic Message Converter + * + * Converts Anthropic MessageCreateParams to AI SDK format for unified processing. + */ +export class AnthropicMessageConverter implements IMessageConverter { + private googleReasoningCache?: ReasoningCache + private openRouterReasoningCache?: ReasoningCache + + constructor(options?: { googleReasoningCache?: ReasoningCache; openRouterReasoningCache?: ReasoningCache }) { + this.googleReasoningCache = options?.googleReasoningCache + this.openRouterReasoningCache = options?.openRouterReasoningCache + } + + /** + * Convert Anthropic MessageCreateParams to AI SDK ModelMessage[] + */ + toAiSdkMessages(params: MessageCreateParams): ModelMessage[] { + const messages: ModelMessage[] = [] + + // System message + if (params.system) { + if (typeof params.system === 'string') { + messages.push({ role: 'system', content: params.system }) + } else if (Array.isArray(params.system)) { + const systemText = params.system + .filter((block) => block.type === 'text') + .map((block) => block.text) + .join('\n') + if (systemText) { + messages.push({ role: 'system', content: systemText }) + } + } + } + + // Build tool call ID to name mapping for tool results + const toolCallIdToName = new Map() + for (const msg of params.messages) { + if (Array.isArray(msg.content)) { + for (const block of msg.content) { + if (block.type === 'tool_use') { + toolCallIdToName.set(block.id, block.name) + } + } + } + } + + // User/assistant messages + for (const msg of params.messages) { + if (typeof msg.content === 'string') { + messages.push({ + role: msg.role === 'user' ? 'user' : 'assistant', + content: msg.content + }) + } else if (Array.isArray(msg.content)) { + const textParts: TextPart[] = [] + const imageParts: ImagePart[] = [] + const reasoningParts: ReasoningPart[] = [] + const toolCallParts: ToolCallPart[] = [] + const toolResultParts: ToolResultPart[] = [] + + for (const block of msg.content) { + if (block.type === 'text') { + textParts.push({ type: 'text', text: block.text }) + } else if (block.type === 'thinking') { + reasoningParts.push({ type: 'reasoning', text: block.thinking }) + } else if (block.type === 'redacted_thinking') { + reasoningParts.push({ type: 'reasoning', text: block.data }) + } else if (block.type === 'image') { + const source = block.source + if (source.type === 'base64') { + imageParts.push({ type: 'image', image: `data:${source.media_type};base64,${source.data}` }) + } else if (source.type === 'url') { + imageParts.push({ type: 'image', image: source.url }) + } + } else if (block.type === 'tool_use') { + const options: ProviderOptions = {} + if (isGemini3ModelId(params.model)) { + if (this.googleReasoningCache?.get(`google-${block.name}`)) { + options.google = { + thoughtSignature: MAGIC_STRING + } + } + } + if (this.openRouterReasoningCache?.get(`openrouter-${block.id}`)) { + options.openrouter = { + reasoning_details: + (sanitizeJson(this.openRouterReasoningCache.get(`openrouter-${block.id}`)) as JSONValue[]) || [] + } + } + toolCallParts.push({ + type: 'tool-call', + toolName: block.name, + toolCallId: block.id, + input: block.input, + providerOptions: options + }) + } else if (block.type === 'tool_result') { + const toolName = toolCallIdToName.get(block.tool_use_id) || 'unknown' + toolResultParts.push({ + type: 'tool-result', + toolCallId: block.tool_use_id, + toolName, + output: block.content ? convertToolResultToAiSdk(block.content) : { type: 'text', value: '' } + }) + } + } + + if (toolResultParts.length > 0) { + messages.push({ role: 'tool', content: [...toolResultParts] }) + } + + if (msg.role === 'user') { + const userContent = [...textParts, ...imageParts] + if (userContent.length > 0) { + messages.push({ role: 'user', content: userContent }) + } + } else { + const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts] + if (assistantContent.length > 0) { + let providerOptions: ProviderOptions | undefined = undefined + if (this.openRouterReasoningCache?.get('openrouter')) { + providerOptions = { + openrouter: { + reasoning_details: + (sanitizeJson(this.openRouterReasoningCache.get('openrouter')) as JSONValue[]) || [] + } + } + } else if (isGemini3ModelId(params.model)) { + providerOptions = { + google: { + thoughtSignature: MAGIC_STRING + } + } + } + messages.push({ role: 'assistant', content: assistantContent, providerOptions }) + } + } + } + } + + return messages + } + + /** + * Convert Anthropic tools to AI SDK tools + */ + toAiSdkTools(params: MessageCreateParams): Record | undefined { + const tools = params.tools + if (!tools || tools.length === 0) return undefined + + const aiSdkTools: Record = {} + for (const anthropicTool of tools) { + if (anthropicTool.type === 'bash_20250124') continue + const toolDef = anthropicTool as AnthropicTool + const rawSchema = toolDef.input_schema + const schema = jsonSchemaToZod(rawSchema as JsonSchemaLike) + + const aiTool = tool({ + description: toolDef.description || '', + inputSchema: zodSchema(schema) + }) + + aiSdkTools[toolDef.name] = aiTool + } + return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined + } + + /** + * Extract stream/generation options from Anthropic params + */ + extractStreamOptions(params: MessageCreateParams): StreamTextOptions { + return { + maxOutputTokens: params.max_tokens, + temperature: params.temperature, + topP: params.top_p, + topK: params.top_k, + stopSequences: params.stop_sequences + } + } + + /** + * Extract provider-specific options from Anthropic params + * Maps thinking configuration to provider-specific parameters + */ + extractProviderOptions(provider: Provider, params: MessageCreateParams): ProviderOptions | undefined { + return mapAnthropicThinkingToProviderOptions(provider, params.thinking) + } +} + +export default AnthropicMessageConverter diff --git a/src/main/apiServer/adapters/converters/OpenAIMessageConverter.ts b/src/main/apiServer/adapters/converters/OpenAIMessageConverter.ts new file mode 100644 index 0000000000..de860380ea --- /dev/null +++ b/src/main/apiServer/adapters/converters/OpenAIMessageConverter.ts @@ -0,0 +1,281 @@ +/** + * OpenAI Message Converter + * + * Converts OpenAI Chat Completions API format to AI SDK format. + * Handles messages, tools, and extended features like reasoning_content. + */ + +import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils' +import type { + ChatCompletionAssistantMessageParam, + ChatCompletionContentPart, + ChatCompletionMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionUserMessageParam +} from '@cherrystudio/openai/resources' +import type { ChatCompletionCreateParamsBase } from '@cherrystudio/openai/resources/chat/completions' +import type { Provider } from '@types' +import type { ImagePart, ModelMessage, TextPart, Tool as AiSdkTool } from 'ai' +import { tool, zodSchema } from 'ai' + +import type { IMessageConverter, StreamTextOptions } from '../interfaces' +import { type JsonSchemaLike, jsonSchemaToZod } from './json-schema-to-zod' +import { mapReasoningEffortToProviderOptions } from './provider-options-mapper' + +/** + * Extended ChatCompletionCreateParams with reasoning_effort support + * Extends the base OpenAI params to inherit all standard parameters + */ +export interface ExtendedChatCompletionCreateParams extends ChatCompletionCreateParamsBase { + /** + * Allow additional provider-specific parameters + */ + [key: string]: unknown +} + +/** + * Extended assistant message with reasoning_content support (DeepSeek-style) + */ +interface ExtendedAssistantMessage extends ChatCompletionAssistantMessageParam { + reasoning_content?: string | null +} + +/** + * OpenAI Message Converter + * + * Converts OpenAI Chat Completions API format to AI SDK format. + * Supports standard OpenAI messages plus extended features: + * - reasoning_content (DeepSeek-style thinking) + * - reasoning_effort parameter + */ +export class OpenAIMessageConverter implements IMessageConverter { + /** + * Convert OpenAI ChatCompletionCreateParams to AI SDK ModelMessage[] + */ + toAiSdkMessages(params: ExtendedChatCompletionCreateParams): ModelMessage[] { + const messages: ModelMessage[] = [] + + // Build tool call ID to name mapping for tool results + const toolCallIdToName = new Map() + for (const msg of params.messages) { + if (msg.role === 'assistant') { + const assistantMsg = msg as ChatCompletionAssistantMessageParam + if (assistantMsg.tool_calls) { + for (const toolCall of assistantMsg.tool_calls) { + // Only handle function tool calls + if (toolCall.type === 'function') { + toolCallIdToName.set(toolCall.id, toolCall.function.name) + } + } + } + } + } + + for (const msg of params.messages) { + const converted = this.convertMessage(msg, toolCallIdToName) + if (converted) { + messages.push(...converted) + } + } + + return messages + } + + /** + * Convert a single OpenAI message to AI SDK message(s) + */ + private convertMessage( + msg: ChatCompletionMessageParam, + toolCallIdToName: Map + ): ModelMessage[] | null { + switch (msg.role) { + case 'system': + return this.convertSystemMessage(msg) + case 'user': + return this.convertUserMessage(msg as ChatCompletionUserMessageParam) + case 'assistant': + return this.convertAssistantMessage(msg as ExtendedAssistantMessage) + case 'tool': + return this.convertToolMessage(msg as ChatCompletionToolMessageParam, toolCallIdToName) + case 'function': + // Legacy function messages - skip or handle as needed + return null + default: + return null + } + } + + /** + * Convert system message + */ + private convertSystemMessage(msg: ChatCompletionMessageParam): ModelMessage[] { + if (msg.role !== 'system') return [] + + // Handle string content + if (typeof msg.content === 'string') { + return [{ role: 'system', content: msg.content }] + } + + // Handle array content (system messages can have text parts) + if (Array.isArray(msg.content)) { + const textContent = msg.content + .filter((part): part is { type: 'text'; text: string } => part.type === 'text') + .map((part) => part.text) + .join('\n') + if (textContent) { + return [{ role: 'system', content: textContent }] + } + } + + return [] + } + + /** + * Convert user message + */ + private convertUserMessage(msg: ChatCompletionUserMessageParam): ModelMessage[] { + // Handle string content + if (typeof msg.content === 'string') { + return [{ role: 'user', content: msg.content }] + } + + // Handle array content (text + images) + if (Array.isArray(msg.content)) { + const parts: (TextPart | ImagePart)[] = [] + + for (const part of msg.content as ChatCompletionContentPart[]) { + if (part.type === 'text') { + parts.push({ type: 'text', text: part.text }) + } else if (part.type === 'image_url') { + parts.push({ type: 'image', image: part.image_url.url }) + } + } + + if (parts.length > 0) { + return [{ role: 'user', content: parts }] + } + } + + return [] + } + + /** + * Convert assistant message + */ + private convertAssistantMessage(msg: ExtendedAssistantMessage): ModelMessage[] { + const parts: (TextPart | ReasoningPart | ToolCallPart)[] = [] + + // Handle reasoning_content (DeepSeek-style thinking) + if (msg.reasoning_content) { + parts.push({ type: 'reasoning', text: msg.reasoning_content }) + } + + // Handle text content + if (msg.content) { + if (typeof msg.content === 'string') { + parts.push({ type: 'text', text: msg.content }) + } else if (Array.isArray(msg.content)) { + for (const part of msg.content) { + if (part.type === 'text') { + parts.push({ type: 'text', text: part.text }) + } + } + } + } + + // Handle tool calls + if (msg.tool_calls && msg.tool_calls.length > 0) { + for (const toolCall of msg.tool_calls) { + // Only handle function tool calls + if (toolCall.type !== 'function') continue + + let input: unknown + try { + input = JSON.parse(toolCall.function.arguments) + } catch { + input = { raw: toolCall.function.arguments } + } + + parts.push({ + type: 'tool-call', + toolCallId: toolCall.id, + toolName: toolCall.function.name, + input + }) + } + } + + if (parts.length > 0) { + return [{ role: 'assistant', content: parts }] + } + + return [] + } + + /** + * Convert tool result message + */ + private convertToolMessage( + msg: ChatCompletionToolMessageParam, + toolCallIdToName: Map + ): ModelMessage[] { + const toolName = toolCallIdToName.get(msg.tool_call_id) || 'unknown' + + const toolResultPart: ToolResultPart = { + type: 'tool-result', + toolCallId: msg.tool_call_id, + toolName, + output: { type: 'text', value: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content) } + } + + return [{ role: 'tool', content: [toolResultPart] }] + } + + /** + * Convert OpenAI tools to AI SDK tools + */ + toAiSdkTools(params: ExtendedChatCompletionCreateParams): Record | undefined { + const tools = params.tools + if (!tools || tools.length === 0) return undefined + + const aiSdkTools: Record = {} + + for (const toolDef of tools) { + if (toolDef.type !== 'function') continue + + const rawSchema = toolDef.function.parameters + const schema = rawSchema ? jsonSchemaToZod(rawSchema as JsonSchemaLike) : jsonSchemaToZod({ type: 'object' }) + + const aiTool = tool({ + description: toolDef.function.description || '', + inputSchema: zodSchema(schema) + }) + + aiSdkTools[toolDef.function.name] = aiTool + } + + return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined + } + + /** + * Extract stream/generation options from OpenAI params + */ + extractStreamOptions(params: ExtendedChatCompletionCreateParams): StreamTextOptions { + return { + maxOutputTokens: params.max_tokens as number | undefined, + temperature: params.temperature as number | undefined, + topP: params.top_p as number | undefined, + stopSequences: params.stop as string[] | undefined + } + } + + /** + * Extract provider-specific options from OpenAI params + * Maps reasoning_effort to provider-specific thinking/reasoning parameters + */ + extractProviderOptions(provider: Provider, params: ExtendedChatCompletionCreateParams): ProviderOptions | undefined { + return mapReasoningEffortToProviderOptions(provider, params.reasoning_effort) + } +} + +export default OpenAIMessageConverter diff --git a/src/main/apiServer/adapters/converters/index.ts b/src/main/apiServer/adapters/converters/index.ts new file mode 100644 index 0000000000..84cfbb6982 --- /dev/null +++ b/src/main/apiServer/adapters/converters/index.ts @@ -0,0 +1,2 @@ +export { AnthropicMessageConverter } from './AnthropicMessageConverter' +export { type JsonSchemaLike, jsonSchemaToZod } from './json-schema-to-zod' diff --git a/src/main/apiServer/adapters/converters/json-schema-to-zod.ts b/src/main/apiServer/adapters/converters/json-schema-to-zod.ts new file mode 100644 index 0000000000..3aa483a033 --- /dev/null +++ b/src/main/apiServer/adapters/converters/json-schema-to-zod.ts @@ -0,0 +1,141 @@ +/** + * JSON Schema to Zod Converter + * + * Converts JSON Schema definitions to Zod schemas for runtime validation. + * This is used to convert tool input schemas from Anthropic format to AI SDK format. + */ + +import type { JSONSchema7 } from '@ai-sdk/provider' +import * as z from 'zod' + +/** + * JSON Schema type alias + */ +export type JsonSchemaLike = JSONSchema7 + +/** + * Convert JSON Schema to Zod schema + * + * Handles: + * - Primitive types (string, number, integer, boolean, null) + * - Complex types (object, array) + * - Enums + * - Union types (type: ["string", "null"]) + * - Required/optional fields + * - Validation constraints (min/max, pattern, etc.) + * + * @example + * ```typescript + * const zodSchema = jsonSchemaToZod({ + * type: 'object', + * properties: { + * name: { type: 'string' }, + * age: { type: 'integer', minimum: 0 } + * }, + * required: ['name'] + * }) + * ``` + */ +export function jsonSchemaToZod(schema: JsonSchemaLike): z.ZodTypeAny { + const schemaType = schema.type + const enumValues = schema.enum + const description = schema.description + + // Handle enum first + if (enumValues && Array.isArray(enumValues) && enumValues.length > 0) { + if (enumValues.every((v) => typeof v === 'string')) { + const zodEnum = z.enum(enumValues as [string, ...string[]]) + return description ? zodEnum.describe(description) : zodEnum + } + // For non-string enums, use union of literals + const literals = enumValues.map((v) => z.literal(v as string | number | boolean)) + if (literals.length === 1) { + return description ? literals[0].describe(description) : literals[0] + } + const zodUnion = z.union(literals as unknown as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]]) + return description ? zodUnion.describe(description) : zodUnion + } + + // Handle union types (type: ["string", "null"]) + if (Array.isArray(schemaType)) { + const schemas = schemaType.map((t) => + jsonSchemaToZod({ + ...schema, + type: t, + enum: undefined + }) + ) + if (schemas.length === 1) { + return schemas[0] + } + return z.union(schemas as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]]) + } + + // Handle by type + switch (schemaType) { + case 'string': { + let zodString = z.string() + if (typeof schema.minLength === 'number') zodString = zodString.min(schema.minLength) + if (typeof schema.maxLength === 'number') zodString = zodString.max(schema.maxLength) + if (typeof schema.pattern === 'string') zodString = zodString.regex(new RegExp(schema.pattern)) + return description ? zodString.describe(description) : zodString + } + + case 'number': + case 'integer': { + let zodNumber = schemaType === 'integer' ? z.number().int() : z.number() + if (typeof schema.minimum === 'number') zodNumber = zodNumber.min(schema.minimum) + if (typeof schema.maximum === 'number') zodNumber = zodNumber.max(schema.maximum) + return description ? zodNumber.describe(description) : zodNumber + } + + case 'boolean': { + const zodBoolean = z.boolean() + return description ? zodBoolean.describe(description) : zodBoolean + } + + case 'null': + return z.null() + + case 'array': { + const items = schema.items + let zodArray: z.ZodArray + if (items && typeof items === 'object' && !Array.isArray(items)) { + zodArray = z.array(jsonSchemaToZod(items as JsonSchemaLike)) + } else { + zodArray = z.array(z.unknown()) + } + if (typeof schema.minItems === 'number') zodArray = zodArray.min(schema.minItems) + if (typeof schema.maxItems === 'number') zodArray = zodArray.max(schema.maxItems) + return description ? zodArray.describe(description) : zodArray + } + + case 'object': { + const properties = schema.properties + const required = schema.required || [] + + // Always use z.object() to ensure "properties" field is present in output schema + // OpenAI requires explicit properties field even for empty objects + const shape: Record = {} + if (properties && typeof properties === 'object') { + for (const [key, propSchema] of Object.entries(properties)) { + if (typeof propSchema === 'boolean') { + shape[key] = propSchema ? z.unknown() : z.never() + } else { + const zodProp = jsonSchemaToZod(propSchema as JsonSchemaLike) + shape[key] = required.includes(key) ? zodProp : zodProp.optional() + } + } + } + + const zodObject = z.object(shape) + return description ? zodObject.describe(description) : zodObject + } + + default: + // Unknown type, use z.unknown() + return z.unknown() + } +} + +export default jsonSchemaToZod diff --git a/src/main/apiServer/adapters/converters/provider-options-mapper.ts b/src/main/apiServer/adapters/converters/provider-options-mapper.ts new file mode 100644 index 0000000000..6f7024dc19 --- /dev/null +++ b/src/main/apiServer/adapters/converters/provider-options-mapper.ts @@ -0,0 +1,194 @@ +/** + * Provider Options Mapper + * + * Maps input format-specific thinking/reasoning configuration to + * AI SDK provider-specific options. + * + * TODO: Refactor this module: + * 1. Move shared reasoning config from src/renderer/src/config/models/reasoning.ts to @shared + * 2. Reuse MODEL_SUPPORTED_REASONING_EFFORT for budgetMap instead of hardcoding + * 3. For unsupported providers, pass through reasoning params in OpenAI-compatible format + * instead of returning undefined (all requests should transparently forward reasoning config) + * 4. Both Anthropic and OpenAI converters should handle OpenAI-compatible mapping + */ + +import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock' +import type { AnthropicProviderOptions } from '@ai-sdk/anthropic' +import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google' +import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai' +import type { ProviderOptions } from '@ai-sdk/provider-utils' +import type { XaiProviderOptions } from '@ai-sdk/xai' +import type { MessageCreateParams } from '@anthropic-ai/sdk/resources/messages' +import { ReasoningEffort } from '@cherrystudio/openai/resources' +import type { OpenRouterProviderOptions } from '@openrouter/ai-sdk-provider' +import { SystemProviderIds } from '@shared/types' +import { isAnthropicProvider, isAwsBedrockProvider, isGeminiProvider, isOpenAIProvider } from '@shared/utils/provider' +import type { Provider } from '@types' + +/** + * Map Anthropic thinking configuration to AI SDK provider options + * + * Converts Anthropic's thinking.type and budget_tokens to provider-specific + * parameters for various AI providers. + */ +export function mapAnthropicThinkingToProviderOptions( + provider: Provider, + config: MessageCreateParams['thinking'] +): ProviderOptions | undefined { + if (!config) return undefined + + // Anthropic provider + if (isAnthropicProvider(provider)) { + return { + anthropic: { + thinking: { + type: config.type, + budgetTokens: config.type === 'enabled' ? config.budget_tokens : undefined + } + } as AnthropicProviderOptions + } + } + + // Google/Gemini provider + if (isGeminiProvider(provider)) { + return { + google: { + thinkingConfig: { + thinkingBudget: config.type === 'enabled' ? config.budget_tokens : -1, + includeThoughts: config.type === 'enabled' + } + } as GoogleGenerativeAIProviderOptions + } + } + + // OpenAI provider (Responses API) + if (isOpenAIProvider(provider)) { + return { + openai: { + reasoningEffort: config.type === 'enabled' ? 'high' : 'none' + } as OpenAIResponsesProviderOptions + } + } + + // OpenRouter provider + if (provider.id === SystemProviderIds.openrouter) { + return { + openrouter: { + reasoning: { + enabled: config.type === 'enabled', + effort: 'high' + } + } as OpenRouterProviderOptions + } + } + + // XAI/Grok provider + if (provider.id === SystemProviderIds.grok) { + return { + xai: { + reasoningEffort: config.type === 'enabled' ? 'high' : undefined + } as XaiProviderOptions + } + } + + // AWS Bedrock provider + if (isAwsBedrockProvider(provider)) { + return { + bedrock: { + reasoningConfig: { + type: config.type, + budgetTokens: config.type === 'enabled' ? config.budget_tokens : undefined + } + } as BedrockProviderOptions + } + } + + // For other providers, thinking options are not automatically mapped + return undefined +} + +/** + * Map OpenAI-style reasoning_effort to AI SDK provider options + * + * Converts reasoning_effort (low/medium/high) to provider-specific + * thinking/reasoning parameters. + */ +export function mapReasoningEffortToProviderOptions( + provider: Provider, + reasoningEffort?: ReasoningEffort +): ProviderOptions | undefined { + if (!reasoningEffort) return undefined + + // TODO: Import from @shared/config/reasoning instead of hardcoding + // Should reuse MODEL_SUPPORTED_REASONING_EFFORT from reasoning.ts + const budgetMap = { low: 5000, medium: 10000, high: 20000 } + + // Anthropic: Map to thinking.budgetTokens + if (isAnthropicProvider(provider)) { + return { + anthropic: { + thinking: { + type: 'enabled', + budgetTokens: budgetMap[reasoningEffort] + } + } as AnthropicProviderOptions + } + } + + // Google/Gemini: Map to thinkingConfig.thinkingBudget + if (isGeminiProvider(provider)) { + return { + google: { + thinkingConfig: { + thinkingBudget: budgetMap[reasoningEffort], + includeThoughts: true + } + } as GoogleGenerativeAIProviderOptions + } + } + + // OpenAI: Use reasoningEffort directly + if (isOpenAIProvider(provider)) { + return { + openai: { + reasoningEffort: reasoningEffort === 'low' ? 'none' : reasoningEffort + } as OpenAIResponsesProviderOptions + } + } + + // OpenRouter: Map to reasoning.effort + if (provider.id === SystemProviderIds.openrouter) { + return { + openrouter: { + reasoning: { + enabled: true, + effort: reasoningEffort + } + } as OpenRouterProviderOptions + } + } + + // XAI/Grok: Map to reasoningEffort + if (provider.id === SystemProviderIds.grok) { + return { + xai: { + reasoningEffort: reasoningEffort === 'low' ? undefined : reasoningEffort + } as XaiProviderOptions + } + } + + // AWS Bedrock: Map to reasoningConfig + if (isAwsBedrockProvider(provider)) { + return { + bedrock: { + reasoningConfig: { + type: 'enabled', + budgetTokens: budgetMap[reasoningEffort] + } + } as BedrockProviderOptions + } + } + + // For other providers, reasoning effort is not automatically mapped + return undefined +} diff --git a/src/main/apiServer/adapters/factory/MessageConverterFactory.ts b/src/main/apiServer/adapters/factory/MessageConverterFactory.ts new file mode 100644 index 0000000000..9513bc9462 --- /dev/null +++ b/src/main/apiServer/adapters/factory/MessageConverterFactory.ts @@ -0,0 +1,82 @@ +/** + * Message Converter Factory + * + * Factory for creating message converters based on input format. + * Uses generics for type-safe converter creation. + */ + +import type { MessageCreateParams } from '@anthropic-ai/sdk/resources/messages' + +import { AnthropicMessageConverter, type ReasoningCache } from '../converters/AnthropicMessageConverter' +import { type ExtendedChatCompletionCreateParams, OpenAIMessageConverter } from '../converters/OpenAIMessageConverter' +import type { IMessageConverter, InputFormat } from '../interfaces' + +/** + * Type mapping from input format to parameter type + */ +export type InputParamsMap = { + openai: ExtendedChatCompletionCreateParams + anthropic: MessageCreateParams +} + +/** + * Options for creating converters + */ +export interface ConverterOptions { + googleReasoningCache?: ReasoningCache + openRouterReasoningCache?: ReasoningCache +} + +/** + * Message Converter Factory + * + * Creates message converters for different input formats with type safety. + * + * @example + * ```typescript + * const converter = MessageConverterFactory.create('anthropic', { + * googleReasoningCache, + * openRouterReasoningCache + * }) + * // converter is typed as IMessageConverter + * const messages = converter.toAiSdkMessages(params) + * const options = converter.extractStreamOptions(params) + * ``` + */ +export class MessageConverterFactory { + /** + * Create a message converter for the specified input format + * + * @param format - The input format ('openai' | 'anthropic') + * @param options - Optional converter options + * @returns A typed message converter instance + */ + static create( + format: T, + options: ConverterOptions = {} + ): IMessageConverter { + if (format === 'openai') { + return new OpenAIMessageConverter() as IMessageConverter + } + return new AnthropicMessageConverter({ + googleReasoningCache: options.googleReasoningCache, + openRouterReasoningCache: options.openRouterReasoningCache + }) as IMessageConverter + } + + /** + * Check if a format is supported + */ + static supportsFormat(format: string): format is InputFormat { + return format === 'openai' || format === 'anthropic' + } + + /** + * Get list of all supported formats + */ + static getSupportedFormats(): InputFormat[] { + return ['openai', 'anthropic'] + } +} + +export default MessageConverterFactory diff --git a/src/main/apiServer/adapters/factory/StreamAdapterFactory.ts b/src/main/apiServer/adapters/factory/StreamAdapterFactory.ts new file mode 100644 index 0000000000..3f686de44d --- /dev/null +++ b/src/main/apiServer/adapters/factory/StreamAdapterFactory.ts @@ -0,0 +1,127 @@ +/** + * Stream Adapter Factory + * + * Factory for creating stream adapters based on output format. + * Uses a registry pattern for extensibility. + */ + +import { AnthropicSSEFormatter } from '../formatters/AnthropicSSEFormatter' +import { OpenAISSEFormatter } from '../formatters/OpenAISSEFormatter' +import type { ISSEFormatter, IStreamAdapter, OutputFormat, StreamAdapterOptions } from '../interfaces' +import { AiSdkToAnthropicSSE } from '../stream/AiSdkToAnthropicSSE' +import { AiSdkToOpenAISSE } from '../stream/AiSdkToOpenAISSE' + +/** + * Registry entry for adapter and formatter classes + */ +interface RegistryEntry { + adapterClass: new (options: StreamAdapterOptions) => IStreamAdapter + formatterClass: new () => ISSEFormatter +} + +/** + * Stream Adapter Factory + * + * Creates stream adapters and formatters for different output formats. + * + * @example + * ```typescript + * const adapter = StreamAdapterFactory.createAdapter('anthropic', { model: 'claude-3' }) + * const outputStream = adapter.transform(aiSdkStream) + * + * const formatter = StreamAdapterFactory.getFormatter('anthropic') + * for await (const event of outputStream) { + * response.write(formatter.formatEvent(event)) + * } + * response.write(formatter.formatDone()) + * ``` + */ +export class StreamAdapterFactory { + private static registry = new Map([ + [ + 'anthropic', + { + adapterClass: AiSdkToAnthropicSSE, + formatterClass: AnthropicSSEFormatter + } + ], + [ + 'openai', + { + adapterClass: AiSdkToOpenAISSE, + formatterClass: OpenAISSEFormatter + } + ] + ]) + + /** + * Create a stream adapter for the specified output format + * + * @param format - The target output format + * @param options - Adapter options (model, messageId, etc.) + * @returns A stream adapter instance + * @throws Error if format is not supported + */ + static createAdapter(format: OutputFormat, options: StreamAdapterOptions): IStreamAdapter { + const entry = this.registry.get(format) + if (!entry) { + throw new Error( + `Unsupported output format: ${format}. Supported formats: ${this.getSupportedFormats().join(', ')}` + ) + } + return new entry.adapterClass(options) + } + + /** + * Get an SSE formatter for the specified output format + * + * @param format - The target output format + * @returns An SSE formatter instance + * @throws Error if format is not supported + */ + static getFormatter(format: OutputFormat): ISSEFormatter { + const entry = this.registry.get(format) + if (!entry) { + throw new Error( + `Unsupported output format: ${format}. Supported formats: ${this.getSupportedFormats().join(', ')}` + ) + } + return new entry.formatterClass() + } + + /** + * Check if a format is supported + * + * @param format - The format to check + * @returns true if the format is supported + */ + static supportsFormat(format: OutputFormat): boolean { + return this.registry.has(format) + } + + /** + * Get list of all supported formats + * + * @returns Array of supported format names + */ + static getSupportedFormats(): OutputFormat[] { + return Array.from(this.registry.keys()) + } + + /** + * Register a new adapter and formatter for a format + * + * @param format - The format name + * @param adapterClass - The adapter class constructor + * @param formatterClass - The formatter class constructor + */ + static registerAdapter( + format: OutputFormat, + adapterClass: new (options: StreamAdapterOptions) => IStreamAdapter, + formatterClass: new () => ISSEFormatter + ): void { + this.registry.set(format, { adapterClass, formatterClass }) + } +} + +export default StreamAdapterFactory diff --git a/src/main/apiServer/adapters/factory/index.ts b/src/main/apiServer/adapters/factory/index.ts new file mode 100644 index 0000000000..6e72bb1fcc --- /dev/null +++ b/src/main/apiServer/adapters/factory/index.ts @@ -0,0 +1 @@ +export { StreamAdapterFactory } from './StreamAdapterFactory' diff --git a/src/main/apiServer/adapters/formatters/AnthropicSSEFormatter.ts b/src/main/apiServer/adapters/formatters/AnthropicSSEFormatter.ts new file mode 100644 index 0000000000..e3ed239e22 --- /dev/null +++ b/src/main/apiServer/adapters/formatters/AnthropicSSEFormatter.ts @@ -0,0 +1,36 @@ +/** + * Anthropic SSE Formatter + * + * Formats Anthropic message stream events for Server-Sent Events. + */ + +import type { RawMessageStreamEvent } from '@anthropic-ai/sdk/resources/messages' + +import type { ISSEFormatter } from '../interfaces' + +/** + * Anthropic SSE Formatter + * + * Formats events according to Anthropic's streaming API specification: + * - event: {type}\n + * - data: {json}\n\n + * + * @see https://docs.anthropic.com/en/api/messages-streaming + */ +export class AnthropicSSEFormatter implements ISSEFormatter { + /** + * Format an Anthropic event for SSE streaming + */ + formatEvent(event: RawMessageStreamEvent): string { + return `event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n` + } + + /** + * Format the stream termination marker + */ + formatDone(): string { + return 'data: [DONE]\n\n' + } +} + +export default AnthropicSSEFormatter diff --git a/src/main/apiServer/adapters/formatters/OpenAISSEFormatter.ts b/src/main/apiServer/adapters/formatters/OpenAISSEFormatter.ts new file mode 100644 index 0000000000..e425ce4f57 --- /dev/null +++ b/src/main/apiServer/adapters/formatters/OpenAISSEFormatter.ts @@ -0,0 +1,42 @@ +/** + * OpenAI Compatible SSE Formatter + * + * Formats OpenAI-compatible chat completion stream events for Server-Sent Events. + * Supports extended features like reasoning_content used by DeepSeek and other providers. + */ + +import type { ISSEFormatter } from '../interfaces' +import type { OpenAICompatibleChunk } from '../stream/AiSdkToOpenAISSE' + +/** + * Re-export the OpenAI-compatible chunk type for convenience + */ +export type { OpenAICompatibleChunk as ChatCompletionChunk } from '../stream/AiSdkToOpenAISSE' + +/** + * OpenAI Compatible SSE Formatter + * + * Formats events according to OpenAI's streaming API specification: + * - data: {json}\n\n + * + * Supports extended fields like reasoning_content for OpenAI-compatible providers. + * + * @see https://platform.openai.com/docs/api-reference/chat/streaming + */ +export class OpenAISSEFormatter implements ISSEFormatter { + /** + * Format an OpenAI-compatible event for SSE streaming + */ + formatEvent(event: OpenAICompatibleChunk): string { + return `data: ${JSON.stringify(event)}\n\n` + } + + /** + * Format the stream termination marker + */ + formatDone(): string { + return 'data: [DONE]\n\n' + } +} + +export default OpenAISSEFormatter diff --git a/src/main/apiServer/adapters/formatters/index.ts b/src/main/apiServer/adapters/formatters/index.ts new file mode 100644 index 0000000000..f7f635d9d5 --- /dev/null +++ b/src/main/apiServer/adapters/formatters/index.ts @@ -0,0 +1,2 @@ +export { AnthropicSSEFormatter } from './AnthropicSSEFormatter' +export { type ChatCompletionChunk, OpenAISSEFormatter } from './OpenAISSEFormatter' diff --git a/src/main/apiServer/adapters/index.ts b/src/main/apiServer/adapters/index.ts index a19db9594e..410aebc3b0 100644 --- a/src/main/apiServer/adapters/index.ts +++ b/src/main/apiServer/adapters/index.ts @@ -1,13 +1,48 @@ /** - * Shared Adapters + * API Server Adapters * - * This module exports adapters for converting between different AI API formats. + * This module provides adapters for converting between different AI API formats. + * + * Architecture: + * - Stream adapters: Convert AI SDK streams to various output formats (Anthropic, OpenAI) + * - Message converters: Convert input message formats to AI SDK format + * - SSE formatters: Format events for Server-Sent Events streaming + * - Factory: Creates adapters and formatters based on output format */ +// Stream Adapters +export { AiSdkToAnthropicSSE } from './stream/AiSdkToAnthropicSSE' +export { AiSdkToOpenAISSE } from './stream/AiSdkToOpenAISSE' +export { BaseStreamAdapter } from './stream/BaseStreamAdapter' + +// Message Converters +export { AnthropicMessageConverter, type ReasoningCache } from './converters/AnthropicMessageConverter' +export { type JsonSchemaLike, jsonSchemaToZod } from './converters/json-schema-to-zod' +export { type ExtendedChatCompletionCreateParams, OpenAIMessageConverter } from './converters/OpenAIMessageConverter' + +// SSE Formatters +export { AnthropicSSEFormatter } from './formatters/AnthropicSSEFormatter' +export { type ChatCompletionChunk, OpenAISSEFormatter } from './formatters/OpenAISSEFormatter' + +// Factory export { - AiSdkToAnthropicSSE, - type AiSdkToAnthropicSSEOptions, - formatSSEDone, - formatSSEEvent, - type SSEEventCallback -} from './AiSdkToAnthropicSSE' + type ConverterOptions, + type InputParamsMap, + MessageConverterFactory +} from './factory/MessageConverterFactory' +export { StreamAdapterFactory } from './factory/StreamAdapterFactory' + +// Interfaces +export type { + AdapterRegistryEntry, + AdapterState, + ContentBlockState, + IMessageConverter, + InputFormat, + ISSEFormatter, + IStreamAdapter, + OutputFormat, + StreamAdapterConstructor, + StreamAdapterOptions, + StreamTextOptions +} from './interfaces' diff --git a/src/main/apiServer/adapters/interfaces.ts b/src/main/apiServer/adapters/interfaces.ts new file mode 100644 index 0000000000..b4b4f7b93c --- /dev/null +++ b/src/main/apiServer/adapters/interfaces.ts @@ -0,0 +1,182 @@ +/** + * Core interfaces for the API Server adapter system + * + * This module defines the contracts for: + * - Stream adapters: Transform AI SDK streams to various output formats + * - Message converters: Convert between API message formats + * - SSE formatters: Format events for Server-Sent Events + */ + +import type { ProviderOptions } from '@ai-sdk/provider-utils' +import type { Provider } from '@types' +import type { ModelMessage, TextStreamPart, ToolSet } from 'ai' + +/** + * Supported output formats for stream adapters + */ +export type OutputFormat = 'anthropic' | 'openai' | 'gemini' | 'openai-responses' + +/** + * Supported input formats for message converters + */ +export type InputFormat = 'anthropic' | 'openai' + +/** + * Stream text options extracted from input params + * These are the common parameters used by AI SDK's streamText/generateText + */ +export interface StreamTextOptions { + maxOutputTokens?: number + temperature?: number + topP?: number + topK?: number + stopSequences?: string[] +} + +/** + * Stream Adapter Interface + * + * Uses TransformStream pattern for composability: + * ``` + * input.pipeThrough(adapter1.getTransformStream()).pipeThrough(adapter2.getTransformStream()) + * ``` + */ +export interface IStreamAdapter { + /** + * Transform AI SDK stream to target format stream + * @param input - ReadableStream from AI SDK's fullStream + * @returns ReadableStream of formatted output events + */ + transform(input: ReadableStream>): ReadableStream + + /** + * Get the internal TransformStream for advanced use cases + */ + getTransformStream(): TransformStream, TOutputEvent> + + /** + * Build a non-streaming response from accumulated state + * Call after stream is fully consumed + */ + buildNonStreamingResponse(): unknown + + /** + * Get the message ID for this adapter instance + */ + getMessageId(): string + + /** + * Set input token count (for usage tracking) + */ + setInputTokens(count: number): void +} + +/** + * Options for creating stream adapters + */ +export interface StreamAdapterOptions { + /** Model identifier (e.g., "anthropic:claude-3-opus") */ + model: string + /** Optional message ID, auto-generated if not provided */ + messageId?: string + /** Initial input token count */ + inputTokens?: number +} + +/** + * Message Converter Interface + * + * Converts between different API message formats and AI SDK format. + * Each converter handles a specific input format (OpenAI, Anthropic, etc.) + */ +export interface IMessageConverter { + /** + * Convert input params to AI SDK ModelMessage[] + */ + toAiSdkMessages(params: TInputParams): ModelMessage[] + + /** + * Convert input tools to AI SDK tools format + */ + toAiSdkTools?(params: TInputParams): ToolSet | undefined + + /** + * Extract stream/generation options from input params + * Maps format-specific parameters to AI SDK common options + */ + extractStreamOptions(params: TInputParams): StreamTextOptions + + /** + * Extract provider-specific options from input params + * Handles thinking/reasoning configuration based on provider type + */ + extractProviderOptions(provider: Provider, params: TInputParams): ProviderOptions | undefined +} + +/** + * SSE Formatter Interface + * + * Formats events for Server-Sent Events streaming + */ +export interface ISSEFormatter { + /** + * Format an event for SSE streaming + * @returns Formatted string like "event: type\ndata: {...}\n\n" + */ + formatEvent(event: TEvent): string + + /** + * Format the stream termination marker + * @returns Done marker like "data: [DONE]\n\n" + */ + formatDone(): string +} + +/** + * Content block state for tracking streaming content + */ +export interface ContentBlockState { + type: 'text' | 'tool_use' | 'thinking' + index: number + started: boolean + content: string + // For tool_use blocks + toolId?: string + toolName?: string + toolInput?: string +} + +/** + * Adapter state for tracking stream processing + */ +export interface AdapterState { + messageId: string + model: string + inputTokens: number + outputTokens: number + cacheInputTokens: number + currentBlockIndex: number + blocks: Map + textBlockIndex: number | null + thinkingBlocks: Map + currentThinkingId: string | null + toolBlocks: Map + stopReason: string | null + hasEmittedMessageStart: boolean +} + +/** + * Constructor type for stream adapters + */ +export type StreamAdapterConstructor = new ( + options: StreamAdapterOptions +) => IStreamAdapter + +/** + * Registry entry for adapter factory + */ +export interface AdapterRegistryEntry { + format: OutputFormat + adapterClass: StreamAdapterConstructor + formatterClass: new () => ISSEFormatter +} diff --git a/src/main/apiServer/adapters/AiSdkToAnthropicSSE.ts b/src/main/apiServer/adapters/stream/AiSdkToAnthropicSSE.ts similarity index 71% rename from src/main/apiServer/adapters/AiSdkToAnthropicSSE.ts rename to src/main/apiServer/adapters/stream/AiSdkToAnthropicSSE.ts index 9ef19c0b9d..704f302dfc 100644 --- a/src/main/apiServer/adapters/AiSdkToAnthropicSSE.ts +++ b/src/main/apiServer/adapters/stream/AiSdkToAnthropicSSE.ts @@ -36,109 +36,67 @@ import type { Usage } from '@anthropic-ai/sdk/resources/messages' import { loggerService } from '@logger' -import { type FinishReason, type LanguageModelUsage, type TextStreamPart, type ToolSet } from 'ai' +import type { FinishReason, LanguageModelUsage, TextStreamPart, ToolSet } from 'ai' -import { googleReasoningCache, openRouterReasoningCache } from '../services/reasoning-cache' +import { googleReasoningCache, openRouterReasoningCache } from '../../services/reasoning-cache' +import type { StreamAdapterOptions } from '../interfaces' +import { BaseStreamAdapter } from './BaseStreamAdapter' const logger = loggerService.withContext('AiSdkToAnthropicSSE') -interface ContentBlockState { - type: 'text' | 'tool_use' | 'thinking' - index: number - started: boolean - content: string - // For tool_use blocks - toolId?: string - toolName?: string - toolInput?: string -} - -interface AdapterState { - messageId: string - model: string - inputTokens: number - outputTokens: number - cacheInputTokens: number - currentBlockIndex: number - blocks: Map - textBlockIndex: number | null - // Track multiple thinking blocks by their reasoning ID - thinkingBlocks: Map // reasoningId -> blockIndex - currentThinkingId: string | null // Currently active thinking block ID - toolBlocks: Map // toolCallId -> blockIndex - stopReason: StopReason | null - hasEmittedMessageStart: boolean -} - -export type SSEEventCallback = (event: RawMessageStreamEvent) => void - -export interface AiSdkToAnthropicSSEOptions { - model: string - messageId?: string - inputTokens?: number - onEvent: SSEEventCallback -} - /** * Adapter that converts AI SDK fullStream events to Anthropic SSE events + * + * Uses TransformStream for composable stream processing: + * ``` + * const adapter = new AiSdkToAnthropicSSE({ model: 'claude-3' }) + * const outputStream = adapter.transform(aiSdkStream) + * ``` */ -export class AiSdkToAnthropicSSE { - private state: AdapterState - private onEvent: SSEEventCallback - - constructor(options: AiSdkToAnthropicSSEOptions) { - this.onEvent = options.onEvent - this.state = { - messageId: options.messageId || `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`, - model: options.model, - inputTokens: options.inputTokens || 0, - outputTokens: 0, - cacheInputTokens: 0, - currentBlockIndex: 0, - blocks: new Map(), - textBlockIndex: null, - thinkingBlocks: new Map(), - currentThinkingId: null, - toolBlocks: new Map(), - stopReason: null, - hasEmittedMessageStart: false - } +export class AiSdkToAnthropicSSE extends BaseStreamAdapter { + constructor(options: StreamAdapterOptions) { + super(options) } /** - * Process the AI SDK stream and emit Anthropic SSE events + * Emit the initial message_start event */ - async processStream(fullStream: ReadableStream>): Promise { - const reader = fullStream.getReader() + protected emitMessageStart(): void { + if (this.state.hasEmittedMessageStart) return - try { - // Emit message_start at the beginning - this.emitMessageStart() + this.state.hasEmittedMessageStart = true - while (true) { - const { done, value } = await reader.read() - - if (done) { - break - } - - this.processChunk(value) - } - - // Ensure all blocks are closed and emit final events - this.finalize() - } catch (error) { - await reader.cancel() - throw error - } finally { - reader.releaseLock() + const usage: Usage = { + input_tokens: this.state.inputTokens, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + server_tool_use: null } + + const message: Message = { + id: this.state.messageId, + type: 'message', + role: 'assistant', + content: [], + model: this.state.model, + stop_reason: null, + stop_sequence: null, + usage + } + + const event: RawMessageStartEvent = { + type: 'message_start', + message + } + + this.emit(event) } /** * Process a single AI SDK chunk and emit corresponding Anthropic events */ - private processChunk(chunk: TextStreamPart): void { + protected processChunk(chunk: TextStreamPart): void { logger.silly('AiSdkToAnthropicSSE - Processing chunk:', { chunk: JSON.stringify(chunk) }) switch (chunk.type) { // === Text Events === @@ -200,13 +158,7 @@ export class AiSdkToAnthropicSSE { break case 'tool-result': - // this.handleToolResult({ - // type: 'tool-result', - // toolCallId: chunk.toolCallId, - // toolName: chunk.toolName, - // args: chunk.input, - // result: chunk.output - // }) + // Tool results are handled differently in Anthropic format break case 'finish-step': @@ -222,49 +174,15 @@ export class AiSdkToAnthropicSSE { case 'error': throw chunk.error - // Ignore other event types default: break } } - private emitMessageStart(): void { - if (this.state.hasEmittedMessageStart) return - - this.state.hasEmittedMessageStart = true - - const usage: Usage = { - input_tokens: this.state.inputTokens, - output_tokens: 0, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - server_tool_use: null - } - - const message: Message = { - id: this.state.messageId, - type: 'message', - role: 'assistant', - content: [], - model: this.state.model, - stop_reason: null, - stop_sequence: null, - usage - } - - const event: RawMessageStartEvent = { - type: 'message_start', - message - } - - this.onEvent(event) - } - private startTextBlock(): void { - // If we already have a text block, don't create another if (this.state.textBlockIndex !== null) return - const index = this.state.currentBlockIndex++ + const index = this.allocateBlockIndex() this.state.textBlockIndex = index this.state.blocks.set(index, { type: 'text', @@ -285,13 +203,12 @@ export class AiSdkToAnthropicSSE { content_block: contentBlock } - this.onEvent(event) + this.emit(event) } private emitTextDelta(text: string): void { if (!text) return - // Auto-start text block if not started if (this.state.textBlockIndex === null) { this.startTextBlock() } @@ -313,7 +230,7 @@ export class AiSdkToAnthropicSSE { delta } - this.onEvent(event) + this.emit(event) } private stopTextBlock(): void { @@ -326,15 +243,14 @@ export class AiSdkToAnthropicSSE { index } - this.onEvent(event) + this.emit(event) this.state.textBlockIndex = null } private startThinkingBlock(reasoningId: string): void { - // Check if this thinking block already exists if (this.state.thinkingBlocks.has(reasoningId)) return - const index = this.state.currentBlockIndex++ + const index = this.allocateBlockIndex() this.state.thinkingBlocks.set(reasoningId, index) this.state.currentThinkingId = reasoningId this.state.blocks.set(index, { @@ -356,16 +272,14 @@ export class AiSdkToAnthropicSSE { content_block: contentBlock } - this.onEvent(event) + this.emit(event) } private emitThinkingDelta(text: string, reasoningId?: string): void { if (!text) return - // Determine which thinking block to use const targetId = reasoningId || this.state.currentThinkingId if (!targetId) { - // Auto-start thinking block if not started const newId = `reasoning_${Date.now()}` this.startThinkingBlock(newId) return this.emitThinkingDelta(text, newId) @@ -373,7 +287,6 @@ export class AiSdkToAnthropicSSE { const index = this.state.thinkingBlocks.get(targetId) if (index === undefined) { - // If the block doesn't exist, create it this.startThinkingBlock(targetId) return this.emitThinkingDelta(text, targetId) } @@ -394,7 +307,7 @@ export class AiSdkToAnthropicSSE { delta } - this.onEvent(event) + this.emit(event) } private stopThinkingBlock(reasoningId?: string): void { @@ -409,12 +322,10 @@ export class AiSdkToAnthropicSSE { index } - this.onEvent(event) + this.emit(event) this.state.thinkingBlocks.delete(targetId) - // Update currentThinkingId if we just closed the current one if (this.state.currentThinkingId === targetId) { - // Set to the most recent remaining thinking block, or null if none const remaining = Array.from(this.state.thinkingBlocks.keys()) this.state.currentThinkingId = remaining.length > 0 ? remaining[remaining.length - 1] : null } @@ -423,12 +334,11 @@ export class AiSdkToAnthropicSSE { private handleToolCall(chunk: { type: 'tool-call'; toolCallId: string; toolName: string; args: unknown }): void { const { toolCallId, toolName, args } = chunk - // Check if we already have this tool call if (this.state.toolBlocks.has(toolCallId)) { return } - const index = this.state.currentBlockIndex++ + const index = this.allocateBlockIndex() this.state.toolBlocks.set(toolCallId, index) const inputJson = JSON.stringify(args) @@ -457,9 +367,9 @@ export class AiSdkToAnthropicSSE { content_block: contentBlock } - this.onEvent(startEvent) + this.emit(startEvent) - // Emit the full input as a delta (Anthropic streams JSON incrementally) + // Emit the full input as a delta const delta: InputJSONDelta = { type: 'input_json_delta', partial_json: inputJson @@ -471,7 +381,7 @@ export class AiSdkToAnthropicSSE { delta } - this.onEvent(deltaEvent) + this.emit(deltaEvent) // Emit content_block_stop const stopEvent: RawContentBlockStopEvent = { @@ -479,21 +389,18 @@ export class AiSdkToAnthropicSSE { index } - this.onEvent(stopEvent) + this.emit(stopEvent) - // Mark that we have tool use this.state.stopReason = 'tool_use' } private handleFinish(chunk: { type: 'finish'; finishReason?: FinishReason; totalUsage?: LanguageModelUsage }): void { - // Update usage if (chunk.totalUsage) { this.state.inputTokens = chunk.totalUsage.inputTokens || 0 this.state.outputTokens = chunk.totalUsage.outputTokens || 0 this.state.cacheInputTokens = chunk.totalUsage.cachedInputTokens || 0 } - // Determine finish reason if (!this.state.stopReason) { switch (chunk.finishReason) { case 'stop': @@ -514,7 +421,10 @@ export class AiSdkToAnthropicSSE { } } - private finalize(): void { + /** + * Finalize the stream and emit closing events + */ + protected finalize(): void { // Close any open blocks if (this.state.textBlockIndex !== null) { this.stopTextBlock() @@ -536,34 +446,20 @@ export class AiSdkToAnthropicSSE { const messageDeltaEvent: RawMessageDeltaEvent = { type: 'message_delta', delta: { - stop_reason: this.state.stopReason || 'end_turn', + stop_reason: (this.state.stopReason as StopReason) || 'end_turn', stop_sequence: null }, usage } - this.onEvent(messageDeltaEvent) + this.emit(messageDeltaEvent) // Emit message_stop const messageStopEvent: RawMessageStopEvent = { type: 'message_stop' } - this.onEvent(messageStopEvent) - } - - /** - * Set input token count (typically from prompt) - */ - setInputTokens(count: number): void { - this.state.inputTokens = count - } - - /** - * Get the current message ID - */ - getMessageId(): string { - return this.state.messageId + this.emit(messageStopEvent) } /** @@ -572,7 +468,6 @@ export class AiSdkToAnthropicSSE { buildNonStreamingResponse(): Message { const content: ContentBlock[] = [] - // Collect all content blocks in order const sortedBlocks = Array.from(this.state.blocks.values()).sort((a, b) => a.index - b.index) for (const block of sortedBlocks) { @@ -607,7 +502,7 @@ export class AiSdkToAnthropicSSE { role: 'assistant', content, model: this.state.model, - stop_reason: this.state.stopReason || 'end_turn', + stop_reason: (this.state.stopReason as StopReason) || 'end_turn', stop_sequence: null, usage: { input_tokens: this.state.inputTokens, @@ -620,18 +515,4 @@ export class AiSdkToAnthropicSSE { } } -/** - * Format an Anthropic SSE event for HTTP streaming - */ -export function formatSSEEvent(event: RawMessageStreamEvent): string { - return `event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n` -} - -/** - * Create a done marker for SSE stream - */ -export function formatSSEDone(): string { - return 'data: [DONE]\n\n' -} - export default AiSdkToAnthropicSSE diff --git a/src/main/apiServer/adapters/stream/AiSdkToOpenAISSE.ts b/src/main/apiServer/adapters/stream/AiSdkToOpenAISSE.ts new file mode 100644 index 0000000000..885fecf960 --- /dev/null +++ b/src/main/apiServer/adapters/stream/AiSdkToOpenAISSE.ts @@ -0,0 +1,416 @@ +/** + * AI SDK to OpenAI Compatible SSE Adapter + * + * Converts AI SDK's fullStream (TextStreamPart) events to OpenAI-compatible Chat Completions API SSE format. + * This enables any AI provider supported by AI SDK to be exposed via OpenAI-compatible API. + * + * Supports extended features used by OpenAI-compatible providers: + * - reasoning_content: DeepSeek-style reasoning/thinking content + * - Standard OpenAI fields: content, tool_calls, finish_reason, usage + * + * OpenAI SSE Event Flow: + * 1. data: {chunk with role} - First chunk with assistant role + * 2. data: {chunk with content/reasoning_content delta} - Incremental content updates + * 3. data: {chunk with tool_calls} - Tool call deltas + * 4. data: {chunk with finish_reason} - Final chunk with finish reason + * 5. data: [DONE] - Stream complete + * + * @see https://platform.openai.com/docs/api-reference/chat/streaming + */ + +import type OpenAI from '@cherrystudio/openai' +import { loggerService } from '@logger' +import type { FinishReason, LanguageModelUsage, TextStreamPart, ToolSet } from 'ai' + +import type { StreamAdapterOptions } from '../interfaces' +import { BaseStreamAdapter } from './BaseStreamAdapter' + +const logger = loggerService.withContext('AiSdkToOpenAISSE') + +/** + * Use official OpenAI SDK types as base + */ +type ChatCompletionChunkBase = OpenAI.Chat.Completions.ChatCompletionChunk +type ChatCompletion = OpenAI.Chat.Completions.ChatCompletion + +/** + * Extended delta type with reasoning_content support (DeepSeek-style) + */ +interface OpenAICompatibleDelta { + role?: 'assistant' + content?: string | null + reasoning_content?: string | null + tool_calls?: ChatCompletionChunkBase['choices'][0]['delta']['tool_calls'] +} + +/** + * Extended ChatCompletionChunk with reasoning_content support + */ +export interface OpenAICompatibleChunk extends Omit { + choices: Array<{ + index: number + delta: OpenAICompatibleDelta + finish_reason: ChatCompletionChunkBase['choices'][0]['finish_reason'] + logprobs?: ChatCompletionChunkBase['choices'][0]['logprobs'] + }> +} + +/** + * Extended ChatCompletion message with reasoning_content support + */ +interface OpenAICompatibleMessage extends OpenAI.Chat.Completions.ChatCompletionMessage { + reasoning_content?: string | null +} + +/** + * Extended ChatCompletion with reasoning_content support + */ +export interface OpenAICompatibleCompletion extends Omit { + choices: Array<{ + index: number + message: OpenAICompatibleMessage + finish_reason: ChatCompletion['choices'][0]['finish_reason'] + logprobs: ChatCompletion['choices'][0]['logprobs'] + }> +} + +/** + * OpenAI finish reasons + */ +type OpenAIFinishReason = 'stop' | 'length' | 'tool_calls' | 'content_filter' | null + +/** + * Tool call state for tracking incremental tool calls + */ +interface ToolCallState { + index: number + id: string + name: string + arguments: string +} + +/** + * Adapter that converts AI SDK fullStream events to OpenAI-compatible SSE events + * + * Uses TransformStream for composable stream processing: + * ``` + * const adapter = new AiSdkToOpenAISSE({ model: 'gpt-4' }) + * const outputStream = adapter.transform(aiSdkStream) + * ``` + */ +export class AiSdkToOpenAISSE extends BaseStreamAdapter { + private createdTimestamp: number + private toolCalls: Map = new Map() + private currentToolCallIndex = 0 + private finishReason: OpenAIFinishReason = null + private reasoningContent = '' + + constructor(options: StreamAdapterOptions) { + super(options) + this.createdTimestamp = Math.floor(Date.now() / 1000) + } + + /** + * Create a base chunk structure + */ + private createBaseChunk(delta: OpenAICompatibleDelta): OpenAICompatibleChunk { + return { + id: `chatcmpl-${this.state.messageId}`, + object: 'chat.completion.chunk', + created: this.createdTimestamp, + model: this.state.model, + choices: [ + { + index: 0, + delta, + finish_reason: null + } + ] + } + } + + /** + * Emit the initial message start event (with role) + */ + protected emitMessageStart(): void { + if (this.state.hasEmittedMessageStart) return + + this.state.hasEmittedMessageStart = true + + // Emit initial chunk with role + const chunk = this.createBaseChunk({ role: 'assistant' }) + this.emit(chunk) + } + + /** + * Process a single AI SDK chunk and emit corresponding OpenAI events + */ + protected processChunk(chunk: TextStreamPart): void { + logger.silly('AiSdkToOpenAISSE - Processing chunk:', { chunk: JSON.stringify(chunk) }) + switch (chunk.type) { + // === Text Events === + case 'text-start': + // OpenAI doesn't have a separate start event + break + + case 'text-delta': + this.emitContentDelta(chunk.text || '') + break + + case 'text-end': + // OpenAI doesn't have a separate end event + break + + // === Reasoning/Thinking Events === + // Support DeepSeek-style reasoning_content + case 'reasoning-start': + // No separate start event needed + break + + case 'reasoning-delta': + this.emitReasoningDelta(chunk.text || '') + break + + case 'reasoning-end': + // No separate end event needed + break + + // === Tool Events === + case 'tool-call': + this.handleToolCall({ + toolCallId: chunk.toolCallId, + toolName: chunk.toolName, + args: chunk.input + }) + break + + case 'tool-result': + // Tool results are not part of streaming output + break + + case 'finish-step': + if (chunk.finishReason === 'tool-calls') { + this.finishReason = 'tool_calls' + } + break + + case 'finish': + this.handleFinish(chunk) + break + + case 'error': + throw chunk.error + + default: + break + } + } + + private emitContentDelta(content: string): void { + if (!content) return + + // Track content in state + let textBlock = this.state.blocks.get(0) + if (!textBlock) { + textBlock = { + type: 'text', + index: 0, + started: true, + content: '' + } + this.state.blocks.set(0, textBlock) + } + textBlock.content += content + + const chunk = this.createBaseChunk({ content }) + this.emit(chunk) + } + + private emitReasoningDelta(reasoningContent: string): void { + if (!reasoningContent) return + + // Track reasoning content + this.reasoningContent += reasoningContent + + // Also track in state blocks for non-streaming response + let thinkingBlock = this.state.blocks.get(-1) // Use -1 for thinking block + if (!thinkingBlock) { + thinkingBlock = { + type: 'thinking', + index: -1, + started: true, + content: '' + } + this.state.blocks.set(-1, thinkingBlock) + } + thinkingBlock.content += reasoningContent + + // Emit chunk with reasoning_content (DeepSeek-style) + const chunk = this.createBaseChunk({ reasoning_content: reasoningContent }) + this.emit(chunk) + } + + private handleToolCall(params: { toolCallId: string; toolName: string; args: unknown }): void { + const { toolCallId, toolName, args } = params + + if (this.toolCalls.has(toolCallId)) { + return + } + + const index = this.currentToolCallIndex++ + const argsString = JSON.stringify(args) + + this.toolCalls.set(toolCallId, { + index, + id: toolCallId, + name: toolName, + arguments: argsString + }) + + // Track in state + const blockIndex = this.allocateBlockIndex() + this.state.blocks.set(blockIndex, { + type: 'tool_use', + index: blockIndex, + started: true, + content: argsString, + toolId: toolCallId, + toolName, + toolInput: argsString + }) + + // Emit tool call chunk + const chunk = this.createBaseChunk({ + tool_calls: [ + { + index, + id: toolCallId, + type: 'function', + function: { + name: toolName, + arguments: argsString + } + } + ] + }) + + this.emit(chunk) + this.finishReason = 'tool_calls' + } + + private handleFinish(chunk: { type: 'finish'; finishReason?: FinishReason; totalUsage?: LanguageModelUsage }): void { + if (chunk.totalUsage) { + this.state.inputTokens = chunk.totalUsage.inputTokens || 0 + this.state.outputTokens = chunk.totalUsage.outputTokens || 0 + } + + if (!this.finishReason) { + switch (chunk.finishReason) { + case 'stop': + this.finishReason = 'stop' + break + case 'length': + this.finishReason = 'length' + break + case 'tool-calls': + this.finishReason = 'tool_calls' + break + case 'content-filter': + this.finishReason = 'content_filter' + break + default: + this.finishReason = 'stop' + } + } + + this.state.stopReason = this.finishReason + } + + /** + * Finalize the stream and emit closing events + */ + protected finalize(): void { + // Emit final chunk with finish_reason and usage + const finalChunk: OpenAICompatibleChunk = { + id: `chatcmpl-${this.state.messageId}`, + object: 'chat.completion.chunk', + created: this.createdTimestamp, + model: this.state.model, + choices: [ + { + index: 0, + delta: {}, + finish_reason: this.finishReason || 'stop' + } + ], + usage: { + prompt_tokens: this.state.inputTokens, + completion_tokens: this.state.outputTokens, + total_tokens: this.state.inputTokens + this.state.outputTokens + } + } + + this.emit(finalChunk) + } + + /** + * Build a complete ChatCompletion object for non-streaming responses + */ + buildNonStreamingResponse(): OpenAICompatibleCompletion { + // Collect text content + let content: string | null = null + const textBlock = this.state.blocks.get(0) + if (textBlock && textBlock.type === 'text' && textBlock.content) { + content = textBlock.content + } + + // Collect reasoning content + let reasoningContent: string | null = null + const thinkingBlock = this.state.blocks.get(-1) + if (thinkingBlock && thinkingBlock.type === 'thinking' && thinkingBlock.content) { + reasoningContent = thinkingBlock.content + } + + // Collect tool calls + const toolCallsArray: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = Array.from( + this.toolCalls.values() + ).map((tc) => ({ + id: tc.id, + type: 'function' as const, + function: { + name: tc.name, + arguments: tc.arguments + } + })) + + const message: OpenAICompatibleMessage = { + role: 'assistant', + content, + refusal: null, + ...(reasoningContent ? { reasoning_content: reasoningContent } : {}), + ...(toolCallsArray.length > 0 ? { tool_calls: toolCallsArray } : {}) + } + + return { + id: `chatcmpl-${this.state.messageId}`, + object: 'chat.completion', + created: this.createdTimestamp, + model: this.state.model, + choices: [ + { + index: 0, + message, + finish_reason: this.finishReason || 'stop', + logprobs: null + } + ], + usage: { + prompt_tokens: this.state.inputTokens, + completion_tokens: this.state.outputTokens, + total_tokens: this.state.inputTokens + this.state.outputTokens + } + } + } +} + +export default AiSdkToOpenAISSE diff --git a/src/main/apiServer/adapters/stream/BaseStreamAdapter.ts b/src/main/apiServer/adapters/stream/BaseStreamAdapter.ts new file mode 100644 index 0000000000..8b498e4699 --- /dev/null +++ b/src/main/apiServer/adapters/stream/BaseStreamAdapter.ts @@ -0,0 +1,161 @@ +/** + * Base Stream Adapter + * + * Abstract base class for stream adapters that provides: + * - Shared state management (messageId, tokens, blocks, etc.) + * - TransformStream implementation + * - Common utility methods + */ + +import type { TextStreamPart, ToolSet } from 'ai' + +import type { AdapterState, ContentBlockState, IStreamAdapter, StreamAdapterOptions } from '../interfaces' + +/** + * Abstract base class for stream adapters + * + * Subclasses must implement: + * - processChunk(): Handle individual stream chunks + * - emitMessageStart(): Emit initial message event + * - finalize(): Clean up and emit final events + * - buildNonStreamingResponse(): Build complete response object + */ +export abstract class BaseStreamAdapter implements IStreamAdapter { + protected state: AdapterState + protected controller: TransformStreamDefaultController | null = null + private transformStream: TransformStream, TOutputEvent> + + constructor(options: StreamAdapterOptions) { + this.state = this.createInitialState(options) + this.transformStream = this.createTransformStream() + } + + /** + * Create initial adapter state + */ + protected createInitialState(options: StreamAdapterOptions): AdapterState { + return { + messageId: options.messageId || `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`, + model: options.model, + inputTokens: options.inputTokens || 0, + outputTokens: 0, + cacheInputTokens: 0, + currentBlockIndex: 0, + blocks: new Map(), + textBlockIndex: null, + thinkingBlocks: new Map(), + currentThinkingId: null, + toolBlocks: new Map(), + stopReason: null, + hasEmittedMessageStart: false + } + } + + /** + * Create the TransformStream for processing + */ + private createTransformStream(): TransformStream, TOutputEvent> { + return new TransformStream, TOutputEvent>({ + start: (controller) => { + this.controller = controller + // Note: emitMessageStart is called lazily in transform or finalize + // to allow configuration changes (like setInputTokens) after construction + }, + transform: (chunk, _controller) => { + // Ensure message_start is emitted before processing chunks + this.emitMessageStart() + this.processChunk(chunk) + }, + flush: (_controller) => { + // Ensure message_start is emitted even for empty streams + this.emitMessageStart() + this.finalize() + } + }) + } + + /** + * Transform input stream to output stream + */ + transform(input: ReadableStream>): ReadableStream { + return input.pipeThrough(this.transformStream) + } + + /** + * Get the internal TransformStream + */ + getTransformStream(): TransformStream, TOutputEvent> { + return this.transformStream + } + + /** + * Get message ID + */ + getMessageId(): string { + return this.state.messageId + } + + /** + * Set input token count + */ + setInputTokens(count: number): void { + this.state.inputTokens = count + } + + /** + * Emit an event to the output stream + */ + protected emit(event: TOutputEvent): void { + if (this.controller) { + this.controller.enqueue(event) + } + } + + /** + * Get or create a content block + */ + protected getOrCreateBlock(index: number, type: ContentBlockState['type']): ContentBlockState { + let block = this.state.blocks.get(index) + if (!block) { + block = { + type, + index, + started: false, + content: '' + } + this.state.blocks.set(index, block) + } + return block + } + + /** + * Allocate a new block index + */ + protected allocateBlockIndex(): number { + return this.state.currentBlockIndex++ + } + + // ===== Abstract methods to be implemented by subclasses ===== + + /** + * Process a single chunk from the AI SDK stream + */ + protected abstract processChunk(chunk: TextStreamPart): void + + /** + * Emit the initial message start event + */ + protected abstract emitMessageStart(): void + + /** + * Finalize the stream and emit closing events + */ + protected abstract finalize(): void + + /** + * Build a non-streaming response from accumulated state + */ + abstract buildNonStreamingResponse(): unknown +} + +export default BaseStreamAdapter diff --git a/src/main/apiServer/adapters/stream/index.ts b/src/main/apiServer/adapters/stream/index.ts new file mode 100644 index 0000000000..7562df5606 --- /dev/null +++ b/src/main/apiServer/adapters/stream/index.ts @@ -0,0 +1,3 @@ +export { AiSdkToAnthropicSSE } from './AiSdkToAnthropicSSE' +export { AiSdkToOpenAISSE } from './AiSdkToOpenAISSE' +export { BaseStreamAdapter } from './BaseStreamAdapter' diff --git a/src/main/apiServer/routes/chat.ts b/src/main/apiServer/routes/chat.ts index 3dd58b9654..999ad36312 100644 --- a/src/main/apiServer/routes/chat.ts +++ b/src/main/apiServer/routes/chat.ts @@ -1,13 +1,10 @@ -import type { ChatCompletionCreateParams } from '@cherrystudio/openai/resources' import type { Request, Response } from 'express' import express from 'express' import { loggerService } from '../../services/LoggerService' -import { - ChatCompletionModelError, - chatCompletionService, - ChatCompletionValidationError -} from '../services/chat-completion' +import type { ExtendedChatCompletionCreateParams } from '../adapters' +import { generateMessage, streamToResponse } from '../services/ProxyStreamService' +import { validateModelId } from '../utils' const logger = loggerService.withContext('ApiServerChatRoutes') @@ -22,44 +19,17 @@ interface ErrorResponseBody { } const mapChatCompletionError = (error: unknown): { status: number; body: ErrorResponseBody } => { - if (error instanceof ChatCompletionValidationError) { - logger.warn('Chat completion validation error', { - errors: error.errors - }) - - return { - status: 400, - body: { - error: { - message: error.errors.join('; '), - type: 'invalid_request_error', - code: 'validation_failed' - } - } - } - } - - if (error instanceof ChatCompletionModelError) { - logger.warn('Chat completion model error', error.error) - - return { - status: 400, - body: { - error: { - message: error.error.message, - type: 'invalid_request_error', - code: error.error.code - } - } - } - } - if (error instanceof Error) { let statusCode = 500 let errorType = 'server_error' let errorCode = 'internal_error' - if (error.message.includes('API key') || error.message.includes('authentication')) { + // Model validation errors + if (error.message.includes('Model') && error.message.includes('not found')) { + statusCode = 400 + errorType = 'invalid_request_error' + errorCode = 'model_not_found' + } else if (error.message.includes('API key') || error.message.includes('authentication')) { statusCode = 401 errorType = 'authentication_error' errorCode = 'invalid_api_key' @@ -182,7 +152,7 @@ const mapChatCompletionError = (error: unknown): { status: number; body: ErrorRe */ router.post('/completions', async (req: Request, res: Response) => { try { - const request: ChatCompletionCreateParams = req.body + const request = req.body as ExtendedChatCompletionCreateParams if (!request) { return res.status(400).json({ @@ -194,6 +164,26 @@ router.post('/completions', async (req: Request, res: Response) => { }) } + if (!request.model) { + return res.status(400).json({ + error: { + message: 'Model is required', + type: 'invalid_request_error', + code: 'missing_model' + } + }) + } + + if (!request.messages || request.messages.length === 0) { + return res.status(400).json({ + error: { + message: 'Messages are required', + type: 'invalid_request_error', + code: 'missing_messages' + } + }) + } + logger.debug('Chat completion request', { model: request.model, messageCount: request.messages?.length || 0, @@ -201,40 +191,51 @@ router.post('/completions', async (req: Request, res: Response) => { temperature: request.temperature }) + // Validate model and get provider + const modelValidation = await validateModelId(request.model) + if (!modelValidation.valid) { + return res.status(400).json({ + error: { + message: modelValidation.error?.message || 'Model not found', + type: 'invalid_request_error', + code: modelValidation.error?.code || 'model_not_found' + } + }) + } + + const provider = modelValidation.provider! + const modelId = modelValidation.modelId! const isStreaming = !!request.stream if (isStreaming) { - const { stream } = await chatCompletionService.processStreamingCompletion(request) - - res.setHeader('Content-Type', 'text/event-stream; charset=utf-8') - res.setHeader('Cache-Control', 'no-cache, no-transform') - res.setHeader('Connection', 'keep-alive') - res.setHeader('X-Accel-Buffering', 'no') - res.flushHeaders() - try { - for await (const chunk of stream) { - res.write(`data: ${JSON.stringify(chunk)}\n\n`) - } - res.write('data: [DONE]\n\n') - } catch (streamError: any) { + await streamToResponse({ + response: res, + provider, + modelId, + params: request, + inputFormat: 'openai', + outputFormat: 'openai' + }) + } catch (streamError) { logger.error('Stream error', { error: streamError }) - res.write( - `data: ${JSON.stringify({ - error: { - message: 'Stream processing error', - type: 'server_error', - code: 'stream_error' - } - })}\n\n` - ) - } finally { - res.end() + // If headers weren't sent yet, return JSON error + if (!res.headersSent) { + const { status, body } = mapChatCompletionError(streamError) + return res.status(status).json(body) + } + // Otherwise the error is already handled by streamToResponse } return } - const { response } = await chatCompletionService.processCompletion(request) + const response = await generateMessage({ + provider, + modelId, + params: request, + inputFormat: 'openai', + outputFormat: 'openai' + }) return res.json(response) } catch (error: unknown) { const { status, body } = mapChatCompletionError(error) diff --git a/src/main/apiServer/routes/messages.ts b/src/main/apiServer/routes/messages.ts index efd8ae1ae1..652d4f46bb 100644 --- a/src/main/apiServer/routes/messages.ts +++ b/src/main/apiServer/routes/messages.ts @@ -8,7 +8,7 @@ import express from 'express' import { approximateTokenSize } from 'tokenx' import { messagesService } from '../services/messages' -import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages' +import { generateMessage, streamToResponse } from '../services/ProxyStreamService' import { getProviderById, isModelAnthropicCompatible, validateModelId } from '../utils' /** @@ -322,7 +322,7 @@ async function handleUnifiedProcessing({ }) if (request.stream) { - await streamUnifiedMessages({ + await streamToResponse({ response: res, provider, modelId: actualModelId, @@ -336,7 +336,7 @@ async function handleUnifiedProcessing({ } }) } else { - const response = await generateUnifiedMessage({ + const response = await generateMessage({ provider, modelId: actualModelId, params: request, diff --git a/src/main/apiServer/services/ProxyStreamService.ts b/src/main/apiServer/services/ProxyStreamService.ts new file mode 100644 index 0000000000..5519dc49ef --- /dev/null +++ b/src/main/apiServer/services/ProxyStreamService.ts @@ -0,0 +1,465 @@ +/** + * Proxy Stream Service + * + * Handles proxying AI requests through the unified AI SDK pipeline, + * converting between different API formats using the adapter system. + */ + +import type { LanguageModelV2Middleware } from '@ai-sdk/provider' +import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core' +import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' +import { loggerService } from '@logger' +import { generateSignature as cherryaiGenerateSignature } from '@main/integration/cherryai' +import anthropicService from '@main/services/AnthropicService' +import copilotService from '@main/services/CopilotService' +import { reduxService } from '@main/services/ReduxService' +import { + type AiSdkConfig, + type AiSdkConfigContext, + formatProviderApiHost, + initializeSharedProviders, + type ProviderFormatContext, + providerToAiSdkConfig as sharedProviderToAiSdkConfig, + resolveActualProvider +} from '@shared/aiCore' +import { COPILOT_DEFAULT_HEADERS } from '@shared/aiCore/constant' +import type { MinimalProvider } from '@shared/types' +import { defaultAppHeaders } from '@shared/utils' +import type { Provider } from '@types' +import type { Provider as AiSdkProvider } from 'ai' +import { simulateStreamingMiddleware, stepCountIs, wrapLanguageModel } from 'ai' +import { net } from 'electron' +import type { Response } from 'express' + +import type { InputFormat, InputParamsMap, IStreamAdapter } from '../adapters' +import { MessageConverterFactory, type OutputFormat, StreamAdapterFactory } from '../adapters' +import { googleReasoningCache, openRouterReasoningCache } from './reasoning-cache' + +const logger = loggerService.withContext('ProxyStreamService') + +initializeSharedProviders({ + warn: (message) => logger.warn(message), + error: (message, error) => logger.error(message, error) +}) + +// ============================================================================ +// Configuration Interfaces +// ============================================================================ + +/** + * Middleware type alias + */ +type LanguageModelMiddleware = LanguageModelV2Middleware + +/** + * Union type for all supported input params + */ +type InputParams = InputParamsMap[InputFormat] + +/** + * Configuration for streaming message requests + */ +export interface StreamConfig { + response: Response + provider: Provider + modelId: string + params: InputParams + inputFormat?: InputFormat + outputFormat?: OutputFormat + onError?: (error: unknown) => void + onComplete?: () => void + middlewares?: LanguageModelMiddleware[] + plugins?: AiPlugin[] +} + +/** + * Configuration for non-streaming message generation + */ +export interface GenerateConfig { + provider: Provider + modelId: string + params: InputParams + inputFormat?: InputFormat + outputFormat?: OutputFormat + middlewares?: LanguageModelMiddleware[] + plugins?: AiPlugin[] +} + +/** + * Internal configuration for stream execution + */ +interface ExecuteStreamConfig { + provider: Provider + modelId: string + params: InputParams + inputFormat: InputFormat + outputFormat: OutputFormat + middlewares?: LanguageModelMiddleware[] + plugins?: AiPlugin[] +} + +// ============================================================================ +// Provider Configuration +// ============================================================================ + +function getMainProcessFormatContext(): ProviderFormatContext { + const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai') + return { + vertex: { + project: vertexSettings?.projectId || 'default-project', + location: vertexSettings?.location || 'us-central1' + } + } +} + +function isSupportStreamOptionsProvider(provider: MinimalProvider): boolean { + const NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS = ['mistral'] as const + return !NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS.some((pid) => pid === provider.id) +} + +const mainProcessSdkContext: AiSdkConfigContext = { + isSupportStreamOptionsProvider, + getIncludeUsageSetting: () => + reduxService.selectSync('state.settings.openAI?.streamOptions?.includeUsage'), + fetch: net.fetch as typeof globalThis.fetch +} + +function getActualProvider(provider: Provider, modelId: string): Provider { + const model = provider.models?.find((m) => m.id === modelId) + if (!model) return provider + return resolveActualProvider(provider, model) +} + +function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig { + const actualProvider = getActualProvider(provider, modelId) + const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext()) + return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext) +} + +/** + * Create AI SDK provider instance from config + */ +async function createAiSdkProvider(config: AiSdkConfig): Promise { + let providerId = config.providerId + + // Handle special provider modes + if (providerId === 'openai' && config.options?.mode === 'chat') { + providerId = 'openai-chat' + } else if (providerId === 'azure' && config.options?.mode === 'responses') { + providerId = 'azure-responses' + } else if (providerId === 'cherryin' && config.options?.mode === 'chat') { + providerId = 'cherryin-chat' + } + + const provider = await createProviderCore(providerId, config.options) + return provider +} + +/** + * Prepare special provider configuration for providers that need dynamic tokens + */ +async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkConfig): Promise { + switch (provider.id) { + case 'copilot': { + const storedHeaders = + ((await reduxService.select('state.copilot.defaultHeaders')) as Record | null) ?? {} + const headers: Record = { + ...COPILOT_DEFAULT_HEADERS, + ...storedHeaders + } + + try { + const { token } = await copilotService.getToken(null as never, headers) + config.options.apiKey = token + const existingHeaders = (config.options.headers as Record | undefined) ?? {} + config.options.headers = { + ...headers, + ...existingHeaders + } + } catch (error) { + logger.error('Failed to get Copilot token', error as Error) + throw new Error('Failed to get Copilot token. Please re-authorize Copilot.') + } + break + } + case 'anthropic': { + if (provider.authType === 'oauth') { + try { + const oauthToken = await anthropicService.getValidAccessToken() + if (!oauthToken) { + throw new Error('Anthropic OAuth token not available. Please re-authorize.') + } + config.options = { + ...config.options, + headers: { + ...(config.options.headers ? config.options.headers : {}), + 'Content-Type': 'application/json', + 'anthropic-version': '2023-06-01', + 'anthropic-beta': 'oauth-2025-04-20', + Authorization: `Bearer ${oauthToken}` + }, + baseURL: 'https://api.anthropic.com/v1', + apiKey: '' + } + } catch (error) { + logger.error('Failed to get Anthropic OAuth token', error as Error) + throw new Error('Failed to get Anthropic OAuth token. Please re-authorize.') + } + } + break + } + case 'cherryai': { + const baseFetch = net.fetch as typeof globalThis.fetch + config.options.fetch = async (url: RequestInfo | URL, options?: RequestInit) => { + if (!options?.body) { + return baseFetch(url, options) + } + const signature = cherryaiGenerateSignature({ + method: 'POST', + path: '/chat/completions', + query: '', + body: JSON.parse(options.body as string) + }) + return baseFetch(url, { + ...options, + headers: { + ...(options.headers as Record), + ...signature + } + }) + } + break + } + } + return config +} + +// ============================================================================ +// Core Stream Execution +// ============================================================================ + +/** + * Execute stream and return adapter with output stream + * + * Uses MessageConverterFactory to create the appropriate converter + * based on input format, eliminating format-specific if-else logic. + */ +async function executeStream(config: ExecuteStreamConfig): Promise<{ + adapter: IStreamAdapter + outputStream: ReadableStream +}> { + const { provider, modelId, params, inputFormat, outputFormat, middlewares = [], plugins = [] } = config + + // Convert provider config to AI SDK config + let sdkConfig = providerToAiSdkConfig(provider, modelId) + sdkConfig = await prepareSpecialProviderConfig(provider, sdkConfig) + + // Create provider instance and get language model + const aiSdkProvider = await createAiSdkProvider(sdkConfig) + const baseModel = aiSdkProvider.languageModel(modelId) + + // Apply middlewares if present + const model = + middlewares.length > 0 && typeof baseModel === 'object' + ? (wrapLanguageModel({ model: baseModel, middleware: middlewares as never }) as typeof baseModel) + : baseModel + + // Create executor with plugins + const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins) + + const converter = MessageConverterFactory.create(inputFormat, { + googleReasoningCache, + openRouterReasoningCache + }) + + // Convert messages, tools, and extract options using unified interface + const coreMessages = converter.toAiSdkMessages(params) + const tools = converter.toAiSdkTools?.(params) + const streamOptions = converter.extractStreamOptions(params) + const providerOptions = converter.extractProviderOptions(provider, params) + + // Create adapter via factory + const adapter = StreamAdapterFactory.createAdapter(outputFormat, { + model: `${provider.id}:${modelId}` + }) + + // Execute AI SDK stream with extracted options + const result = await executor.streamText({ + model, + messages: coreMessages, + ...streamOptions, + stopWhen: stepCountIs(100), + headers: defaultAppHeaders(), + tools, + providerOptions + }) + + // Transform stream using adapter + const outputStream = adapter.transform(result.fullStream) + + return { adapter, outputStream } +} + +// ============================================================================ +// Public API +// ============================================================================ + +/** + * Stream a message request and write to HTTP response + * + * Uses TransformStream-based adapters for efficient streaming. + * + * @example + * ```typescript + * await streamToResponse({ + * response: res, + * provider, + * modelId: 'claude-3-opus', + * params: messageCreateParams, + * outputFormat: 'anthropic' + * }) + * ``` + */ +export async function streamToResponse(config: StreamConfig): Promise { + const { + response, + provider, + modelId, + params, + inputFormat = 'anthropic', + outputFormat = 'anthropic', + onError, + onComplete, + middlewares = [], + plugins = [] + } = config + + logger.info('Starting proxy stream', { + providerId: provider.id, + providerType: provider.type, + modelId, + inputFormat, + outputFormat, + middlewareCount: middlewares.length, + pluginCount: plugins.length + }) + + try { + // Set SSE headers + response.setHeader('Content-Type', 'text/event-stream') + response.setHeader('Cache-Control', 'no-cache') + response.setHeader('Connection', 'keep-alive') + response.setHeader('X-Accel-Buffering', 'no') + + const { outputStream } = await executeStream({ + provider, + modelId, + params, + inputFormat, + outputFormat, + middlewares, + plugins + }) + + // Get formatter for the output format + const formatter = StreamAdapterFactory.getFormatter(outputFormat) + + // Stream events to response + const reader = outputStream.getReader() + try { + while (true) { + const { done, value } = await reader.read() + if (done) break + response.write(formatter.formatEvent(value)) + } + } finally { + reader.releaseLock() + } + + // Send done marker and end response + response.write(formatter.formatDone()) + response.end() + + logger.info('Proxy stream completed', { providerId: provider.id, modelId }) + onComplete?.() + } catch (error) { + logger.error('Error in proxy stream', error as Error, { providerId: provider.id, modelId }) + onError?.(error) + throw error + } +} + +/** + * Generate a non-streaming message response + * + * Uses simulateStreamingMiddleware to reuse the same streaming logic. + * + * @example + * ```typescript + * const message = await generateMessage({ + * provider, + * modelId: 'claude-3-opus', + * params: messageCreateParams, + * outputFormat: 'anthropic' + * }) + * ``` + */ +export async function generateMessage(config: GenerateConfig): Promise { + const { + provider, + modelId, + params, + inputFormat = 'anthropic', + outputFormat = 'anthropic', + middlewares = [], + plugins = [] + } = config + + logger.info('Starting message generation', { + providerId: provider.id, + providerType: provider.type, + modelId, + inputFormat, + outputFormat, + middlewareCount: middlewares.length, + pluginCount: plugins.length + }) + + try { + // Add simulateStreamingMiddleware to reuse streaming logic + const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares] + + const { adapter, outputStream } = await executeStream({ + provider, + modelId, + params, + inputFormat, + outputFormat, + middlewares: allMiddlewares, + plugins + }) + + // Consume the stream to populate adapter state + const reader = outputStream.getReader() + while (true) { + const { done } = await reader.read() + if (done) break + } + reader.releaseLock() + + // Build final response from adapter + const finalResponse = adapter.buildNonStreamingResponse() + + logger.info('Message generation completed', { providerId: provider.id, modelId }) + + return finalResponse + } catch (error) { + logger.error('Error in message generation', error as Error, { providerId: provider.id, modelId }) + throw error + } +} + +export default { + streamToResponse, + generateMessage +} diff --git a/src/main/apiServer/services/__tests__/jsonSchemaToZod.test.ts b/src/main/apiServer/services/__tests__/jsonSchemaToZod.test.ts index 804db0d357..92863f4ba2 100644 --- a/src/main/apiServer/services/__tests__/jsonSchemaToZod.test.ts +++ b/src/main/apiServer/services/__tests__/jsonSchemaToZod.test.ts @@ -1,7 +1,7 @@ import { describe, expect, it } from 'vitest' import * as z from 'zod' -import { type JsonSchemaLike, jsonSchemaToZod } from '../unified-messages' +import { type JsonSchemaLike, jsonSchemaToZod } from '../../adapters/converters/json-schema-to-zod' describe('jsonSchemaToZod', () => { describe('Basic Types', () => { diff --git a/src/main/apiServer/services/__tests__/unified-messages.test.ts b/src/main/apiServer/services/__tests__/unified-messages.test.ts index f8ee1a4952..9e33fda4c6 100644 --- a/src/main/apiServer/services/__tests__/unified-messages.test.ts +++ b/src/main/apiServer/services/__tests__/unified-messages.test.ts @@ -1,10 +1,18 @@ import type { MessageCreateParams } from '@anthropic-ai/sdk/resources/messages' import { describe, expect, it } from 'vitest' -import { convertAnthropicToAiMessages, convertAnthropicToolsToAiSdk } from '../unified-messages' +import { AnthropicMessageConverter } from '../../adapters/converters/AnthropicMessageConverter' -describe('unified-messages', () => { - describe('convertAnthropicToolsToAiSdk', () => { +// Create a converter instance for testing +const converter = new AnthropicMessageConverter() + +// Helper functions that wrap the converter methods +const convertAnthropicToAiMessages = (params: MessageCreateParams) => converter.toAiSdkMessages(params) +const convertAnthropicToolsToAiSdk = (tools: MessageCreateParams['tools']) => + converter.toAiSdkTools({ model: 'test', max_tokens: 100, messages: [], tools }) + +describe('AnthropicMessageConverter', () => { + describe('toAiSdkTools', () => { it('should return undefined for empty tools array', () => { const result = convertAnthropicToolsToAiSdk([]) expect(result).toBeUndefined() @@ -135,7 +143,7 @@ describe('unified-messages', () => { }) }) - describe('convertAnthropicToAiMessages', () => { + describe('toAiSdkMessages', () => { describe('System Messages', () => { it('should convert string system message', () => { const params: MessageCreateParams = { diff --git a/src/main/apiServer/services/chat-completion.ts b/src/main/apiServer/services/chat-completion.ts deleted file mode 100644 index a7c6160e81..0000000000 --- a/src/main/apiServer/services/chat-completion.ts +++ /dev/null @@ -1,260 +0,0 @@ -import OpenAI from '@cherrystudio/openai' -import type { ChatCompletionCreateParams, ChatCompletionCreateParamsStreaming } from '@cherrystudio/openai/resources' -import type { Provider } from '@types' - -import { loggerService } from '../../services/LoggerService' -import type { ModelValidationError } from '../utils' -import { validateModelId } from '../utils' - -const logger = loggerService.withContext('ChatCompletionService') - -export interface ValidationResult { - isValid: boolean - errors: string[] -} - -export class ChatCompletionValidationError extends Error { - constructor(public readonly errors: string[]) { - super(`Request validation failed: ${errors.join('; ')}`) - this.name = 'ChatCompletionValidationError' - } -} - -export class ChatCompletionModelError extends Error { - constructor(public readonly error: ModelValidationError) { - super(`Model validation failed: ${error.message}`) - this.name = 'ChatCompletionModelError' - } -} - -export type PrepareRequestResult = - | { status: 'validation_error'; errors: string[] } - | { status: 'model_error'; error: ModelValidationError } - | { - status: 'ok' - provider: Provider - modelId: string - client: OpenAI - providerRequest: ChatCompletionCreateParams - } - -export class ChatCompletionService { - async resolveProviderContext( - model: string - ): Promise< - { ok: false; error: ModelValidationError } | { ok: true; provider: Provider; modelId: string; client: OpenAI } - > { - const modelValidation = await validateModelId(model) - if (!modelValidation.valid) { - return { - ok: false, - error: modelValidation.error! - } - } - - const provider = modelValidation.provider! - - if (provider.type !== 'openai') { - return { - ok: false, - error: { - type: 'unsupported_provider_type', - message: `Provider '${provider.id}' of type '${provider.type}' is not supported for OpenAI chat completions`, - code: 'unsupported_provider_type' - } - } - } - - const modelId = modelValidation.modelId! - - const client = new OpenAI({ - baseURL: provider.apiHost, - apiKey: provider.apiKey - }) - - return { - ok: true, - provider, - modelId, - client - } - } - - async prepareRequest(request: ChatCompletionCreateParams, stream: boolean): Promise { - const requestValidation = this.validateRequest(request) - if (!requestValidation.isValid) { - return { - status: 'validation_error', - errors: requestValidation.errors - } - } - - const providerContext = await this.resolveProviderContext(request.model!) - if (!providerContext.ok) { - return { - status: 'model_error', - error: providerContext.error - } - } - - const { provider, modelId, client } = providerContext - - logger.debug('Model validation successful', { - provider: provider.id, - providerType: provider.type, - modelId, - fullModelId: request.model - }) - - return { - status: 'ok', - provider, - modelId, - client, - providerRequest: stream - ? { - ...request, - model: modelId, - stream: true as const - } - : { - ...request, - model: modelId, - stream: false as const - } - } - } - - validateRequest(request: ChatCompletionCreateParams): ValidationResult { - const errors: string[] = [] - - // Validate messages - if (!request.messages) { - errors.push('Messages array is required') - } else if (!Array.isArray(request.messages)) { - errors.push('Messages must be an array') - } else if (request.messages.length === 0) { - errors.push('Messages array cannot be empty') - } else { - // Validate each message - request.messages.forEach((message, index) => { - if (!message.role) { - errors.push(`Message ${index}: role is required`) - } - if (!message.content) { - errors.push(`Message ${index}: content is required`) - } - }) - } - - // Validate optional parameters - - return { - isValid: errors.length === 0, - errors - } - } - - async processCompletion(request: ChatCompletionCreateParams): Promise<{ - provider: Provider - modelId: string - response: OpenAI.Chat.Completions.ChatCompletion - }> { - try { - logger.debug('Processing chat completion request', { - model: request.model, - messageCount: request.messages.length, - stream: request.stream - }) - - const preparation = await this.prepareRequest(request, false) - if (preparation.status === 'validation_error') { - throw new ChatCompletionValidationError(preparation.errors) - } - - if (preparation.status === 'model_error') { - throw new ChatCompletionModelError(preparation.error) - } - - const { provider, modelId, client, providerRequest } = preparation - - logger.debug('Sending request to provider', { - provider: provider.id, - model: modelId, - apiHost: provider.apiHost - }) - - const response = (await client.chat.completions.create(providerRequest)) as OpenAI.Chat.Completions.ChatCompletion - - logger.info('Chat completion processed', { - modelId, - provider: provider.id - }) - return { - provider, - modelId, - response - } - } catch (error: any) { - logger.error('Error processing chat completion', { - error, - model: request.model - }) - throw error - } - } - - async processStreamingCompletion(request: ChatCompletionCreateParams): Promise<{ - provider: Provider - modelId: string - stream: AsyncIterable - }> { - try { - logger.debug('Processing streaming chat completion request', { - model: request.model, - messageCount: request.messages.length - }) - - const preparation = await this.prepareRequest(request, true) - if (preparation.status === 'validation_error') { - throw new ChatCompletionValidationError(preparation.errors) - } - - if (preparation.status === 'model_error') { - throw new ChatCompletionModelError(preparation.error) - } - - const { provider, modelId, client, providerRequest } = preparation - - logger.debug('Sending streaming request to provider', { - provider: provider.id, - model: modelId, - apiHost: provider.apiHost - }) - - const streamRequest = providerRequest as ChatCompletionCreateParamsStreaming - const stream = (await client.chat.completions.create( - streamRequest - )) as AsyncIterable - - logger.info('Streaming chat completion started', { - modelId, - provider: provider.id - }) - return { - provider, - modelId, - stream - } - } catch (error: any) { - logger.error('Error processing streaming chat completion', { - error, - model: request.model - }) - throw error - } - } -} - -// Export singleton instance -export const chatCompletionService = new ChatCompletionService() diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts deleted file mode 100644 index c3591a46d0..0000000000 --- a/src/main/apiServer/services/unified-messages.ts +++ /dev/null @@ -1,762 +0,0 @@ -import type { AnthropicProviderOptions } from '@ai-sdk/anthropic' -import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google' -import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai' -import type { JSONSchema7, LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider' -import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils' -import type { - ImageBlockParam, - MessageCreateParams, - TextBlockParam, - Tool as AnthropicTool -} from '@anthropic-ai/sdk/resources/messages' -import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core' -import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' -import { loggerService } from '@logger' -import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@main/apiServer/adapters' -import { generateSignature as cherryaiGenerateSignature } from '@main/integration/cherryai' -import anthropicService from '@main/services/AnthropicService' -import copilotService from '@main/services/CopilotService' -import { reduxService } from '@main/services/ReduxService' -import type { OpenRouterProviderOptions } from '@openrouter/ai-sdk-provider' -import { - type AiSdkConfig, - type AiSdkConfigContext, - formatProviderApiHost, - initializeSharedProviders, - type ProviderFormatContext, - providerToAiSdkConfig as sharedProviderToAiSdkConfig, - resolveActualProvider -} from '@shared/aiCore' -import { COPILOT_DEFAULT_HEADERS } from '@shared/aiCore/constant' -import { isGemini3ModelId } from '@shared/aiCore/middlewares' -import type { MinimalProvider } from '@shared/types' -import { SystemProviderIds } from '@shared/types' -import { defaultAppHeaders } from '@shared/utils' -import { isAnthropicProvider, isGeminiProvider, isOpenAIProvider } from '@shared/utils/provider' -import type { Provider } from '@types' -import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool as AiSdkTool } from 'ai' -import { simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel, zodSchema } from 'ai' -import { net } from 'electron' -import type { Response } from 'express' -import * as z from 'zod' - -import { googleReasoningCache, openRouterReasoningCache } from './reasoning-cache' - -const logger = loggerService.withContext('UnifiedMessagesService') - -const MAGIC_STRING = 'skip_thought_signature_validator' - -function sanitizeJson(value: unknown): JSONValue { - return JSON.parse(JSON.stringify(value)) -} - -initializeSharedProviders({ - warn: (message) => logger.warn(message), - error: (message, error) => logger.error(message, error) -}) - -/** - * Configuration for unified message streaming - */ -export interface UnifiedStreamConfig { - response: Response - provider: Provider - modelId: string - params: MessageCreateParams - onError?: (error: unknown) => void - onComplete?: () => void - /** - * Optional AI SDK middlewares to apply - */ - middlewares?: LanguageModelV2Middleware[] - /** - * Optional AI Core plugins to use with the executor - */ - plugins?: AiPlugin[] -} - -/** - * Configuration for non-streaming message generation - */ -export interface GenerateUnifiedMessageConfig { - provider: Provider - modelId: string - params: MessageCreateParams - middlewares?: LanguageModelV2Middleware[] - plugins?: AiPlugin[] -} - -function getMainProcessFormatContext(): ProviderFormatContext { - const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai') - return { - vertex: { - project: vertexSettings?.projectId || 'default-project', - location: vertexSettings?.location || 'us-central1' - } - } -} - -function isSupportStreamOptionsProvider(provider: MinimalProvider): boolean { - const NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS = ['mistral'] as const - return !NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS.some((pid) => pid === provider.id) -} - -const mainProcessSdkContext: AiSdkConfigContext = { - isSupportStreamOptionsProvider, - getIncludeUsageSetting: () => - reduxService.selectSync('state.settings.openAI?.streamOptions?.includeUsage'), - fetch: net.fetch as typeof globalThis.fetch -} - -function getActualProvider(provider: Provider, modelId: string): Provider { - const model = provider.models?.find((m) => m.id === modelId) - if (!model) return provider - return resolveActualProvider(provider, model) -} - -function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig { - const actualProvider = getActualProvider(provider, modelId) - const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext()) - return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext) -} - -function convertAnthropicToolResultToAiSdk( - content: string | Array -): LanguageModelV2ToolResultOutput { - if (typeof content === 'string') { - return { type: 'text', value: content } - } - const values: Array<{ type: 'text'; text: string } | { type: 'media'; data: string; mediaType: string }> = [] - for (const block of content) { - if (block.type === 'text') { - values.push({ type: 'text', text: block.text }) - } else if (block.type === 'image') { - values.push({ - type: 'media', - data: block.source.type === 'base64' ? block.source.data : block.source.url, - mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png' - }) - } - } - return { type: 'content', value: values } -} - -/** - * JSON Schema type for tool input schemas - */ -export type JsonSchemaLike = JSONSchema7 - -/** - * Convert JSON Schema to Zod schema - * This avoids non-standard fields like input_examples that Anthropic doesn't support - * TODO: Anthropic/beta support input_examples - */ -export function jsonSchemaToZod(schema: JsonSchemaLike): z.ZodTypeAny { - const schemaType = schema.type - const enumValues = schema.enum - const description = schema.description - - // Handle enum first - if (enumValues && Array.isArray(enumValues) && enumValues.length > 0) { - if (enumValues.every((v) => typeof v === 'string')) { - const zodEnum = z.enum(enumValues as [string, ...string[]]) - return description ? zodEnum.describe(description) : zodEnum - } - // For non-string enums, use union of literals - const literals = enumValues.map((v) => z.literal(v as string | number | boolean)) - if (literals.length === 1) { - return description ? literals[0].describe(description) : literals[0] - } - const zodUnion = z.union(literals as unknown as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]]) - return description ? zodUnion.describe(description) : zodUnion - } - - // Handle union types (type: ["string", "null"]) - if (Array.isArray(schemaType)) { - const schemas = schemaType.map((t) => - jsonSchemaToZod({ - ...schema, - type: t, - enum: undefined - }) - ) - if (schemas.length === 1) { - return schemas[0] - } - return z.union(schemas as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]]) - } - - // Handle by type - switch (schemaType) { - case 'string': { - let zodString = z.string() - if (typeof schema.minLength === 'number') zodString = zodString.min(schema.minLength) - if (typeof schema.maxLength === 'number') zodString = zodString.max(schema.maxLength) - if (typeof schema.pattern === 'string') zodString = zodString.regex(new RegExp(schema.pattern)) - return description ? zodString.describe(description) : zodString - } - - case 'number': - case 'integer': { - let zodNumber = schemaType === 'integer' ? z.number().int() : z.number() - if (typeof schema.minimum === 'number') zodNumber = zodNumber.min(schema.minimum) - if (typeof schema.maximum === 'number') zodNumber = zodNumber.max(schema.maximum) - return description ? zodNumber.describe(description) : zodNumber - } - - case 'boolean': { - const zodBoolean = z.boolean() - return description ? zodBoolean.describe(description) : zodBoolean - } - - case 'null': - return z.null() - - case 'array': { - const items = schema.items - let zodArray: z.ZodArray - if (items && typeof items === 'object' && !Array.isArray(items)) { - zodArray = z.array(jsonSchemaToZod(items as JsonSchemaLike)) - } else { - zodArray = z.array(z.unknown()) - } - if (typeof schema.minItems === 'number') zodArray = zodArray.min(schema.minItems) - if (typeof schema.maxItems === 'number') zodArray = zodArray.max(schema.maxItems) - return description ? zodArray.describe(description) : zodArray - } - - case 'object': { - const properties = schema.properties - const required = schema.required || [] - - // Always use z.object() to ensure "properties" field is present in output schema - // OpenAI requires explicit properties field even for empty objects - const shape: Record = {} - if (properties && typeof properties === 'object') { - for (const [key, propSchema] of Object.entries(properties)) { - if (typeof propSchema === 'boolean') { - shape[key] = propSchema ? z.unknown() : z.never() - } else { - const zodProp = jsonSchemaToZod(propSchema as JsonSchemaLike) - shape[key] = required.includes(key) ? zodProp : zodProp.optional() - } - } - } - - const zodObject = z.object(shape) - return description ? zodObject.describe(description) : zodObject - } - - default: - // Unknown type, use z.unknown() - return z.unknown() - } -} - -export function convertAnthropicToolsToAiSdk( - tools: MessageCreateParams['tools'] -): Record | undefined { - if (!tools || tools.length === 0) return undefined - - const aiSdkTools: Record = {} - for (const anthropicTool of tools) { - if (anthropicTool.type === 'bash_20250124') continue - const toolDef = anthropicTool as AnthropicTool - const rawSchema = toolDef.input_schema - // Convert Anthropic's InputSchema to JSONSchema7-compatible format - const schema = jsonSchemaToZod(rawSchema as JsonSchemaLike) - - // Use tool() with inputSchema (AI SDK v5 API) - const aiTool = tool({ - description: toolDef.description || '', - inputSchema: zodSchema(schema) - }) - - aiSdkTools[toolDef.name] = aiTool - } - return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined -} - -export function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage[] { - const messages: ModelMessage[] = [] - - // System message - if (params.system) { - if (typeof params.system === 'string') { - messages.push({ role: 'system', content: params.system }) - } else if (Array.isArray(params.system)) { - const systemText = params.system - .filter((block) => block.type === 'text') - .map((block) => block.text) - .join('\n') - if (systemText) { - messages.push({ role: 'system', content: systemText }) - } - } - } - - const toolCallIdToName = new Map() - for (const msg of params.messages) { - if (Array.isArray(msg.content)) { - for (const block of msg.content) { - if (block.type === 'tool_use') { - toolCallIdToName.set(block.id, block.name) - } - } - } - } - - // User/assistant messages - for (const msg of params.messages) { - if (typeof msg.content === 'string') { - messages.push({ - role: msg.role === 'user' ? 'user' : 'assistant', - content: msg.content - }) - } else if (Array.isArray(msg.content)) { - const textParts: TextPart[] = [] - const imageParts: ImagePart[] = [] - const reasoningParts: ReasoningPart[] = [] - const toolCallParts: ToolCallPart[] = [] - const toolResultParts: ToolResultPart[] = [] - - for (const block of msg.content) { - if (block.type === 'text') { - textParts.push({ type: 'text', text: block.text }) - } else if (block.type === 'thinking') { - reasoningParts.push({ type: 'reasoning', text: block.thinking }) - } else if (block.type === 'redacted_thinking') { - reasoningParts.push({ type: 'reasoning', text: block.data }) - } else if (block.type === 'image') { - const source = block.source - if (source.type === 'base64') { - imageParts.push({ type: 'image', image: `data:${source.media_type};base64,${source.data}` }) - } else if (source.type === 'url') { - imageParts.push({ type: 'image', image: source.url }) - } - } else if (block.type === 'tool_use') { - const options: ProviderOptions = {} - logger.debug('Processing tool call block', { block, msgRole: msg.role, model: params.model }) - if (isGemini3ModelId(params.model)) { - if (googleReasoningCache.get(`google-${block.name}`)) { - options.google = { - thoughtSignature: MAGIC_STRING - } - } - } - if (openRouterReasoningCache.get(`openrouter-${block.id}`)) { - options.openrouter = { - reasoning_details: - (sanitizeJson(openRouterReasoningCache.get(`openrouter-${block.id}`)) as JSONValue[]) || [] - } - } - toolCallParts.push({ - type: 'tool-call', - toolName: block.name, - toolCallId: block.id, - input: block.input, - providerOptions: options - }) - } else if (block.type === 'tool_result') { - // Look up toolName from the pre-built map (covers cross-message references) - const toolName = toolCallIdToName.get(block.tool_use_id) || 'unknown' - toolResultParts.push({ - type: 'tool-result', - toolCallId: block.tool_use_id, - toolName, - output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' } - }) - } - } - - if (toolResultParts.length > 0) { - messages.push({ role: 'tool', content: [...toolResultParts] }) - } - - if (msg.role === 'user') { - const userContent = [...textParts, ...imageParts] - if (userContent.length > 0) { - messages.push({ role: 'user', content: userContent }) - } - } else { - const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts] - if (assistantContent.length > 0) { - let providerOptions: ProviderOptions | undefined = undefined - if (openRouterReasoningCache.get('openrouter')) { - providerOptions = { - openrouter: { - reasoning_details: (sanitizeJson(openRouterReasoningCache.get('openrouter')) as JSONValue[]) || [] - } - } - } else if (isGemini3ModelId(params.model)) { - providerOptions = { - google: { - thoughtSignature: MAGIC_STRING - } - } - } - messages.push({ role: 'assistant', content: assistantContent, providerOptions }) - } - } - } - } - - return messages -} - -interface ExecuteStreamConfig { - provider: Provider - modelId: string - params: MessageCreateParams - middlewares?: LanguageModelV2Middleware[] - plugins?: AiPlugin[] - onEvent?: (event: Parameters[0]) => void -} - -/** - * Create AI SDK provider instance from config - * Similar to renderer's createAiSdkProvider - */ -async function createAiSdkProvider(config: AiSdkConfig): Promise { - let providerId = config.providerId - - // Handle special provider modes (same as renderer) - if (providerId === 'openai' && config.options?.mode === 'chat') { - providerId = 'openai-chat' - } else if (providerId === 'azure' && config.options?.mode === 'responses') { - providerId = 'azure-responses' - } else if (providerId === 'cherryin' && config.options?.mode === 'chat') { - providerId = 'cherryin-chat' - } - - const provider = await createProviderCore(providerId, config.options) - - return provider -} - -/** - * Prepare special provider configuration for providers that need dynamic tokens - * Similar to renderer's prepareSpecialProviderConfig - */ -async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkConfig): Promise { - switch (provider.id) { - case 'copilot': { - const storedHeaders = - ((await reduxService.select('state.copilot.defaultHeaders')) as Record | null) ?? {} - const headers: Record = { - ...COPILOT_DEFAULT_HEADERS, - ...storedHeaders - } - - try { - const { token } = await copilotService.getToken(null as any, headers) - config.options.apiKey = token - const existingHeaders = (config.options.headers as Record | undefined) ?? {} - config.options.headers = { - ...headers, - ...existingHeaders - } - } catch (error) { - logger.error('Failed to get Copilot token', error as Error) - throw new Error('Failed to get Copilot token. Please re-authorize Copilot.') - } - break - } - case 'anthropic': { - if (provider.authType === 'oauth') { - try { - const oauthToken = await anthropicService.getValidAccessToken() - if (!oauthToken) { - throw new Error('Anthropic OAuth token not available. Please re-authorize.') - } - config.options = { - ...config.options, - headers: { - ...(config.options.headers ? config.options.headers : {}), - 'Content-Type': 'application/json', - 'anthropic-version': '2023-06-01', - 'anthropic-beta': 'oauth-2025-04-20', - Authorization: `Bearer ${oauthToken}` - }, - baseURL: 'https://api.anthropic.com/v1', - apiKey: '' - } - } catch (error) { - logger.error('Failed to get Anthropic OAuth token', error as Error) - throw new Error('Failed to get Anthropic OAuth token. Please re-authorize.') - } - } - break - } - case 'cherryai': { - // Create a signed fetch wrapper for cherryai - const baseFetch = net.fetch as typeof globalThis.fetch - config.options.fetch = async (url: RequestInfo | URL, options?: RequestInit) => { - if (!options?.body) { - return baseFetch(url, options) - } - const signature = cherryaiGenerateSignature({ - method: 'POST', - path: '/chat/completions', - query: '', - body: JSON.parse(options.body as string) - }) - return baseFetch(url, { - ...options, - headers: { - ...(options.headers as Record), - ...signature - } - }) - } - break - } - } - return config -} - -function mapAnthropicThinkToAISdkProviderOptions( - provider: Provider, - config: MessageCreateParams['thinking'] -): ProviderOptions | undefined { - if (!config) return undefined - if (isAnthropicProvider(provider)) { - return { - anthropic: { - ...mapToAnthropicProviderOptions(config) - } - } - } - if (isGeminiProvider(provider)) { - return { - google: { - ...mapToGeminiProviderOptions(config) - } - } - } - if (isOpenAIProvider(provider)) { - return { - openai: { - ...mapToOpenAIProviderOptions(config) - } - } - } - if (provider.id === SystemProviderIds.openrouter) { - return { - openrouter: { - ...mapToOpenRouterProviderOptions(config) - } - } - } - return undefined -} - -function mapToAnthropicProviderOptions(config: NonNullable): AnthropicProviderOptions { - return { - thinking: { - type: config.type, - budgetTokens: config.type === 'enabled' ? config.budget_tokens : undefined - } - } -} - -function mapToGeminiProviderOptions( - config: NonNullable -): GoogleGenerativeAIProviderOptions { - return { - thinkingConfig: { - thinkingBudget: config.type === 'enabled' ? config.budget_tokens : -1, - includeThoughts: config.type === 'enabled' - } - } -} - -function mapToOpenAIProviderOptions( - config: NonNullable -): OpenAIResponsesProviderOptions { - return { - reasoningEffort: config.type === 'enabled' ? 'high' : 'none' - } -} - -function mapToOpenRouterProviderOptions( - config: NonNullable -): OpenRouterProviderOptions { - return { - reasoning: { - enabled: config.type === 'enabled', - effort: 'high' - } - } -} - -/** - * Core stream execution function - single source of truth for AI SDK calls - */ -async function executeStream(config: ExecuteStreamConfig): Promise { - const { provider, modelId, params, middlewares = [], plugins = [], onEvent } = config - - // Convert provider config to AI SDK config - let sdkConfig = providerToAiSdkConfig(provider, modelId) - - // Prepare special provider config (Copilot, Anthropic OAuth, etc.) - sdkConfig = await prepareSpecialProviderConfig(provider, sdkConfig) - - // Create provider instance and get language model - const aiSdkProvider = await createAiSdkProvider(sdkConfig) - const baseModel = aiSdkProvider.languageModel(modelId) - - // Apply middlewares if present - const model = - middlewares.length > 0 && typeof baseModel === 'object' - ? (wrapLanguageModel({ model: baseModel, middleware: middlewares }) as typeof baseModel) - : baseModel - - // Create executor with plugins - const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins) - - // Convert messages and tools - const coreMessages = convertAnthropicToAiMessages(params) - const tools = convertAnthropicToolsToAiSdk(params.tools) - - // Create the adapter - const adapter = new AiSdkToAnthropicSSE({ - model: `${provider.id}:${modelId}`, - onEvent: onEvent || (() => {}) - }) - - const result = await executor.streamText({ - model, - messages: coreMessages, - // FIXME: Claude Code传入的maxToken会超出有些模型限制,需做特殊处理,可能在v2好修复一点,现在维护的成本有点高 - // 已知: 豆包 - maxOutputTokens: params.max_tokens, - temperature: params.temperature, - topP: params.top_p, - topK: params.top_k, - stopSequences: params.stop_sequences, - stopWhen: stepCountIs(100), - headers: defaultAppHeaders(), - tools, - providerOptions: mapAnthropicThinkToAISdkProviderOptions(provider, params.thinking) - }) - - // Process the stream through the adapter - await adapter.processStream(result.fullStream) - - return adapter -} - -/** - * Stream a message request using AI SDK executor and convert to Anthropic SSE format - */ -export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise { - const { response, provider, modelId, params, onError, onComplete, middlewares = [], plugins = [] } = config - - logger.info('Starting unified message stream', { - providerId: provider.id, - providerType: provider.type, - modelId, - stream: params.stream, - middlewareCount: middlewares.length, - pluginCount: plugins.length - }) - - try { - response.setHeader('Content-Type', 'text/event-stream') - response.setHeader('Cache-Control', 'no-cache') - response.setHeader('Connection', 'keep-alive') - response.setHeader('X-Accel-Buffering', 'no') - - await executeStream({ - provider, - modelId, - params, - middlewares, - plugins, - onEvent: (event) => { - logger.silly('Streaming event', { eventType: event.type }) - const sseData = formatSSEEvent(event) - response.write(sseData) - } - }) - - // Send done marker - response.write(formatSSEDone()) - response.end() - - logger.info('Unified message stream completed', { providerId: provider.id, modelId }) - onComplete?.() - } catch (error) { - logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId }) - onError?.(error) - throw error - } -} - -/** - * Generate a non-streaming message response - * - * Uses simulateStreamingMiddleware to reuse the same streaming logic, - * similar to renderer's ModernAiProvider pattern. - */ -export async function generateUnifiedMessage( - providerOrConfig: Provider | GenerateUnifiedMessageConfig, - modelId?: string, - params?: MessageCreateParams -): Promise> { - // Support both old signature and new config-based signature - let config: GenerateUnifiedMessageConfig - if ('provider' in providerOrConfig && 'modelId' in providerOrConfig && 'params' in providerOrConfig) { - config = providerOrConfig - } else { - config = { - provider: providerOrConfig as Provider, - modelId: modelId!, - params: params! - } - } - - const { provider, middlewares = [], plugins = [] } = config - - logger.info('Starting unified message generation', { - providerId: provider.id, - providerType: provider.type, - modelId: config.modelId, - middlewareCount: middlewares.length, - pluginCount: plugins.length - }) - - try { - // Add simulateStreamingMiddleware to reuse streaming logic for non-streaming - const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares] - - const adapter = await executeStream({ - provider, - modelId: config.modelId, - params: config.params, - middlewares: allMiddlewares, - plugins - }) - - const finalResponse = adapter.buildNonStreamingResponse() - - logger.info('Unified message generation completed', { - providerId: provider.id, - modelId: config.modelId - }) - - return finalResponse - } catch (error) { - logger.error('Error in unified message generation', error as Error, { - providerId: provider.id, - modelId: config.modelId - }) - throw error - } -} - -export default { - streamUnifiedMessages, - generateUnifiedMessage -} diff --git a/src/main/apiServer/utils/index.ts b/src/main/apiServer/utils/index.ts index 17d3f9f088..20672d2b27 100644 --- a/src/main/apiServer/utils/index.ts +++ b/src/main/apiServer/utils/index.ts @@ -29,7 +29,7 @@ export async function getAvailableProviders(): Promise { } // Support all provider types that AI SDK can handle - // The unified-messages service uses AI SDK which supports many providers + // The ProxyStreamService uses AI SDK which supports many providers const supportedProviders = providers.filter((p: Provider) => p.enabled) // Cache the filtered results