mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-12 00:49:14 +08:00
refactor: streaming adapter
This commit is contained in:
parent
56dfd1de1e
commit
1755fd9bcb
@ -1,8 +1,9 @@
|
||||
import type { RawMessageStreamEvent } from '@anthropic-ai/sdk/resources/messages'
|
||||
import type { FinishReason, LanguageModelUsage, TextStreamPart, ToolSet } from 'ai'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '../AiSdkToAnthropicSSE'
|
||||
import { AnthropicSSEFormatter } from '../formatters/AnthropicSSEFormatter'
|
||||
import { AiSdkToAnthropicSSE } from '../stream/AiSdkToAnthropicSSE'
|
||||
|
||||
const createTextDelta = (text: string, id = 'text_0'): TextStreamPart<ToolSet> => ({
|
||||
type: 'text-delta',
|
||||
@ -24,17 +25,17 @@ const createFinish = (
|
||||
finishReason: FinishReason | undefined = 'stop',
|
||||
totalUsage?: Partial<LanguageModelUsage>
|
||||
): TextStreamPart<ToolSet> => {
|
||||
const defaultUsage: LanguageModelUsage = {
|
||||
const defaultUsage = {
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalTokens: 0
|
||||
}
|
||||
const event: TextStreamPart<ToolSet> = {
|
||||
// Cast to TextStreamPart to avoid strict type checking on optional fields
|
||||
return {
|
||||
type: 'finish',
|
||||
finishReason: finishReason || 'stop',
|
||||
totalUsage: { ...defaultUsage, ...totalUsage }
|
||||
}
|
||||
return event
|
||||
} as TextStreamPart<ToolSet>
|
||||
}
|
||||
|
||||
// Helper to create stream
|
||||
@ -49,19 +50,32 @@ function createMockStream(events: readonly TextStreamPart<ToolSet>[]) {
|
||||
})
|
||||
}
|
||||
|
||||
// Helper to collect all events from output stream
|
||||
async function collectEvents(stream: ReadableStream<RawMessageStreamEvent>): Promise<RawMessageStreamEvent[]> {
|
||||
const events: RawMessageStreamEvent[] = []
|
||||
const reader = stream.getReader()
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
events.push(value)
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
describe('AiSdkToAnthropicSSE', () => {
|
||||
describe('Text Processing', () => {
|
||||
it('should emit message_start and process text-delta events', async () => {
|
||||
const events: RawMessageStreamEvent[] = []
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: 'test:model',
|
||||
onEvent: (event) => events.push(event)
|
||||
})
|
||||
const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' })
|
||||
|
||||
// Create a mock stream with text events
|
||||
const stream = createMockStream([createTextDelta('Hello'), createTextDelta(' world'), createFinish('stop')])
|
||||
|
||||
await adapter.processStream(stream)
|
||||
const outputStream = adapter.transform(stream)
|
||||
const events = await collectEvents(outputStream)
|
||||
|
||||
// Verify message_start
|
||||
expect(events[0]).toMatchObject({
|
||||
@ -106,11 +120,7 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
})
|
||||
|
||||
it('should handle text-start and text-end events', async () => {
|
||||
const events: RawMessageStreamEvent[] = []
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: 'test:model',
|
||||
onEvent: (event) => events.push(event)
|
||||
})
|
||||
const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' })
|
||||
|
||||
const stream = createMockStream([
|
||||
createTextStart(),
|
||||
@ -119,7 +129,8 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
createFinish('stop')
|
||||
])
|
||||
|
||||
await adapter.processStream(stream)
|
||||
const outputStream = adapter.transform(stream)
|
||||
const events = await collectEvents(outputStream)
|
||||
|
||||
// Should have content_block_start, delta, and content_block_stop
|
||||
const blockEvents = events.filter((e) => e.type.startsWith('content_block'))
|
||||
@ -127,15 +138,12 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
})
|
||||
|
||||
it('should auto-start text block if not explicitly started', async () => {
|
||||
const events: RawMessageStreamEvent[] = []
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: 'test:model',
|
||||
onEvent: (event) => events.push(event)
|
||||
})
|
||||
const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' })
|
||||
|
||||
const stream = createMockStream([createTextDelta('Auto-started'), createFinish('stop')])
|
||||
|
||||
await adapter.processStream(stream)
|
||||
const outputStream = adapter.transform(stream)
|
||||
const events = await collectEvents(outputStream)
|
||||
|
||||
// Should automatically emit content_block_start
|
||||
expect(events.some((e) => e.type === 'content_block_start')).toBe(true)
|
||||
@ -144,11 +152,7 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
|
||||
describe('Tool Call Processing', () => {
|
||||
it('should emit tool_use block for tool-call events', async () => {
|
||||
const events: RawMessageStreamEvent[] = []
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: 'test:model',
|
||||
onEvent: (event) => events.push(event)
|
||||
})
|
||||
const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' })
|
||||
|
||||
const stream = createMockStream([
|
||||
{
|
||||
@ -160,7 +164,8 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
createFinish('tool-calls')
|
||||
])
|
||||
|
||||
await adapter.processStream(stream)
|
||||
const outputStream = adapter.transform(stream)
|
||||
const events = await collectEvents(outputStream)
|
||||
|
||||
// Find tool_use block events
|
||||
const blockStart = events.find((e) => {
|
||||
@ -195,11 +200,7 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
})
|
||||
|
||||
it('should not create duplicate tool blocks', async () => {
|
||||
const events: RawMessageStreamEvent[] = []
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: 'test:model',
|
||||
onEvent: (event) => events.push(event)
|
||||
})
|
||||
const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' })
|
||||
|
||||
const toolCallEvent: TextStreamPart<ToolSet> = {
|
||||
type: 'tool-call',
|
||||
@ -209,7 +210,8 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
}
|
||||
const stream = createMockStream([toolCallEvent, toolCallEvent, createFinish()])
|
||||
|
||||
await adapter.processStream(stream)
|
||||
const outputStream = adapter.transform(stream)
|
||||
const events = await collectEvents(outputStream)
|
||||
|
||||
// Should only have one tool_use block
|
||||
const toolBlocks = events.filter((e) => {
|
||||
@ -224,11 +226,7 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
|
||||
describe('Reasoning/Thinking Processing', () => {
|
||||
it('should emit thinking block for reasoning events', async () => {
|
||||
const events: RawMessageStreamEvent[] = []
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: 'test:model',
|
||||
onEvent: (event) => events.push(event)
|
||||
})
|
||||
const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' })
|
||||
|
||||
const stream = createMockStream([
|
||||
{ type: 'reasoning-start', id: 'reason_1' },
|
||||
@ -237,7 +235,8 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
createFinish()
|
||||
])
|
||||
|
||||
await adapter.processStream(stream)
|
||||
const outputStream = adapter.transform(stream)
|
||||
const events = await collectEvents(outputStream)
|
||||
|
||||
// Find thinking block events
|
||||
const blockStart = events.find((e) => {
|
||||
@ -262,11 +261,7 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
})
|
||||
|
||||
it('should handle multiple thinking blocks', async () => {
|
||||
const events: RawMessageStreamEvent[] = []
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: 'test:model',
|
||||
onEvent: (event) => events.push(event)
|
||||
})
|
||||
const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' })
|
||||
|
||||
const stream = createMockStream([
|
||||
{ type: 'reasoning-start', id: 'reason_1' },
|
||||
@ -278,7 +273,8 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
createFinish()
|
||||
])
|
||||
|
||||
await adapter.processStream(stream)
|
||||
const outputStream = adapter.transform(stream)
|
||||
const events = await collectEvents(outputStream)
|
||||
|
||||
// Should have two thinking blocks
|
||||
const thinkingBlocks = events.filter((e) => {
|
||||
@ -304,15 +300,12 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
]
|
||||
|
||||
for (const { aiSdkReason, expectedReason } of testCases) {
|
||||
const events: RawMessageStreamEvent[] = []
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: 'test:model',
|
||||
onEvent: (event) => events.push(event)
|
||||
})
|
||||
const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' })
|
||||
|
||||
const stream = createMockStream([createFinish(aiSdkReason)])
|
||||
|
||||
await adapter.processStream(stream)
|
||||
const outputStream = adapter.transform(stream)
|
||||
const events = await collectEvents(outputStream)
|
||||
|
||||
const messageDelta = events.find((e) => e.type === 'message_delta')
|
||||
if (messageDelta && messageDelta.type === 'message_delta') {
|
||||
@ -324,11 +317,9 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
|
||||
describe('Usage Tracking', () => {
|
||||
it('should track token usage', async () => {
|
||||
const events: RawMessageStreamEvent[] = []
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: 'test:model',
|
||||
inputTokens: 100,
|
||||
onEvent: (event) => events.push(event)
|
||||
inputTokens: 100
|
||||
})
|
||||
|
||||
const stream = createMockStream([
|
||||
@ -340,7 +331,8 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
})
|
||||
])
|
||||
|
||||
await adapter.processStream(stream)
|
||||
const outputStream = adapter.transform(stream)
|
||||
const events = await collectEvents(outputStream)
|
||||
|
||||
const messageDelta = events.find((e) => e.type === 'message_delta')
|
||||
if (messageDelta && messageDelta.type === 'message_delta') {
|
||||
@ -355,10 +347,7 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
|
||||
describe('Non-Streaming Response', () => {
|
||||
it('should build complete message for non-streaming', async () => {
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: 'test:model',
|
||||
onEvent: vi.fn()
|
||||
})
|
||||
const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' })
|
||||
|
||||
const stream = createMockStream([
|
||||
createTextDelta('Hello world'),
|
||||
@ -371,7 +360,14 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
createFinish('tool-calls', { inputTokens: 10, outputTokens: 20 })
|
||||
])
|
||||
|
||||
await adapter.processStream(stream)
|
||||
// Consume the stream to populate adapter state
|
||||
const outputStream = adapter.transform(stream)
|
||||
const reader = outputStream.getReader()
|
||||
while (true) {
|
||||
const { done } = await reader.read()
|
||||
if (done) break
|
||||
}
|
||||
reader.releaseLock()
|
||||
|
||||
const response = adapter.buildNonStreamingResponse()
|
||||
|
||||
@ -403,25 +399,20 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should throw on error events', async () => {
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: 'test:model',
|
||||
onEvent: vi.fn()
|
||||
})
|
||||
const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' })
|
||||
|
||||
const testError = new Error('Test error')
|
||||
const stream = createMockStream([{ type: 'error', error: testError }])
|
||||
|
||||
await expect(adapter.processStream(stream)).rejects.toThrow('Test error')
|
||||
const outputStream = adapter.transform(stream)
|
||||
|
||||
await expect(collectEvents(outputStream)).rejects.toThrow('Test error')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty stream', async () => {
|
||||
const events: RawMessageStreamEvent[] = []
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: 'test:model',
|
||||
onEvent: (event) => events.push(event)
|
||||
})
|
||||
const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' })
|
||||
|
||||
const stream = new ReadableStream<TextStreamPart<ToolSet>>({
|
||||
start(controller) {
|
||||
@ -429,7 +420,8 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
}
|
||||
})
|
||||
|
||||
await adapter.processStream(stream)
|
||||
const outputStream = adapter.transform(stream)
|
||||
const events = await collectEvents(outputStream)
|
||||
|
||||
// Should still emit message_start, message_delta, and message_stop
|
||||
expect(events.some((e) => e.type === 'message_start')).toBe(true)
|
||||
@ -438,15 +430,12 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
})
|
||||
|
||||
it('should handle empty text deltas', async () => {
|
||||
const events: RawMessageStreamEvent[] = []
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: 'test:model',
|
||||
onEvent: (event) => events.push(event)
|
||||
})
|
||||
const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' })
|
||||
|
||||
const stream = createMockStream([createTextDelta(''), createTextDelta(''), createFinish()])
|
||||
|
||||
await adapter.processStream(stream)
|
||||
const outputStream = adapter.transform(stream)
|
||||
const events = await collectEvents(outputStream)
|
||||
|
||||
// Should not emit deltas for empty text
|
||||
const deltas = events.filter((e) => e.type === 'content_block_delta')
|
||||
@ -454,8 +443,9 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('Utility Functions', () => {
|
||||
describe('AnthropicSSEFormatter', () => {
|
||||
it('should format SSE events correctly', () => {
|
||||
const formatter = new AnthropicSSEFormatter()
|
||||
const event: RawMessageStreamEvent = {
|
||||
type: 'message_start',
|
||||
message: {
|
||||
@ -476,7 +466,7 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
}
|
||||
}
|
||||
|
||||
const formatted = formatSSEEvent(event)
|
||||
const formatted = formatter.formatEvent(event)
|
||||
|
||||
expect(formatted).toContain('event: message_start')
|
||||
expect(formatted).toContain('data: ')
|
||||
@ -485,7 +475,8 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
})
|
||||
|
||||
it('should format SSE done marker correctly', () => {
|
||||
const done = formatSSEDone()
|
||||
const formatter = new AnthropicSSEFormatter()
|
||||
const done = formatter.formatDone()
|
||||
|
||||
expect(done).toBe('data: [DONE]\n\n')
|
||||
})
|
||||
@ -495,18 +486,14 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
it('should use provided message ID', () => {
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: 'test:model',
|
||||
messageId: 'custom_msg_123',
|
||||
onEvent: vi.fn()
|
||||
messageId: 'custom_msg_123'
|
||||
})
|
||||
|
||||
expect(adapter.getMessageId()).toBe('custom_msg_123')
|
||||
})
|
||||
|
||||
it('should generate message ID if not provided', () => {
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: 'test:model',
|
||||
onEvent: vi.fn()
|
||||
})
|
||||
const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' })
|
||||
|
||||
const messageId = adapter.getMessageId()
|
||||
expect(messageId).toMatch(/^msg_/)
|
||||
@ -514,23 +501,20 @@ describe('AiSdkToAnthropicSSE', () => {
|
||||
})
|
||||
|
||||
describe('Input Tokens', () => {
|
||||
it('should allow setting input tokens', () => {
|
||||
const events: RawMessageStreamEvent[] = []
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: 'test:model',
|
||||
onEvent: (event) => events.push(event)
|
||||
})
|
||||
it('should allow setting input tokens', async () => {
|
||||
const adapter = new AiSdkToAnthropicSSE({ model: 'test:model' })
|
||||
|
||||
adapter.setInputTokens(500)
|
||||
|
||||
const stream = createMockStream([createFinish()])
|
||||
|
||||
return adapter.processStream(stream).then(() => {
|
||||
const messageStart = events.find((e) => e.type === 'message_start')
|
||||
if (messageStart && messageStart.type === 'message_start') {
|
||||
expect(messageStart.message.usage.input_tokens).toBe(500)
|
||||
}
|
||||
})
|
||||
const outputStream = adapter.transform(stream)
|
||||
const events = await collectEvents(outputStream)
|
||||
|
||||
const messageStart = events.find((e) => e.type === 'message_start')
|
||||
if (messageStart && messageStart.type === 'message_start') {
|
||||
expect(messageStart.message.usage.input_tokens).toBe(500)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -0,0 +1,256 @@
|
||||
/**
|
||||
* Anthropic Message Converter
|
||||
*
|
||||
* Converts Anthropic Messages API format to AI SDK format.
|
||||
* Handles messages, tools, and special content types (images, thinking, tool results).
|
||||
*/
|
||||
|
||||
import type { LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
|
||||
import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
|
||||
import type {
|
||||
ImageBlockParam,
|
||||
MessageCreateParams,
|
||||
TextBlockParam,
|
||||
Tool as AnthropicTool
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
import { isGemini3ModelId } from '@shared/aiCore/middlewares'
|
||||
import type { Provider } from '@types'
|
||||
import type { ImagePart, JSONValue, ModelMessage, TextPart, Tool as AiSdkTool } from 'ai'
|
||||
import { tool, zodSchema } from 'ai'
|
||||
|
||||
import type { IMessageConverter, StreamTextOptions } from '../interfaces'
|
||||
import { type JsonSchemaLike, jsonSchemaToZod } from './json-schema-to-zod'
|
||||
import { mapAnthropicThinkingToProviderOptions } from './provider-options-mapper'
|
||||
|
||||
const MAGIC_STRING = 'skip_thought_signature_validator'
|
||||
|
||||
/**
|
||||
* Sanitize value for JSON serialization
|
||||
*/
|
||||
function sanitizeJson(value: unknown): JSONValue {
|
||||
return JSON.parse(JSON.stringify(value))
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert Anthropic tool result content to AI SDK format
|
||||
*/
|
||||
function convertToolResultToAiSdk(
|
||||
content: string | Array<TextBlockParam | ImageBlockParam>
|
||||
): LanguageModelV2ToolResultOutput {
|
||||
if (typeof content === 'string') {
|
||||
return { type: 'text', value: content }
|
||||
}
|
||||
const values: Array<{ type: 'text'; text: string } | { type: 'media'; data: string; mediaType: string }> = []
|
||||
for (const block of content) {
|
||||
if (block.type === 'text') {
|
||||
values.push({ type: 'text', text: block.text })
|
||||
} else if (block.type === 'image') {
|
||||
values.push({
|
||||
type: 'media',
|
||||
data: block.source.type === 'base64' ? block.source.data : block.source.url,
|
||||
mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png'
|
||||
})
|
||||
}
|
||||
}
|
||||
return { type: 'content', value: values }
|
||||
}
|
||||
|
||||
/**
|
||||
* Reasoning cache interface for storing provider-specific reasoning state
|
||||
*/
|
||||
export interface ReasoningCache {
|
||||
get(key: string): unknown
|
||||
set(key: string, value: unknown): void
|
||||
}
|
||||
|
||||
/**
|
||||
* Anthropic Message Converter
|
||||
*
|
||||
* Converts Anthropic MessageCreateParams to AI SDK format for unified processing.
|
||||
*/
|
||||
export class AnthropicMessageConverter implements IMessageConverter<MessageCreateParams> {
|
||||
private googleReasoningCache?: ReasoningCache
|
||||
private openRouterReasoningCache?: ReasoningCache
|
||||
|
||||
constructor(options?: { googleReasoningCache?: ReasoningCache; openRouterReasoningCache?: ReasoningCache }) {
|
||||
this.googleReasoningCache = options?.googleReasoningCache
|
||||
this.openRouterReasoningCache = options?.openRouterReasoningCache
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert Anthropic MessageCreateParams to AI SDK ModelMessage[]
|
||||
*/
|
||||
toAiSdkMessages(params: MessageCreateParams): ModelMessage[] {
|
||||
const messages: ModelMessage[] = []
|
||||
|
||||
// System message
|
||||
if (params.system) {
|
||||
if (typeof params.system === 'string') {
|
||||
messages.push({ role: 'system', content: params.system })
|
||||
} else if (Array.isArray(params.system)) {
|
||||
const systemText = params.system
|
||||
.filter((block) => block.type === 'text')
|
||||
.map((block) => block.text)
|
||||
.join('\n')
|
||||
if (systemText) {
|
||||
messages.push({ role: 'system', content: systemText })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build tool call ID to name mapping for tool results
|
||||
const toolCallIdToName = new Map<string, string>()
|
||||
for (const msg of params.messages) {
|
||||
if (Array.isArray(msg.content)) {
|
||||
for (const block of msg.content) {
|
||||
if (block.type === 'tool_use') {
|
||||
toolCallIdToName.set(block.id, block.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// User/assistant messages
|
||||
for (const msg of params.messages) {
|
||||
if (typeof msg.content === 'string') {
|
||||
messages.push({
|
||||
role: msg.role === 'user' ? 'user' : 'assistant',
|
||||
content: msg.content
|
||||
})
|
||||
} else if (Array.isArray(msg.content)) {
|
||||
const textParts: TextPart[] = []
|
||||
const imageParts: ImagePart[] = []
|
||||
const reasoningParts: ReasoningPart[] = []
|
||||
const toolCallParts: ToolCallPart[] = []
|
||||
const toolResultParts: ToolResultPart[] = []
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === 'text') {
|
||||
textParts.push({ type: 'text', text: block.text })
|
||||
} else if (block.type === 'thinking') {
|
||||
reasoningParts.push({ type: 'reasoning', text: block.thinking })
|
||||
} else if (block.type === 'redacted_thinking') {
|
||||
reasoningParts.push({ type: 'reasoning', text: block.data })
|
||||
} else if (block.type === 'image') {
|
||||
const source = block.source
|
||||
if (source.type === 'base64') {
|
||||
imageParts.push({ type: 'image', image: `data:${source.media_type};base64,${source.data}` })
|
||||
} else if (source.type === 'url') {
|
||||
imageParts.push({ type: 'image', image: source.url })
|
||||
}
|
||||
} else if (block.type === 'tool_use') {
|
||||
const options: ProviderOptions = {}
|
||||
if (isGemini3ModelId(params.model)) {
|
||||
if (this.googleReasoningCache?.get(`google-${block.name}`)) {
|
||||
options.google = {
|
||||
thoughtSignature: MAGIC_STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
if (this.openRouterReasoningCache?.get(`openrouter-${block.id}`)) {
|
||||
options.openrouter = {
|
||||
reasoning_details:
|
||||
(sanitizeJson(this.openRouterReasoningCache.get(`openrouter-${block.id}`)) as JSONValue[]) || []
|
||||
}
|
||||
}
|
||||
toolCallParts.push({
|
||||
type: 'tool-call',
|
||||
toolName: block.name,
|
||||
toolCallId: block.id,
|
||||
input: block.input,
|
||||
providerOptions: options
|
||||
})
|
||||
} else if (block.type === 'tool_result') {
|
||||
const toolName = toolCallIdToName.get(block.tool_use_id) || 'unknown'
|
||||
toolResultParts.push({
|
||||
type: 'tool-result',
|
||||
toolCallId: block.tool_use_id,
|
||||
toolName,
|
||||
output: block.content ? convertToolResultToAiSdk(block.content) : { type: 'text', value: '' }
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if (toolResultParts.length > 0) {
|
||||
messages.push({ role: 'tool', content: [...toolResultParts] })
|
||||
}
|
||||
|
||||
if (msg.role === 'user') {
|
||||
const userContent = [...textParts, ...imageParts]
|
||||
if (userContent.length > 0) {
|
||||
messages.push({ role: 'user', content: userContent })
|
||||
}
|
||||
} else {
|
||||
const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts]
|
||||
if (assistantContent.length > 0) {
|
||||
let providerOptions: ProviderOptions | undefined = undefined
|
||||
if (this.openRouterReasoningCache?.get('openrouter')) {
|
||||
providerOptions = {
|
||||
openrouter: {
|
||||
reasoning_details:
|
||||
(sanitizeJson(this.openRouterReasoningCache.get('openrouter')) as JSONValue[]) || []
|
||||
}
|
||||
}
|
||||
} else if (isGemini3ModelId(params.model)) {
|
||||
providerOptions = {
|
||||
google: {
|
||||
thoughtSignature: MAGIC_STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
messages.push({ role: 'assistant', content: assistantContent, providerOptions })
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert Anthropic tools to AI SDK tools
|
||||
*/
|
||||
toAiSdkTools(params: MessageCreateParams): Record<string, AiSdkTool> | undefined {
|
||||
const tools = params.tools
|
||||
if (!tools || tools.length === 0) return undefined
|
||||
|
||||
const aiSdkTools: Record<string, AiSdkTool> = {}
|
||||
for (const anthropicTool of tools) {
|
||||
if (anthropicTool.type === 'bash_20250124') continue
|
||||
const toolDef = anthropicTool as AnthropicTool
|
||||
const rawSchema = toolDef.input_schema
|
||||
const schema = jsonSchemaToZod(rawSchema as JsonSchemaLike)
|
||||
|
||||
const aiTool = tool({
|
||||
description: toolDef.description || '',
|
||||
inputSchema: zodSchema(schema)
|
||||
})
|
||||
|
||||
aiSdkTools[toolDef.name] = aiTool
|
||||
}
|
||||
return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract stream/generation options from Anthropic params
|
||||
*/
|
||||
extractStreamOptions(params: MessageCreateParams): StreamTextOptions {
|
||||
return {
|
||||
maxOutputTokens: params.max_tokens,
|
||||
temperature: params.temperature,
|
||||
topP: params.top_p,
|
||||
topK: params.top_k,
|
||||
stopSequences: params.stop_sequences
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract provider-specific options from Anthropic params
|
||||
* Maps thinking configuration to provider-specific parameters
|
||||
*/
|
||||
extractProviderOptions(provider: Provider, params: MessageCreateParams): ProviderOptions | undefined {
|
||||
return mapAnthropicThinkingToProviderOptions(provider, params.thinking)
|
||||
}
|
||||
}
|
||||
|
||||
export default AnthropicMessageConverter
|
||||
281
src/main/apiServer/adapters/converters/OpenAIMessageConverter.ts
Normal file
281
src/main/apiServer/adapters/converters/OpenAIMessageConverter.ts
Normal file
@ -0,0 +1,281 @@
|
||||
/**
|
||||
* OpenAI Message Converter
|
||||
*
|
||||
* Converts OpenAI Chat Completions API format to AI SDK format.
|
||||
* Handles messages, tools, and extended features like reasoning_content.
|
||||
*/
|
||||
|
||||
import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
|
||||
import type {
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionContentPart,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionUserMessageParam
|
||||
} from '@cherrystudio/openai/resources'
|
||||
import type { ChatCompletionCreateParamsBase } from '@cherrystudio/openai/resources/chat/completions'
|
||||
import type { Provider } from '@types'
|
||||
import type { ImagePart, ModelMessage, TextPart, Tool as AiSdkTool } from 'ai'
|
||||
import { tool, zodSchema } from 'ai'
|
||||
|
||||
import type { IMessageConverter, StreamTextOptions } from '../interfaces'
|
||||
import { type JsonSchemaLike, jsonSchemaToZod } from './json-schema-to-zod'
|
||||
import { mapReasoningEffortToProviderOptions } from './provider-options-mapper'
|
||||
|
||||
/**
|
||||
* Extended ChatCompletionCreateParams with reasoning_effort support
|
||||
* Extends the base OpenAI params to inherit all standard parameters
|
||||
*/
|
||||
export interface ExtendedChatCompletionCreateParams extends ChatCompletionCreateParamsBase {
|
||||
/**
|
||||
* Allow additional provider-specific parameters
|
||||
*/
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
/**
|
||||
* Extended assistant message with reasoning_content support (DeepSeek-style)
|
||||
*/
|
||||
interface ExtendedAssistantMessage extends ChatCompletionAssistantMessageParam {
|
||||
reasoning_content?: string | null
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI Message Converter
|
||||
*
|
||||
* Converts OpenAI Chat Completions API format to AI SDK format.
|
||||
* Supports standard OpenAI messages plus extended features:
|
||||
* - reasoning_content (DeepSeek-style thinking)
|
||||
* - reasoning_effort parameter
|
||||
*/
|
||||
export class OpenAIMessageConverter implements IMessageConverter<ExtendedChatCompletionCreateParams> {
|
||||
/**
|
||||
* Convert OpenAI ChatCompletionCreateParams to AI SDK ModelMessage[]
|
||||
*/
|
||||
toAiSdkMessages(params: ExtendedChatCompletionCreateParams): ModelMessage[] {
|
||||
const messages: ModelMessage[] = []
|
||||
|
||||
// Build tool call ID to name mapping for tool results
|
||||
const toolCallIdToName = new Map<string, string>()
|
||||
for (const msg of params.messages) {
|
||||
if (msg.role === 'assistant') {
|
||||
const assistantMsg = msg as ChatCompletionAssistantMessageParam
|
||||
if (assistantMsg.tool_calls) {
|
||||
for (const toolCall of assistantMsg.tool_calls) {
|
||||
// Only handle function tool calls
|
||||
if (toolCall.type === 'function') {
|
||||
toolCallIdToName.set(toolCall.id, toolCall.function.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const msg of params.messages) {
|
||||
const converted = this.convertMessage(msg, toolCallIdToName)
|
||||
if (converted) {
|
||||
messages.push(...converted)
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a single OpenAI message to AI SDK message(s)
|
||||
*/
|
||||
private convertMessage(
|
||||
msg: ChatCompletionMessageParam,
|
||||
toolCallIdToName: Map<string, string>
|
||||
): ModelMessage[] | null {
|
||||
switch (msg.role) {
|
||||
case 'system':
|
||||
return this.convertSystemMessage(msg)
|
||||
case 'user':
|
||||
return this.convertUserMessage(msg as ChatCompletionUserMessageParam)
|
||||
case 'assistant':
|
||||
return this.convertAssistantMessage(msg as ExtendedAssistantMessage)
|
||||
case 'tool':
|
||||
return this.convertToolMessage(msg as ChatCompletionToolMessageParam, toolCallIdToName)
|
||||
case 'function':
|
||||
// Legacy function messages - skip or handle as needed
|
||||
return null
|
||||
default:
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert system message
|
||||
*/
|
||||
private convertSystemMessage(msg: ChatCompletionMessageParam): ModelMessage[] {
|
||||
if (msg.role !== 'system') return []
|
||||
|
||||
// Handle string content
|
||||
if (typeof msg.content === 'string') {
|
||||
return [{ role: 'system', content: msg.content }]
|
||||
}
|
||||
|
||||
// Handle array content (system messages can have text parts)
|
||||
if (Array.isArray(msg.content)) {
|
||||
const textContent = msg.content
|
||||
.filter((part): part is { type: 'text'; text: string } => part.type === 'text')
|
||||
.map((part) => part.text)
|
||||
.join('\n')
|
||||
if (textContent) {
|
||||
return [{ role: 'system', content: textContent }]
|
||||
}
|
||||
}
|
||||
|
||||
return []
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert user message
|
||||
*/
|
||||
private convertUserMessage(msg: ChatCompletionUserMessageParam): ModelMessage[] {
|
||||
// Handle string content
|
||||
if (typeof msg.content === 'string') {
|
||||
return [{ role: 'user', content: msg.content }]
|
||||
}
|
||||
|
||||
// Handle array content (text + images)
|
||||
if (Array.isArray(msg.content)) {
|
||||
const parts: (TextPart | ImagePart)[] = []
|
||||
|
||||
for (const part of msg.content as ChatCompletionContentPart[]) {
|
||||
if (part.type === 'text') {
|
||||
parts.push({ type: 'text', text: part.text })
|
||||
} else if (part.type === 'image_url') {
|
||||
parts.push({ type: 'image', image: part.image_url.url })
|
||||
}
|
||||
}
|
||||
|
||||
if (parts.length > 0) {
|
||||
return [{ role: 'user', content: parts }]
|
||||
}
|
||||
}
|
||||
|
||||
return []
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert assistant message
|
||||
*/
|
||||
private convertAssistantMessage(msg: ExtendedAssistantMessage): ModelMessage[] {
|
||||
const parts: (TextPart | ReasoningPart | ToolCallPart)[] = []
|
||||
|
||||
// Handle reasoning_content (DeepSeek-style thinking)
|
||||
if (msg.reasoning_content) {
|
||||
parts.push({ type: 'reasoning', text: msg.reasoning_content })
|
||||
}
|
||||
|
||||
// Handle text content
|
||||
if (msg.content) {
|
||||
if (typeof msg.content === 'string') {
|
||||
parts.push({ type: 'text', text: msg.content })
|
||||
} else if (Array.isArray(msg.content)) {
|
||||
for (const part of msg.content) {
|
||||
if (part.type === 'text') {
|
||||
parts.push({ type: 'text', text: part.text })
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
if (msg.tool_calls && msg.tool_calls.length > 0) {
|
||||
for (const toolCall of msg.tool_calls) {
|
||||
// Only handle function tool calls
|
||||
if (toolCall.type !== 'function') continue
|
||||
|
||||
let input: unknown
|
||||
try {
|
||||
input = JSON.parse(toolCall.function.arguments)
|
||||
} catch {
|
||||
input = { raw: toolCall.function.arguments }
|
||||
}
|
||||
|
||||
parts.push({
|
||||
type: 'tool-call',
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.function.name,
|
||||
input
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if (parts.length > 0) {
|
||||
return [{ role: 'assistant', content: parts }]
|
||||
}
|
||||
|
||||
return []
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert tool result message
|
||||
*/
|
||||
private convertToolMessage(
|
||||
msg: ChatCompletionToolMessageParam,
|
||||
toolCallIdToName: Map<string, string>
|
||||
): ModelMessage[] {
|
||||
const toolName = toolCallIdToName.get(msg.tool_call_id) || 'unknown'
|
||||
|
||||
const toolResultPart: ToolResultPart = {
|
||||
type: 'tool-result',
|
||||
toolCallId: msg.tool_call_id,
|
||||
toolName,
|
||||
output: { type: 'text', value: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content) }
|
||||
}
|
||||
|
||||
return [{ role: 'tool', content: [toolResultPart] }]
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert OpenAI tools to AI SDK tools
|
||||
*/
|
||||
toAiSdkTools(params: ExtendedChatCompletionCreateParams): Record<string, AiSdkTool> | undefined {
|
||||
const tools = params.tools
|
||||
if (!tools || tools.length === 0) return undefined
|
||||
|
||||
const aiSdkTools: Record<string, AiSdkTool> = {}
|
||||
|
||||
for (const toolDef of tools) {
|
||||
if (toolDef.type !== 'function') continue
|
||||
|
||||
const rawSchema = toolDef.function.parameters
|
||||
const schema = rawSchema ? jsonSchemaToZod(rawSchema as JsonSchemaLike) : jsonSchemaToZod({ type: 'object' })
|
||||
|
||||
const aiTool = tool({
|
||||
description: toolDef.function.description || '',
|
||||
inputSchema: zodSchema(schema)
|
||||
})
|
||||
|
||||
aiSdkTools[toolDef.function.name] = aiTool
|
||||
}
|
||||
|
||||
return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract stream/generation options from OpenAI params
|
||||
*/
|
||||
extractStreamOptions(params: ExtendedChatCompletionCreateParams): StreamTextOptions {
|
||||
return {
|
||||
maxOutputTokens: params.max_tokens as number | undefined,
|
||||
temperature: params.temperature as number | undefined,
|
||||
topP: params.top_p as number | undefined,
|
||||
stopSequences: params.stop as string[] | undefined
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract provider-specific options from OpenAI params
|
||||
* Maps reasoning_effort to provider-specific thinking/reasoning parameters
|
||||
*/
|
||||
extractProviderOptions(provider: Provider, params: ExtendedChatCompletionCreateParams): ProviderOptions | undefined {
|
||||
return mapReasoningEffortToProviderOptions(provider, params.reasoning_effort)
|
||||
}
|
||||
}
|
||||
|
||||
export default OpenAIMessageConverter
|
||||
2
src/main/apiServer/adapters/converters/index.ts
Normal file
2
src/main/apiServer/adapters/converters/index.ts
Normal file
@ -0,0 +1,2 @@
|
||||
export { AnthropicMessageConverter } from './AnthropicMessageConverter'
|
||||
export { type JsonSchemaLike, jsonSchemaToZod } from './json-schema-to-zod'
|
||||
141
src/main/apiServer/adapters/converters/json-schema-to-zod.ts
Normal file
141
src/main/apiServer/adapters/converters/json-schema-to-zod.ts
Normal file
@ -0,0 +1,141 @@
|
||||
/**
|
||||
* JSON Schema to Zod Converter
|
||||
*
|
||||
* Converts JSON Schema definitions to Zod schemas for runtime validation.
|
||||
* This is used to convert tool input schemas from Anthropic format to AI SDK format.
|
||||
*/
|
||||
|
||||
import type { JSONSchema7 } from '@ai-sdk/provider'
|
||||
import * as z from 'zod'
|
||||
|
||||
/**
|
||||
* JSON Schema type alias
|
||||
*/
|
||||
export type JsonSchemaLike = JSONSchema7
|
||||
|
||||
/**
|
||||
* Convert JSON Schema to Zod schema
|
||||
*
|
||||
* Handles:
|
||||
* - Primitive types (string, number, integer, boolean, null)
|
||||
* - Complex types (object, array)
|
||||
* - Enums
|
||||
* - Union types (type: ["string", "null"])
|
||||
* - Required/optional fields
|
||||
* - Validation constraints (min/max, pattern, etc.)
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const zodSchema = jsonSchemaToZod({
|
||||
* type: 'object',
|
||||
* properties: {
|
||||
* name: { type: 'string' },
|
||||
* age: { type: 'integer', minimum: 0 }
|
||||
* },
|
||||
* required: ['name']
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export function jsonSchemaToZod(schema: JsonSchemaLike): z.ZodTypeAny {
|
||||
const schemaType = schema.type
|
||||
const enumValues = schema.enum
|
||||
const description = schema.description
|
||||
|
||||
// Handle enum first
|
||||
if (enumValues && Array.isArray(enumValues) && enumValues.length > 0) {
|
||||
if (enumValues.every((v) => typeof v === 'string')) {
|
||||
const zodEnum = z.enum(enumValues as [string, ...string[]])
|
||||
return description ? zodEnum.describe(description) : zodEnum
|
||||
}
|
||||
// For non-string enums, use union of literals
|
||||
const literals = enumValues.map((v) => z.literal(v as string | number | boolean))
|
||||
if (literals.length === 1) {
|
||||
return description ? literals[0].describe(description) : literals[0]
|
||||
}
|
||||
const zodUnion = z.union(literals as unknown as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
|
||||
return description ? zodUnion.describe(description) : zodUnion
|
||||
}
|
||||
|
||||
// Handle union types (type: ["string", "null"])
|
||||
if (Array.isArray(schemaType)) {
|
||||
const schemas = schemaType.map((t) =>
|
||||
jsonSchemaToZod({
|
||||
...schema,
|
||||
type: t,
|
||||
enum: undefined
|
||||
})
|
||||
)
|
||||
if (schemas.length === 1) {
|
||||
return schemas[0]
|
||||
}
|
||||
return z.union(schemas as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
|
||||
}
|
||||
|
||||
// Handle by type
|
||||
switch (schemaType) {
|
||||
case 'string': {
|
||||
let zodString = z.string()
|
||||
if (typeof schema.minLength === 'number') zodString = zodString.min(schema.minLength)
|
||||
if (typeof schema.maxLength === 'number') zodString = zodString.max(schema.maxLength)
|
||||
if (typeof schema.pattern === 'string') zodString = zodString.regex(new RegExp(schema.pattern))
|
||||
return description ? zodString.describe(description) : zodString
|
||||
}
|
||||
|
||||
case 'number':
|
||||
case 'integer': {
|
||||
let zodNumber = schemaType === 'integer' ? z.number().int() : z.number()
|
||||
if (typeof schema.minimum === 'number') zodNumber = zodNumber.min(schema.minimum)
|
||||
if (typeof schema.maximum === 'number') zodNumber = zodNumber.max(schema.maximum)
|
||||
return description ? zodNumber.describe(description) : zodNumber
|
||||
}
|
||||
|
||||
case 'boolean': {
|
||||
const zodBoolean = z.boolean()
|
||||
return description ? zodBoolean.describe(description) : zodBoolean
|
||||
}
|
||||
|
||||
case 'null':
|
||||
return z.null()
|
||||
|
||||
case 'array': {
|
||||
const items = schema.items
|
||||
let zodArray: z.ZodArray<z.ZodTypeAny>
|
||||
if (items && typeof items === 'object' && !Array.isArray(items)) {
|
||||
zodArray = z.array(jsonSchemaToZod(items as JsonSchemaLike))
|
||||
} else {
|
||||
zodArray = z.array(z.unknown())
|
||||
}
|
||||
if (typeof schema.minItems === 'number') zodArray = zodArray.min(schema.minItems)
|
||||
if (typeof schema.maxItems === 'number') zodArray = zodArray.max(schema.maxItems)
|
||||
return description ? zodArray.describe(description) : zodArray
|
||||
}
|
||||
|
||||
case 'object': {
|
||||
const properties = schema.properties
|
||||
const required = schema.required || []
|
||||
|
||||
// Always use z.object() to ensure "properties" field is present in output schema
|
||||
// OpenAI requires explicit properties field even for empty objects
|
||||
const shape: Record<string, z.ZodTypeAny> = {}
|
||||
if (properties && typeof properties === 'object') {
|
||||
for (const [key, propSchema] of Object.entries(properties)) {
|
||||
if (typeof propSchema === 'boolean') {
|
||||
shape[key] = propSchema ? z.unknown() : z.never()
|
||||
} else {
|
||||
const zodProp = jsonSchemaToZod(propSchema as JsonSchemaLike)
|
||||
shape[key] = required.includes(key) ? zodProp : zodProp.optional()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const zodObject = z.object(shape)
|
||||
return description ? zodObject.describe(description) : zodObject
|
||||
}
|
||||
|
||||
default:
|
||||
// Unknown type, use z.unknown()
|
||||
return z.unknown()
|
||||
}
|
||||
}
|
||||
|
||||
export default jsonSchemaToZod
|
||||
@ -0,0 +1,194 @@
|
||||
/**
|
||||
* Provider Options Mapper
|
||||
*
|
||||
* Maps input format-specific thinking/reasoning configuration to
|
||||
* AI SDK provider-specific options.
|
||||
*
|
||||
* TODO: Refactor this module:
|
||||
* 1. Move shared reasoning config from src/renderer/src/config/models/reasoning.ts to @shared
|
||||
* 2. Reuse MODEL_SUPPORTED_REASONING_EFFORT for budgetMap instead of hardcoding
|
||||
* 3. For unsupported providers, pass through reasoning params in OpenAI-compatible format
|
||||
* instead of returning undefined (all requests should transparently forward reasoning config)
|
||||
* 4. Both Anthropic and OpenAI converters should handle OpenAI-compatible mapping
|
||||
*/
|
||||
|
||||
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
|
||||
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
|
||||
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
|
||||
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
|
||||
import type { ProviderOptions } from '@ai-sdk/provider-utils'
|
||||
import type { XaiProviderOptions } from '@ai-sdk/xai'
|
||||
import type { MessageCreateParams } from '@anthropic-ai/sdk/resources/messages'
|
||||
import { ReasoningEffort } from '@cherrystudio/openai/resources'
|
||||
import type { OpenRouterProviderOptions } from '@openrouter/ai-sdk-provider'
|
||||
import { SystemProviderIds } from '@shared/types'
|
||||
import { isAnthropicProvider, isAwsBedrockProvider, isGeminiProvider, isOpenAIProvider } from '@shared/utils/provider'
|
||||
import type { Provider } from '@types'
|
||||
|
||||
/**
|
||||
* Map Anthropic thinking configuration to AI SDK provider options
|
||||
*
|
||||
* Converts Anthropic's thinking.type and budget_tokens to provider-specific
|
||||
* parameters for various AI providers.
|
||||
*/
|
||||
export function mapAnthropicThinkingToProviderOptions(
|
||||
provider: Provider,
|
||||
config: MessageCreateParams['thinking']
|
||||
): ProviderOptions | undefined {
|
||||
if (!config) return undefined
|
||||
|
||||
// Anthropic provider
|
||||
if (isAnthropicProvider(provider)) {
|
||||
return {
|
||||
anthropic: {
|
||||
thinking: {
|
||||
type: config.type,
|
||||
budgetTokens: config.type === 'enabled' ? config.budget_tokens : undefined
|
||||
}
|
||||
} as AnthropicProviderOptions
|
||||
}
|
||||
}
|
||||
|
||||
// Google/Gemini provider
|
||||
if (isGeminiProvider(provider)) {
|
||||
return {
|
||||
google: {
|
||||
thinkingConfig: {
|
||||
thinkingBudget: config.type === 'enabled' ? config.budget_tokens : -1,
|
||||
includeThoughts: config.type === 'enabled'
|
||||
}
|
||||
} as GoogleGenerativeAIProviderOptions
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI provider (Responses API)
|
||||
if (isOpenAIProvider(provider)) {
|
||||
return {
|
||||
openai: {
|
||||
reasoningEffort: config.type === 'enabled' ? 'high' : 'none'
|
||||
} as OpenAIResponsesProviderOptions
|
||||
}
|
||||
}
|
||||
|
||||
// OpenRouter provider
|
||||
if (provider.id === SystemProviderIds.openrouter) {
|
||||
return {
|
||||
openrouter: {
|
||||
reasoning: {
|
||||
enabled: config.type === 'enabled',
|
||||
effort: 'high'
|
||||
}
|
||||
} as OpenRouterProviderOptions
|
||||
}
|
||||
}
|
||||
|
||||
// XAI/Grok provider
|
||||
if (provider.id === SystemProviderIds.grok) {
|
||||
return {
|
||||
xai: {
|
||||
reasoningEffort: config.type === 'enabled' ? 'high' : undefined
|
||||
} as XaiProviderOptions
|
||||
}
|
||||
}
|
||||
|
||||
// AWS Bedrock provider
|
||||
if (isAwsBedrockProvider(provider)) {
|
||||
return {
|
||||
bedrock: {
|
||||
reasoningConfig: {
|
||||
type: config.type,
|
||||
budgetTokens: config.type === 'enabled' ? config.budget_tokens : undefined
|
||||
}
|
||||
} as BedrockProviderOptions
|
||||
}
|
||||
}
|
||||
|
||||
// For other providers, thinking options are not automatically mapped
|
||||
return undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Map OpenAI-style reasoning_effort to AI SDK provider options
|
||||
*
|
||||
* Converts reasoning_effort (low/medium/high) to provider-specific
|
||||
* thinking/reasoning parameters.
|
||||
*/
|
||||
export function mapReasoningEffortToProviderOptions(
|
||||
provider: Provider,
|
||||
reasoningEffort?: ReasoningEffort
|
||||
): ProviderOptions | undefined {
|
||||
if (!reasoningEffort) return undefined
|
||||
|
||||
// TODO: Import from @shared/config/reasoning instead of hardcoding
|
||||
// Should reuse MODEL_SUPPORTED_REASONING_EFFORT from reasoning.ts
|
||||
const budgetMap = { low: 5000, medium: 10000, high: 20000 }
|
||||
|
||||
// Anthropic: Map to thinking.budgetTokens
|
||||
if (isAnthropicProvider(provider)) {
|
||||
return {
|
||||
anthropic: {
|
||||
thinking: {
|
||||
type: 'enabled',
|
||||
budgetTokens: budgetMap[reasoningEffort]
|
||||
}
|
||||
} as AnthropicProviderOptions
|
||||
}
|
||||
}
|
||||
|
||||
// Google/Gemini: Map to thinkingConfig.thinkingBudget
|
||||
if (isGeminiProvider(provider)) {
|
||||
return {
|
||||
google: {
|
||||
thinkingConfig: {
|
||||
thinkingBudget: budgetMap[reasoningEffort],
|
||||
includeThoughts: true
|
||||
}
|
||||
} as GoogleGenerativeAIProviderOptions
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI: Use reasoningEffort directly
|
||||
if (isOpenAIProvider(provider)) {
|
||||
return {
|
||||
openai: {
|
||||
reasoningEffort: reasoningEffort === 'low' ? 'none' : reasoningEffort
|
||||
} as OpenAIResponsesProviderOptions
|
||||
}
|
||||
}
|
||||
|
||||
// OpenRouter: Map to reasoning.effort
|
||||
if (provider.id === SystemProviderIds.openrouter) {
|
||||
return {
|
||||
openrouter: {
|
||||
reasoning: {
|
||||
enabled: true,
|
||||
effort: reasoningEffort
|
||||
}
|
||||
} as OpenRouterProviderOptions
|
||||
}
|
||||
}
|
||||
|
||||
// XAI/Grok: Map to reasoningEffort
|
||||
if (provider.id === SystemProviderIds.grok) {
|
||||
return {
|
||||
xai: {
|
||||
reasoningEffort: reasoningEffort === 'low' ? undefined : reasoningEffort
|
||||
} as XaiProviderOptions
|
||||
}
|
||||
}
|
||||
|
||||
// AWS Bedrock: Map to reasoningConfig
|
||||
if (isAwsBedrockProvider(provider)) {
|
||||
return {
|
||||
bedrock: {
|
||||
reasoningConfig: {
|
||||
type: 'enabled',
|
||||
budgetTokens: budgetMap[reasoningEffort]
|
||||
}
|
||||
} as BedrockProviderOptions
|
||||
}
|
||||
}
|
||||
|
||||
// For other providers, reasoning effort is not automatically mapped
|
||||
return undefined
|
||||
}
|
||||
@ -0,0 +1,82 @@
|
||||
/**
|
||||
* Message Converter Factory
|
||||
*
|
||||
* Factory for creating message converters based on input format.
|
||||
* Uses generics for type-safe converter creation.
|
||||
*/
|
||||
|
||||
import type { MessageCreateParams } from '@anthropic-ai/sdk/resources/messages'
|
||||
|
||||
import { AnthropicMessageConverter, type ReasoningCache } from '../converters/AnthropicMessageConverter'
|
||||
import { type ExtendedChatCompletionCreateParams, OpenAIMessageConverter } from '../converters/OpenAIMessageConverter'
|
||||
import type { IMessageConverter, InputFormat } from '../interfaces'
|
||||
|
||||
/**
|
||||
* Type mapping from input format to parameter type
|
||||
*/
|
||||
export type InputParamsMap = {
|
||||
openai: ExtendedChatCompletionCreateParams
|
||||
anthropic: MessageCreateParams
|
||||
}
|
||||
|
||||
/**
|
||||
* Options for creating converters
|
||||
*/
|
||||
export interface ConverterOptions {
|
||||
googleReasoningCache?: ReasoningCache
|
||||
openRouterReasoningCache?: ReasoningCache
|
||||
}
|
||||
|
||||
/**
|
||||
* Message Converter Factory
|
||||
*
|
||||
* Creates message converters for different input formats with type safety.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const converter = MessageConverterFactory.create('anthropic', {
|
||||
* googleReasoningCache,
|
||||
* openRouterReasoningCache
|
||||
* })
|
||||
* // converter is typed as IMessageConverter<MessageCreateParams>
|
||||
* const messages = converter.toAiSdkMessages(params)
|
||||
* const options = converter.extractStreamOptions(params)
|
||||
* ```
|
||||
*/
|
||||
export class MessageConverterFactory {
|
||||
/**
|
||||
* Create a message converter for the specified input format
|
||||
*
|
||||
* @param format - The input format ('openai' | 'anthropic')
|
||||
* @param options - Optional converter options
|
||||
* @returns A typed message converter instance
|
||||
*/
|
||||
static create<T extends InputFormat>(
|
||||
format: T,
|
||||
options: ConverterOptions = {}
|
||||
): IMessageConverter<InputParamsMap[T]> {
|
||||
if (format === 'openai') {
|
||||
return new OpenAIMessageConverter() as IMessageConverter<InputParamsMap[T]>
|
||||
}
|
||||
return new AnthropicMessageConverter({
|
||||
googleReasoningCache: options.googleReasoningCache,
|
||||
openRouterReasoningCache: options.openRouterReasoningCache
|
||||
}) as IMessageConverter<InputParamsMap[T]>
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a format is supported
|
||||
*/
|
||||
static supportsFormat(format: string): format is InputFormat {
|
||||
return format === 'openai' || format === 'anthropic'
|
||||
}
|
||||
|
||||
/**
|
||||
* Get list of all supported formats
|
||||
*/
|
||||
static getSupportedFormats(): InputFormat[] {
|
||||
return ['openai', 'anthropic']
|
||||
}
|
||||
}
|
||||
|
||||
export default MessageConverterFactory
|
||||
127
src/main/apiServer/adapters/factory/StreamAdapterFactory.ts
Normal file
127
src/main/apiServer/adapters/factory/StreamAdapterFactory.ts
Normal file
@ -0,0 +1,127 @@
|
||||
/**
|
||||
* Stream Adapter Factory
|
||||
*
|
||||
* Factory for creating stream adapters based on output format.
|
||||
* Uses a registry pattern for extensibility.
|
||||
*/
|
||||
|
||||
import { AnthropicSSEFormatter } from '../formatters/AnthropicSSEFormatter'
|
||||
import { OpenAISSEFormatter } from '../formatters/OpenAISSEFormatter'
|
||||
import type { ISSEFormatter, IStreamAdapter, OutputFormat, StreamAdapterOptions } from '../interfaces'
|
||||
import { AiSdkToAnthropicSSE } from '../stream/AiSdkToAnthropicSSE'
|
||||
import { AiSdkToOpenAISSE } from '../stream/AiSdkToOpenAISSE'
|
||||
|
||||
/**
|
||||
* Registry entry for adapter and formatter classes
|
||||
*/
|
||||
interface RegistryEntry {
|
||||
adapterClass: new (options: StreamAdapterOptions) => IStreamAdapter
|
||||
formatterClass: new () => ISSEFormatter
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream Adapter Factory
|
||||
*
|
||||
* Creates stream adapters and formatters for different output formats.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const adapter = StreamAdapterFactory.createAdapter('anthropic', { model: 'claude-3' })
|
||||
* const outputStream = adapter.transform(aiSdkStream)
|
||||
*
|
||||
* const formatter = StreamAdapterFactory.getFormatter('anthropic')
|
||||
* for await (const event of outputStream) {
|
||||
* response.write(formatter.formatEvent(event))
|
||||
* }
|
||||
* response.write(formatter.formatDone())
|
||||
* ```
|
||||
*/
|
||||
export class StreamAdapterFactory {
|
||||
private static registry = new Map<OutputFormat, RegistryEntry>([
|
||||
[
|
||||
'anthropic',
|
||||
{
|
||||
adapterClass: AiSdkToAnthropicSSE,
|
||||
formatterClass: AnthropicSSEFormatter
|
||||
}
|
||||
],
|
||||
[
|
||||
'openai',
|
||||
{
|
||||
adapterClass: AiSdkToOpenAISSE,
|
||||
formatterClass: OpenAISSEFormatter
|
||||
}
|
||||
]
|
||||
])
|
||||
|
||||
/**
|
||||
* Create a stream adapter for the specified output format
|
||||
*
|
||||
* @param format - The target output format
|
||||
* @param options - Adapter options (model, messageId, etc.)
|
||||
* @returns A stream adapter instance
|
||||
* @throws Error if format is not supported
|
||||
*/
|
||||
static createAdapter(format: OutputFormat, options: StreamAdapterOptions): IStreamAdapter {
|
||||
const entry = this.registry.get(format)
|
||||
if (!entry) {
|
||||
throw new Error(
|
||||
`Unsupported output format: ${format}. Supported formats: ${this.getSupportedFormats().join(', ')}`
|
||||
)
|
||||
}
|
||||
return new entry.adapterClass(options)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get an SSE formatter for the specified output format
|
||||
*
|
||||
* @param format - The target output format
|
||||
* @returns An SSE formatter instance
|
||||
* @throws Error if format is not supported
|
||||
*/
|
||||
static getFormatter(format: OutputFormat): ISSEFormatter {
|
||||
const entry = this.registry.get(format)
|
||||
if (!entry) {
|
||||
throw new Error(
|
||||
`Unsupported output format: ${format}. Supported formats: ${this.getSupportedFormats().join(', ')}`
|
||||
)
|
||||
}
|
||||
return new entry.formatterClass()
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a format is supported
|
||||
*
|
||||
* @param format - The format to check
|
||||
* @returns true if the format is supported
|
||||
*/
|
||||
static supportsFormat(format: OutputFormat): boolean {
|
||||
return this.registry.has(format)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get list of all supported formats
|
||||
*
|
||||
* @returns Array of supported format names
|
||||
*/
|
||||
static getSupportedFormats(): OutputFormat[] {
|
||||
return Array.from(this.registry.keys())
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a new adapter and formatter for a format
|
||||
*
|
||||
* @param format - The format name
|
||||
* @param adapterClass - The adapter class constructor
|
||||
* @param formatterClass - The formatter class constructor
|
||||
*/
|
||||
static registerAdapter(
|
||||
format: OutputFormat,
|
||||
adapterClass: new (options: StreamAdapterOptions) => IStreamAdapter,
|
||||
formatterClass: new () => ISSEFormatter
|
||||
): void {
|
||||
this.registry.set(format, { adapterClass, formatterClass })
|
||||
}
|
||||
}
|
||||
|
||||
export default StreamAdapterFactory
|
||||
1
src/main/apiServer/adapters/factory/index.ts
Normal file
1
src/main/apiServer/adapters/factory/index.ts
Normal file
@ -0,0 +1 @@
|
||||
export { StreamAdapterFactory } from './StreamAdapterFactory'
|
||||
@ -0,0 +1,36 @@
|
||||
/**
|
||||
* Anthropic SSE Formatter
|
||||
*
|
||||
* Formats Anthropic message stream events for Server-Sent Events.
|
||||
*/
|
||||
|
||||
import type { RawMessageStreamEvent } from '@anthropic-ai/sdk/resources/messages'
|
||||
|
||||
import type { ISSEFormatter } from '../interfaces'
|
||||
|
||||
/**
|
||||
* Anthropic SSE Formatter
|
||||
*
|
||||
* Formats events according to Anthropic's streaming API specification:
|
||||
* - event: {type}\n
|
||||
* - data: {json}\n\n
|
||||
*
|
||||
* @see https://docs.anthropic.com/en/api/messages-streaming
|
||||
*/
|
||||
export class AnthropicSSEFormatter implements ISSEFormatter<RawMessageStreamEvent> {
|
||||
/**
|
||||
* Format an Anthropic event for SSE streaming
|
||||
*/
|
||||
formatEvent(event: RawMessageStreamEvent): string {
|
||||
return `event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n`
|
||||
}
|
||||
|
||||
/**
|
||||
* Format the stream termination marker
|
||||
*/
|
||||
formatDone(): string {
|
||||
return 'data: [DONE]\n\n'
|
||||
}
|
||||
}
|
||||
|
||||
export default AnthropicSSEFormatter
|
||||
42
src/main/apiServer/adapters/formatters/OpenAISSEFormatter.ts
Normal file
42
src/main/apiServer/adapters/formatters/OpenAISSEFormatter.ts
Normal file
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* OpenAI Compatible SSE Formatter
|
||||
*
|
||||
* Formats OpenAI-compatible chat completion stream events for Server-Sent Events.
|
||||
* Supports extended features like reasoning_content used by DeepSeek and other providers.
|
||||
*/
|
||||
|
||||
import type { ISSEFormatter } from '../interfaces'
|
||||
import type { OpenAICompatibleChunk } from '../stream/AiSdkToOpenAISSE'
|
||||
|
||||
/**
|
||||
* Re-export the OpenAI-compatible chunk type for convenience
|
||||
*/
|
||||
export type { OpenAICompatibleChunk as ChatCompletionChunk } from '../stream/AiSdkToOpenAISSE'
|
||||
|
||||
/**
|
||||
* OpenAI Compatible SSE Formatter
|
||||
*
|
||||
* Formats events according to OpenAI's streaming API specification:
|
||||
* - data: {json}\n\n
|
||||
*
|
||||
* Supports extended fields like reasoning_content for OpenAI-compatible providers.
|
||||
*
|
||||
* @see https://platform.openai.com/docs/api-reference/chat/streaming
|
||||
*/
|
||||
export class OpenAISSEFormatter implements ISSEFormatter<OpenAICompatibleChunk> {
|
||||
/**
|
||||
* Format an OpenAI-compatible event for SSE streaming
|
||||
*/
|
||||
formatEvent(event: OpenAICompatibleChunk): string {
|
||||
return `data: ${JSON.stringify(event)}\n\n`
|
||||
}
|
||||
|
||||
/**
|
||||
* Format the stream termination marker
|
||||
*/
|
||||
formatDone(): string {
|
||||
return 'data: [DONE]\n\n'
|
||||
}
|
||||
}
|
||||
|
||||
export default OpenAISSEFormatter
|
||||
2
src/main/apiServer/adapters/formatters/index.ts
Normal file
2
src/main/apiServer/adapters/formatters/index.ts
Normal file
@ -0,0 +1,2 @@
|
||||
export { AnthropicSSEFormatter } from './AnthropicSSEFormatter'
|
||||
export { type ChatCompletionChunk, OpenAISSEFormatter } from './OpenAISSEFormatter'
|
||||
@ -1,13 +1,48 @@
|
||||
/**
|
||||
* Shared Adapters
|
||||
* API Server Adapters
|
||||
*
|
||||
* This module exports adapters for converting between different AI API formats.
|
||||
* This module provides adapters for converting between different AI API formats.
|
||||
*
|
||||
* Architecture:
|
||||
* - Stream adapters: Convert AI SDK streams to various output formats (Anthropic, OpenAI)
|
||||
* - Message converters: Convert input message formats to AI SDK format
|
||||
* - SSE formatters: Format events for Server-Sent Events streaming
|
||||
* - Factory: Creates adapters and formatters based on output format
|
||||
*/
|
||||
|
||||
// Stream Adapters
|
||||
export { AiSdkToAnthropicSSE } from './stream/AiSdkToAnthropicSSE'
|
||||
export { AiSdkToOpenAISSE } from './stream/AiSdkToOpenAISSE'
|
||||
export { BaseStreamAdapter } from './stream/BaseStreamAdapter'
|
||||
|
||||
// Message Converters
|
||||
export { AnthropicMessageConverter, type ReasoningCache } from './converters/AnthropicMessageConverter'
|
||||
export { type JsonSchemaLike, jsonSchemaToZod } from './converters/json-schema-to-zod'
|
||||
export { type ExtendedChatCompletionCreateParams, OpenAIMessageConverter } from './converters/OpenAIMessageConverter'
|
||||
|
||||
// SSE Formatters
|
||||
export { AnthropicSSEFormatter } from './formatters/AnthropicSSEFormatter'
|
||||
export { type ChatCompletionChunk, OpenAISSEFormatter } from './formatters/OpenAISSEFormatter'
|
||||
|
||||
// Factory
|
||||
export {
|
||||
AiSdkToAnthropicSSE,
|
||||
type AiSdkToAnthropicSSEOptions,
|
||||
formatSSEDone,
|
||||
formatSSEEvent,
|
||||
type SSEEventCallback
|
||||
} from './AiSdkToAnthropicSSE'
|
||||
type ConverterOptions,
|
||||
type InputParamsMap,
|
||||
MessageConverterFactory
|
||||
} from './factory/MessageConverterFactory'
|
||||
export { StreamAdapterFactory } from './factory/StreamAdapterFactory'
|
||||
|
||||
// Interfaces
|
||||
export type {
|
||||
AdapterRegistryEntry,
|
||||
AdapterState,
|
||||
ContentBlockState,
|
||||
IMessageConverter,
|
||||
InputFormat,
|
||||
ISSEFormatter,
|
||||
IStreamAdapter,
|
||||
OutputFormat,
|
||||
StreamAdapterConstructor,
|
||||
StreamAdapterOptions,
|
||||
StreamTextOptions
|
||||
} from './interfaces'
|
||||
|
||||
182
src/main/apiServer/adapters/interfaces.ts
Normal file
182
src/main/apiServer/adapters/interfaces.ts
Normal file
@ -0,0 +1,182 @@
|
||||
/**
|
||||
* Core interfaces for the API Server adapter system
|
||||
*
|
||||
* This module defines the contracts for:
|
||||
* - Stream adapters: Transform AI SDK streams to various output formats
|
||||
* - Message converters: Convert between API message formats
|
||||
* - SSE formatters: Format events for Server-Sent Events
|
||||
*/
|
||||
|
||||
import type { ProviderOptions } from '@ai-sdk/provider-utils'
|
||||
import type { Provider } from '@types'
|
||||
import type { ModelMessage, TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
/**
|
||||
* Supported output formats for stream adapters
|
||||
*/
|
||||
export type OutputFormat = 'anthropic' | 'openai' | 'gemini' | 'openai-responses'
|
||||
|
||||
/**
|
||||
* Supported input formats for message converters
|
||||
*/
|
||||
export type InputFormat = 'anthropic' | 'openai'
|
||||
|
||||
/**
|
||||
* Stream text options extracted from input params
|
||||
* These are the common parameters used by AI SDK's streamText/generateText
|
||||
*/
|
||||
export interface StreamTextOptions {
|
||||
maxOutputTokens?: number
|
||||
temperature?: number
|
||||
topP?: number
|
||||
topK?: number
|
||||
stopSequences?: string[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream Adapter Interface
|
||||
*
|
||||
* Uses TransformStream pattern for composability:
|
||||
* ```
|
||||
* input.pipeThrough(adapter1.getTransformStream()).pipeThrough(adapter2.getTransformStream())
|
||||
* ```
|
||||
*/
|
||||
export interface IStreamAdapter<TOutputEvent = unknown> {
|
||||
/**
|
||||
* Transform AI SDK stream to target format stream
|
||||
* @param input - ReadableStream from AI SDK's fullStream
|
||||
* @returns ReadableStream of formatted output events
|
||||
*/
|
||||
transform(input: ReadableStream<TextStreamPart<ToolSet>>): ReadableStream<TOutputEvent>
|
||||
|
||||
/**
|
||||
* Get the internal TransformStream for advanced use cases
|
||||
*/
|
||||
getTransformStream(): TransformStream<TextStreamPart<ToolSet>, TOutputEvent>
|
||||
|
||||
/**
|
||||
* Build a non-streaming response from accumulated state
|
||||
* Call after stream is fully consumed
|
||||
*/
|
||||
buildNonStreamingResponse(): unknown
|
||||
|
||||
/**
|
||||
* Get the message ID for this adapter instance
|
||||
*/
|
||||
getMessageId(): string
|
||||
|
||||
/**
|
||||
* Set input token count (for usage tracking)
|
||||
*/
|
||||
setInputTokens(count: number): void
|
||||
}
|
||||
|
||||
/**
|
||||
* Options for creating stream adapters
|
||||
*/
|
||||
export interface StreamAdapterOptions {
|
||||
/** Model identifier (e.g., "anthropic:claude-3-opus") */
|
||||
model: string
|
||||
/** Optional message ID, auto-generated if not provided */
|
||||
messageId?: string
|
||||
/** Initial input token count */
|
||||
inputTokens?: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Message Converter Interface
|
||||
*
|
||||
* Converts between different API message formats and AI SDK format.
|
||||
* Each converter handles a specific input format (OpenAI, Anthropic, etc.)
|
||||
*/
|
||||
export interface IMessageConverter<TInputParams = unknown> {
|
||||
/**
|
||||
* Convert input params to AI SDK ModelMessage[]
|
||||
*/
|
||||
toAiSdkMessages(params: TInputParams): ModelMessage[]
|
||||
|
||||
/**
|
||||
* Convert input tools to AI SDK tools format
|
||||
*/
|
||||
toAiSdkTools?(params: TInputParams): ToolSet | undefined
|
||||
|
||||
/**
|
||||
* Extract stream/generation options from input params
|
||||
* Maps format-specific parameters to AI SDK common options
|
||||
*/
|
||||
extractStreamOptions(params: TInputParams): StreamTextOptions
|
||||
|
||||
/**
|
||||
* Extract provider-specific options from input params
|
||||
* Handles thinking/reasoning configuration based on provider type
|
||||
*/
|
||||
extractProviderOptions(provider: Provider, params: TInputParams): ProviderOptions | undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* SSE Formatter Interface
|
||||
*
|
||||
* Formats events for Server-Sent Events streaming
|
||||
*/
|
||||
export interface ISSEFormatter<TEvent = unknown> {
|
||||
/**
|
||||
* Format an event for SSE streaming
|
||||
* @returns Formatted string like "event: type\ndata: {...}\n\n"
|
||||
*/
|
||||
formatEvent(event: TEvent): string
|
||||
|
||||
/**
|
||||
* Format the stream termination marker
|
||||
* @returns Done marker like "data: [DONE]\n\n"
|
||||
*/
|
||||
formatDone(): string
|
||||
}
|
||||
|
||||
/**
|
||||
* Content block state for tracking streaming content
|
||||
*/
|
||||
export interface ContentBlockState {
|
||||
type: 'text' | 'tool_use' | 'thinking'
|
||||
index: number
|
||||
started: boolean
|
||||
content: string
|
||||
// For tool_use blocks
|
||||
toolId?: string
|
||||
toolName?: string
|
||||
toolInput?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Adapter state for tracking stream processing
|
||||
*/
|
||||
export interface AdapterState {
|
||||
messageId: string
|
||||
model: string
|
||||
inputTokens: number
|
||||
outputTokens: number
|
||||
cacheInputTokens: number
|
||||
currentBlockIndex: number
|
||||
blocks: Map<number, ContentBlockState>
|
||||
textBlockIndex: number | null
|
||||
thinkingBlocks: Map<string, number>
|
||||
currentThinkingId: string | null
|
||||
toolBlocks: Map<string, number>
|
||||
stopReason: string | null
|
||||
hasEmittedMessageStart: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor type for stream adapters
|
||||
*/
|
||||
export type StreamAdapterConstructor<TOutputEvent = unknown> = new (
|
||||
options: StreamAdapterOptions
|
||||
) => IStreamAdapter<TOutputEvent>
|
||||
|
||||
/**
|
||||
* Registry entry for adapter factory
|
||||
*/
|
||||
export interface AdapterRegistryEntry<TOutputEvent = unknown> {
|
||||
format: OutputFormat
|
||||
adapterClass: StreamAdapterConstructor<TOutputEvent>
|
||||
formatterClass: new () => ISSEFormatter<TOutputEvent>
|
||||
}
|
||||
@ -36,109 +36,67 @@ import type {
|
||||
Usage
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
import { loggerService } from '@logger'
|
||||
import { type FinishReason, type LanguageModelUsage, type TextStreamPart, type ToolSet } from 'ai'
|
||||
import type { FinishReason, LanguageModelUsage, TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { googleReasoningCache, openRouterReasoningCache } from '../services/reasoning-cache'
|
||||
import { googleReasoningCache, openRouterReasoningCache } from '../../services/reasoning-cache'
|
||||
import type { StreamAdapterOptions } from '../interfaces'
|
||||
import { BaseStreamAdapter } from './BaseStreamAdapter'
|
||||
|
||||
const logger = loggerService.withContext('AiSdkToAnthropicSSE')
|
||||
|
||||
interface ContentBlockState {
|
||||
type: 'text' | 'tool_use' | 'thinking'
|
||||
index: number
|
||||
started: boolean
|
||||
content: string
|
||||
// For tool_use blocks
|
||||
toolId?: string
|
||||
toolName?: string
|
||||
toolInput?: string
|
||||
}
|
||||
|
||||
interface AdapterState {
|
||||
messageId: string
|
||||
model: string
|
||||
inputTokens: number
|
||||
outputTokens: number
|
||||
cacheInputTokens: number
|
||||
currentBlockIndex: number
|
||||
blocks: Map<number, ContentBlockState>
|
||||
textBlockIndex: number | null
|
||||
// Track multiple thinking blocks by their reasoning ID
|
||||
thinkingBlocks: Map<string, number> // reasoningId -> blockIndex
|
||||
currentThinkingId: string | null // Currently active thinking block ID
|
||||
toolBlocks: Map<string, number> // toolCallId -> blockIndex
|
||||
stopReason: StopReason | null
|
||||
hasEmittedMessageStart: boolean
|
||||
}
|
||||
|
||||
export type SSEEventCallback = (event: RawMessageStreamEvent) => void
|
||||
|
||||
export interface AiSdkToAnthropicSSEOptions {
|
||||
model: string
|
||||
messageId?: string
|
||||
inputTokens?: number
|
||||
onEvent: SSEEventCallback
|
||||
}
|
||||
|
||||
/**
|
||||
* Adapter that converts AI SDK fullStream events to Anthropic SSE events
|
||||
*
|
||||
* Uses TransformStream for composable stream processing:
|
||||
* ```
|
||||
* const adapter = new AiSdkToAnthropicSSE({ model: 'claude-3' })
|
||||
* const outputStream = adapter.transform(aiSdkStream)
|
||||
* ```
|
||||
*/
|
||||
export class AiSdkToAnthropicSSE {
|
||||
private state: AdapterState
|
||||
private onEvent: SSEEventCallback
|
||||
|
||||
constructor(options: AiSdkToAnthropicSSEOptions) {
|
||||
this.onEvent = options.onEvent
|
||||
this.state = {
|
||||
messageId: options.messageId || `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`,
|
||||
model: options.model,
|
||||
inputTokens: options.inputTokens || 0,
|
||||
outputTokens: 0,
|
||||
cacheInputTokens: 0,
|
||||
currentBlockIndex: 0,
|
||||
blocks: new Map(),
|
||||
textBlockIndex: null,
|
||||
thinkingBlocks: new Map(),
|
||||
currentThinkingId: null,
|
||||
toolBlocks: new Map(),
|
||||
stopReason: null,
|
||||
hasEmittedMessageStart: false
|
||||
}
|
||||
export class AiSdkToAnthropicSSE extends BaseStreamAdapter<RawMessageStreamEvent> {
|
||||
constructor(options: StreamAdapterOptions) {
|
||||
super(options)
|
||||
}
|
||||
|
||||
/**
|
||||
* Process the AI SDK stream and emit Anthropic SSE events
|
||||
* Emit the initial message_start event
|
||||
*/
|
||||
async processStream(fullStream: ReadableStream<TextStreamPart<ToolSet>>): Promise<void> {
|
||||
const reader = fullStream.getReader()
|
||||
protected emitMessageStart(): void {
|
||||
if (this.state.hasEmittedMessageStart) return
|
||||
|
||||
try {
|
||||
// Emit message_start at the beginning
|
||||
this.emitMessageStart()
|
||||
this.state.hasEmittedMessageStart = true
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
|
||||
if (done) {
|
||||
break
|
||||
}
|
||||
|
||||
this.processChunk(value)
|
||||
}
|
||||
|
||||
// Ensure all blocks are closed and emit final events
|
||||
this.finalize()
|
||||
} catch (error) {
|
||||
await reader.cancel()
|
||||
throw error
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
const usage: Usage = {
|
||||
input_tokens: this.state.inputTokens,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
server_tool_use: null
|
||||
}
|
||||
|
||||
const message: Message = {
|
||||
id: this.state.messageId,
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
content: [],
|
||||
model: this.state.model,
|
||||
stop_reason: null,
|
||||
stop_sequence: null,
|
||||
usage
|
||||
}
|
||||
|
||||
const event: RawMessageStartEvent = {
|
||||
type: 'message_start',
|
||||
message
|
||||
}
|
||||
|
||||
this.emit(event)
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a single AI SDK chunk and emit corresponding Anthropic events
|
||||
*/
|
||||
private processChunk(chunk: TextStreamPart<ToolSet>): void {
|
||||
protected processChunk(chunk: TextStreamPart<ToolSet>): void {
|
||||
logger.silly('AiSdkToAnthropicSSE - Processing chunk:', { chunk: JSON.stringify(chunk) })
|
||||
switch (chunk.type) {
|
||||
// === Text Events ===
|
||||
@ -200,13 +158,7 @@ export class AiSdkToAnthropicSSE {
|
||||
break
|
||||
|
||||
case 'tool-result':
|
||||
// this.handleToolResult({
|
||||
// type: 'tool-result',
|
||||
// toolCallId: chunk.toolCallId,
|
||||
// toolName: chunk.toolName,
|
||||
// args: chunk.input,
|
||||
// result: chunk.output
|
||||
// })
|
||||
// Tool results are handled differently in Anthropic format
|
||||
break
|
||||
|
||||
case 'finish-step':
|
||||
@ -222,49 +174,15 @@ export class AiSdkToAnthropicSSE {
|
||||
case 'error':
|
||||
throw chunk.error
|
||||
|
||||
// Ignore other event types
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
private emitMessageStart(): void {
|
||||
if (this.state.hasEmittedMessageStart) return
|
||||
|
||||
this.state.hasEmittedMessageStart = true
|
||||
|
||||
const usage: Usage = {
|
||||
input_tokens: this.state.inputTokens,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
server_tool_use: null
|
||||
}
|
||||
|
||||
const message: Message = {
|
||||
id: this.state.messageId,
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
content: [],
|
||||
model: this.state.model,
|
||||
stop_reason: null,
|
||||
stop_sequence: null,
|
||||
usage
|
||||
}
|
||||
|
||||
const event: RawMessageStartEvent = {
|
||||
type: 'message_start',
|
||||
message
|
||||
}
|
||||
|
||||
this.onEvent(event)
|
||||
}
|
||||
|
||||
private startTextBlock(): void {
|
||||
// If we already have a text block, don't create another
|
||||
if (this.state.textBlockIndex !== null) return
|
||||
|
||||
const index = this.state.currentBlockIndex++
|
||||
const index = this.allocateBlockIndex()
|
||||
this.state.textBlockIndex = index
|
||||
this.state.blocks.set(index, {
|
||||
type: 'text',
|
||||
@ -285,13 +203,12 @@ export class AiSdkToAnthropicSSE {
|
||||
content_block: contentBlock
|
||||
}
|
||||
|
||||
this.onEvent(event)
|
||||
this.emit(event)
|
||||
}
|
||||
|
||||
private emitTextDelta(text: string): void {
|
||||
if (!text) return
|
||||
|
||||
// Auto-start text block if not started
|
||||
if (this.state.textBlockIndex === null) {
|
||||
this.startTextBlock()
|
||||
}
|
||||
@ -313,7 +230,7 @@ export class AiSdkToAnthropicSSE {
|
||||
delta
|
||||
}
|
||||
|
||||
this.onEvent(event)
|
||||
this.emit(event)
|
||||
}
|
||||
|
||||
private stopTextBlock(): void {
|
||||
@ -326,15 +243,14 @@ export class AiSdkToAnthropicSSE {
|
||||
index
|
||||
}
|
||||
|
||||
this.onEvent(event)
|
||||
this.emit(event)
|
||||
this.state.textBlockIndex = null
|
||||
}
|
||||
|
||||
private startThinkingBlock(reasoningId: string): void {
|
||||
// Check if this thinking block already exists
|
||||
if (this.state.thinkingBlocks.has(reasoningId)) return
|
||||
|
||||
const index = this.state.currentBlockIndex++
|
||||
const index = this.allocateBlockIndex()
|
||||
this.state.thinkingBlocks.set(reasoningId, index)
|
||||
this.state.currentThinkingId = reasoningId
|
||||
this.state.blocks.set(index, {
|
||||
@ -356,16 +272,14 @@ export class AiSdkToAnthropicSSE {
|
||||
content_block: contentBlock
|
||||
}
|
||||
|
||||
this.onEvent(event)
|
||||
this.emit(event)
|
||||
}
|
||||
|
||||
private emitThinkingDelta(text: string, reasoningId?: string): void {
|
||||
if (!text) return
|
||||
|
||||
// Determine which thinking block to use
|
||||
const targetId = reasoningId || this.state.currentThinkingId
|
||||
if (!targetId) {
|
||||
// Auto-start thinking block if not started
|
||||
const newId = `reasoning_${Date.now()}`
|
||||
this.startThinkingBlock(newId)
|
||||
return this.emitThinkingDelta(text, newId)
|
||||
@ -373,7 +287,6 @@ export class AiSdkToAnthropicSSE {
|
||||
|
||||
const index = this.state.thinkingBlocks.get(targetId)
|
||||
if (index === undefined) {
|
||||
// If the block doesn't exist, create it
|
||||
this.startThinkingBlock(targetId)
|
||||
return this.emitThinkingDelta(text, targetId)
|
||||
}
|
||||
@ -394,7 +307,7 @@ export class AiSdkToAnthropicSSE {
|
||||
delta
|
||||
}
|
||||
|
||||
this.onEvent(event)
|
||||
this.emit(event)
|
||||
}
|
||||
|
||||
private stopThinkingBlock(reasoningId?: string): void {
|
||||
@ -409,12 +322,10 @@ export class AiSdkToAnthropicSSE {
|
||||
index
|
||||
}
|
||||
|
||||
this.onEvent(event)
|
||||
this.emit(event)
|
||||
this.state.thinkingBlocks.delete(targetId)
|
||||
|
||||
// Update currentThinkingId if we just closed the current one
|
||||
if (this.state.currentThinkingId === targetId) {
|
||||
// Set to the most recent remaining thinking block, or null if none
|
||||
const remaining = Array.from(this.state.thinkingBlocks.keys())
|
||||
this.state.currentThinkingId = remaining.length > 0 ? remaining[remaining.length - 1] : null
|
||||
}
|
||||
@ -423,12 +334,11 @@ export class AiSdkToAnthropicSSE {
|
||||
private handleToolCall(chunk: { type: 'tool-call'; toolCallId: string; toolName: string; args: unknown }): void {
|
||||
const { toolCallId, toolName, args } = chunk
|
||||
|
||||
// Check if we already have this tool call
|
||||
if (this.state.toolBlocks.has(toolCallId)) {
|
||||
return
|
||||
}
|
||||
|
||||
const index = this.state.currentBlockIndex++
|
||||
const index = this.allocateBlockIndex()
|
||||
this.state.toolBlocks.set(toolCallId, index)
|
||||
|
||||
const inputJson = JSON.stringify(args)
|
||||
@ -457,9 +367,9 @@ export class AiSdkToAnthropicSSE {
|
||||
content_block: contentBlock
|
||||
}
|
||||
|
||||
this.onEvent(startEvent)
|
||||
this.emit(startEvent)
|
||||
|
||||
// Emit the full input as a delta (Anthropic streams JSON incrementally)
|
||||
// Emit the full input as a delta
|
||||
const delta: InputJSONDelta = {
|
||||
type: 'input_json_delta',
|
||||
partial_json: inputJson
|
||||
@ -471,7 +381,7 @@ export class AiSdkToAnthropicSSE {
|
||||
delta
|
||||
}
|
||||
|
||||
this.onEvent(deltaEvent)
|
||||
this.emit(deltaEvent)
|
||||
|
||||
// Emit content_block_stop
|
||||
const stopEvent: RawContentBlockStopEvent = {
|
||||
@ -479,21 +389,18 @@ export class AiSdkToAnthropicSSE {
|
||||
index
|
||||
}
|
||||
|
||||
this.onEvent(stopEvent)
|
||||
this.emit(stopEvent)
|
||||
|
||||
// Mark that we have tool use
|
||||
this.state.stopReason = 'tool_use'
|
||||
}
|
||||
|
||||
private handleFinish(chunk: { type: 'finish'; finishReason?: FinishReason; totalUsage?: LanguageModelUsage }): void {
|
||||
// Update usage
|
||||
if (chunk.totalUsage) {
|
||||
this.state.inputTokens = chunk.totalUsage.inputTokens || 0
|
||||
this.state.outputTokens = chunk.totalUsage.outputTokens || 0
|
||||
this.state.cacheInputTokens = chunk.totalUsage.cachedInputTokens || 0
|
||||
}
|
||||
|
||||
// Determine finish reason
|
||||
if (!this.state.stopReason) {
|
||||
switch (chunk.finishReason) {
|
||||
case 'stop':
|
||||
@ -514,7 +421,10 @@ export class AiSdkToAnthropicSSE {
|
||||
}
|
||||
}
|
||||
|
||||
private finalize(): void {
|
||||
/**
|
||||
* Finalize the stream and emit closing events
|
||||
*/
|
||||
protected finalize(): void {
|
||||
// Close any open blocks
|
||||
if (this.state.textBlockIndex !== null) {
|
||||
this.stopTextBlock()
|
||||
@ -536,34 +446,20 @@ export class AiSdkToAnthropicSSE {
|
||||
const messageDeltaEvent: RawMessageDeltaEvent = {
|
||||
type: 'message_delta',
|
||||
delta: {
|
||||
stop_reason: this.state.stopReason || 'end_turn',
|
||||
stop_reason: (this.state.stopReason as StopReason) || 'end_turn',
|
||||
stop_sequence: null
|
||||
},
|
||||
usage
|
||||
}
|
||||
|
||||
this.onEvent(messageDeltaEvent)
|
||||
this.emit(messageDeltaEvent)
|
||||
|
||||
// Emit message_stop
|
||||
const messageStopEvent: RawMessageStopEvent = {
|
||||
type: 'message_stop'
|
||||
}
|
||||
|
||||
this.onEvent(messageStopEvent)
|
||||
}
|
||||
|
||||
/**
|
||||
* Set input token count (typically from prompt)
|
||||
*/
|
||||
setInputTokens(count: number): void {
|
||||
this.state.inputTokens = count
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current message ID
|
||||
*/
|
||||
getMessageId(): string {
|
||||
return this.state.messageId
|
||||
this.emit(messageStopEvent)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -572,7 +468,6 @@ export class AiSdkToAnthropicSSE {
|
||||
buildNonStreamingResponse(): Message {
|
||||
const content: ContentBlock[] = []
|
||||
|
||||
// Collect all content blocks in order
|
||||
const sortedBlocks = Array.from(this.state.blocks.values()).sort((a, b) => a.index - b.index)
|
||||
|
||||
for (const block of sortedBlocks) {
|
||||
@ -607,7 +502,7 @@ export class AiSdkToAnthropicSSE {
|
||||
role: 'assistant',
|
||||
content,
|
||||
model: this.state.model,
|
||||
stop_reason: this.state.stopReason || 'end_turn',
|
||||
stop_reason: (this.state.stopReason as StopReason) || 'end_turn',
|
||||
stop_sequence: null,
|
||||
usage: {
|
||||
input_tokens: this.state.inputTokens,
|
||||
@ -620,18 +515,4 @@ export class AiSdkToAnthropicSSE {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Format an Anthropic SSE event for HTTP streaming
|
||||
*/
|
||||
export function formatSSEEvent(event: RawMessageStreamEvent): string {
|
||||
return `event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n`
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a done marker for SSE stream
|
||||
*/
|
||||
export function formatSSEDone(): string {
|
||||
return 'data: [DONE]\n\n'
|
||||
}
|
||||
|
||||
export default AiSdkToAnthropicSSE
|
||||
416
src/main/apiServer/adapters/stream/AiSdkToOpenAISSE.ts
Normal file
416
src/main/apiServer/adapters/stream/AiSdkToOpenAISSE.ts
Normal file
@ -0,0 +1,416 @@
|
||||
/**
|
||||
* AI SDK to OpenAI Compatible SSE Adapter
|
||||
*
|
||||
* Converts AI SDK's fullStream (TextStreamPart) events to OpenAI-compatible Chat Completions API SSE format.
|
||||
* This enables any AI provider supported by AI SDK to be exposed via OpenAI-compatible API.
|
||||
*
|
||||
* Supports extended features used by OpenAI-compatible providers:
|
||||
* - reasoning_content: DeepSeek-style reasoning/thinking content
|
||||
* - Standard OpenAI fields: content, tool_calls, finish_reason, usage
|
||||
*
|
||||
* OpenAI SSE Event Flow:
|
||||
* 1. data: {chunk with role} - First chunk with assistant role
|
||||
* 2. data: {chunk with content/reasoning_content delta} - Incremental content updates
|
||||
* 3. data: {chunk with tool_calls} - Tool call deltas
|
||||
* 4. data: {chunk with finish_reason} - Final chunk with finish reason
|
||||
* 5. data: [DONE] - Stream complete
|
||||
*
|
||||
* @see https://platform.openai.com/docs/api-reference/chat/streaming
|
||||
*/
|
||||
|
||||
import type OpenAI from '@cherrystudio/openai'
|
||||
import { loggerService } from '@logger'
|
||||
import type { FinishReason, LanguageModelUsage, TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import type { StreamAdapterOptions } from '../interfaces'
|
||||
import { BaseStreamAdapter } from './BaseStreamAdapter'
|
||||
|
||||
const logger = loggerService.withContext('AiSdkToOpenAISSE')
|
||||
|
||||
/**
|
||||
* Use official OpenAI SDK types as base
|
||||
*/
|
||||
type ChatCompletionChunkBase = OpenAI.Chat.Completions.ChatCompletionChunk
|
||||
type ChatCompletion = OpenAI.Chat.Completions.ChatCompletion
|
||||
|
||||
/**
|
||||
* Extended delta type with reasoning_content support (DeepSeek-style)
|
||||
*/
|
||||
interface OpenAICompatibleDelta {
|
||||
role?: 'assistant'
|
||||
content?: string | null
|
||||
reasoning_content?: string | null
|
||||
tool_calls?: ChatCompletionChunkBase['choices'][0]['delta']['tool_calls']
|
||||
}
|
||||
|
||||
/**
|
||||
* Extended ChatCompletionChunk with reasoning_content support
|
||||
*/
|
||||
export interface OpenAICompatibleChunk extends Omit<ChatCompletionChunkBase, 'choices'> {
|
||||
choices: Array<{
|
||||
index: number
|
||||
delta: OpenAICompatibleDelta
|
||||
finish_reason: ChatCompletionChunkBase['choices'][0]['finish_reason']
|
||||
logprobs?: ChatCompletionChunkBase['choices'][0]['logprobs']
|
||||
}>
|
||||
}
|
||||
|
||||
/**
|
||||
* Extended ChatCompletion message with reasoning_content support
|
||||
*/
|
||||
interface OpenAICompatibleMessage extends OpenAI.Chat.Completions.ChatCompletionMessage {
|
||||
reasoning_content?: string | null
|
||||
}
|
||||
|
||||
/**
|
||||
* Extended ChatCompletion with reasoning_content support
|
||||
*/
|
||||
export interface OpenAICompatibleCompletion extends Omit<ChatCompletion, 'choices'> {
|
||||
choices: Array<{
|
||||
index: number
|
||||
message: OpenAICompatibleMessage
|
||||
finish_reason: ChatCompletion['choices'][0]['finish_reason']
|
||||
logprobs: ChatCompletion['choices'][0]['logprobs']
|
||||
}>
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI finish reasons
|
||||
*/
|
||||
type OpenAIFinishReason = 'stop' | 'length' | 'tool_calls' | 'content_filter' | null
|
||||
|
||||
/**
|
||||
* Tool call state for tracking incremental tool calls
|
||||
*/
|
||||
interface ToolCallState {
|
||||
index: number
|
||||
id: string
|
||||
name: string
|
||||
arguments: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Adapter that converts AI SDK fullStream events to OpenAI-compatible SSE events
|
||||
*
|
||||
* Uses TransformStream for composable stream processing:
|
||||
* ```
|
||||
* const adapter = new AiSdkToOpenAISSE({ model: 'gpt-4' })
|
||||
* const outputStream = adapter.transform(aiSdkStream)
|
||||
* ```
|
||||
*/
|
||||
export class AiSdkToOpenAISSE extends BaseStreamAdapter<OpenAICompatibleChunk> {
|
||||
private createdTimestamp: number
|
||||
private toolCalls: Map<string, ToolCallState> = new Map()
|
||||
private currentToolCallIndex = 0
|
||||
private finishReason: OpenAIFinishReason = null
|
||||
private reasoningContent = ''
|
||||
|
||||
constructor(options: StreamAdapterOptions) {
|
||||
super(options)
|
||||
this.createdTimestamp = Math.floor(Date.now() / 1000)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a base chunk structure
|
||||
*/
|
||||
private createBaseChunk(delta: OpenAICompatibleDelta): OpenAICompatibleChunk {
|
||||
return {
|
||||
id: `chatcmpl-${this.state.messageId}`,
|
||||
object: 'chat.completion.chunk',
|
||||
created: this.createdTimestamp,
|
||||
model: this.state.model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta,
|
||||
finish_reason: null
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Emit the initial message start event (with role)
|
||||
*/
|
||||
protected emitMessageStart(): void {
|
||||
if (this.state.hasEmittedMessageStart) return
|
||||
|
||||
this.state.hasEmittedMessageStart = true
|
||||
|
||||
// Emit initial chunk with role
|
||||
const chunk = this.createBaseChunk({ role: 'assistant' })
|
||||
this.emit(chunk)
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a single AI SDK chunk and emit corresponding OpenAI events
|
||||
*/
|
||||
protected processChunk(chunk: TextStreamPart<ToolSet>): void {
|
||||
logger.silly('AiSdkToOpenAISSE - Processing chunk:', { chunk: JSON.stringify(chunk) })
|
||||
switch (chunk.type) {
|
||||
// === Text Events ===
|
||||
case 'text-start':
|
||||
// OpenAI doesn't have a separate start event
|
||||
break
|
||||
|
||||
case 'text-delta':
|
||||
this.emitContentDelta(chunk.text || '')
|
||||
break
|
||||
|
||||
case 'text-end':
|
||||
// OpenAI doesn't have a separate end event
|
||||
break
|
||||
|
||||
// === Reasoning/Thinking Events ===
|
||||
// Support DeepSeek-style reasoning_content
|
||||
case 'reasoning-start':
|
||||
// No separate start event needed
|
||||
break
|
||||
|
||||
case 'reasoning-delta':
|
||||
this.emitReasoningDelta(chunk.text || '')
|
||||
break
|
||||
|
||||
case 'reasoning-end':
|
||||
// No separate end event needed
|
||||
break
|
||||
|
||||
// === Tool Events ===
|
||||
case 'tool-call':
|
||||
this.handleToolCall({
|
||||
toolCallId: chunk.toolCallId,
|
||||
toolName: chunk.toolName,
|
||||
args: chunk.input
|
||||
})
|
||||
break
|
||||
|
||||
case 'tool-result':
|
||||
// Tool results are not part of streaming output
|
||||
break
|
||||
|
||||
case 'finish-step':
|
||||
if (chunk.finishReason === 'tool-calls') {
|
||||
this.finishReason = 'tool_calls'
|
||||
}
|
||||
break
|
||||
|
||||
case 'finish':
|
||||
this.handleFinish(chunk)
|
||||
break
|
||||
|
||||
case 'error':
|
||||
throw chunk.error
|
||||
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
private emitContentDelta(content: string): void {
|
||||
if (!content) return
|
||||
|
||||
// Track content in state
|
||||
let textBlock = this.state.blocks.get(0)
|
||||
if (!textBlock) {
|
||||
textBlock = {
|
||||
type: 'text',
|
||||
index: 0,
|
||||
started: true,
|
||||
content: ''
|
||||
}
|
||||
this.state.blocks.set(0, textBlock)
|
||||
}
|
||||
textBlock.content += content
|
||||
|
||||
const chunk = this.createBaseChunk({ content })
|
||||
this.emit(chunk)
|
||||
}
|
||||
|
||||
private emitReasoningDelta(reasoningContent: string): void {
|
||||
if (!reasoningContent) return
|
||||
|
||||
// Track reasoning content
|
||||
this.reasoningContent += reasoningContent
|
||||
|
||||
// Also track in state blocks for non-streaming response
|
||||
let thinkingBlock = this.state.blocks.get(-1) // Use -1 for thinking block
|
||||
if (!thinkingBlock) {
|
||||
thinkingBlock = {
|
||||
type: 'thinking',
|
||||
index: -1,
|
||||
started: true,
|
||||
content: ''
|
||||
}
|
||||
this.state.blocks.set(-1, thinkingBlock)
|
||||
}
|
||||
thinkingBlock.content += reasoningContent
|
||||
|
||||
// Emit chunk with reasoning_content (DeepSeek-style)
|
||||
const chunk = this.createBaseChunk({ reasoning_content: reasoningContent })
|
||||
this.emit(chunk)
|
||||
}
|
||||
|
||||
private handleToolCall(params: { toolCallId: string; toolName: string; args: unknown }): void {
|
||||
const { toolCallId, toolName, args } = params
|
||||
|
||||
if (this.toolCalls.has(toolCallId)) {
|
||||
return
|
||||
}
|
||||
|
||||
const index = this.currentToolCallIndex++
|
||||
const argsString = JSON.stringify(args)
|
||||
|
||||
this.toolCalls.set(toolCallId, {
|
||||
index,
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
arguments: argsString
|
||||
})
|
||||
|
||||
// Track in state
|
||||
const blockIndex = this.allocateBlockIndex()
|
||||
this.state.blocks.set(blockIndex, {
|
||||
type: 'tool_use',
|
||||
index: blockIndex,
|
||||
started: true,
|
||||
content: argsString,
|
||||
toolId: toolCallId,
|
||||
toolName,
|
||||
toolInput: argsString
|
||||
})
|
||||
|
||||
// Emit tool call chunk
|
||||
const chunk = this.createBaseChunk({
|
||||
tool_calls: [
|
||||
{
|
||||
index,
|
||||
id: toolCallId,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: toolName,
|
||||
arguments: argsString
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
this.emit(chunk)
|
||||
this.finishReason = 'tool_calls'
|
||||
}
|
||||
|
||||
private handleFinish(chunk: { type: 'finish'; finishReason?: FinishReason; totalUsage?: LanguageModelUsage }): void {
|
||||
if (chunk.totalUsage) {
|
||||
this.state.inputTokens = chunk.totalUsage.inputTokens || 0
|
||||
this.state.outputTokens = chunk.totalUsage.outputTokens || 0
|
||||
}
|
||||
|
||||
if (!this.finishReason) {
|
||||
switch (chunk.finishReason) {
|
||||
case 'stop':
|
||||
this.finishReason = 'stop'
|
||||
break
|
||||
case 'length':
|
||||
this.finishReason = 'length'
|
||||
break
|
||||
case 'tool-calls':
|
||||
this.finishReason = 'tool_calls'
|
||||
break
|
||||
case 'content-filter':
|
||||
this.finishReason = 'content_filter'
|
||||
break
|
||||
default:
|
||||
this.finishReason = 'stop'
|
||||
}
|
||||
}
|
||||
|
||||
this.state.stopReason = this.finishReason
|
||||
}
|
||||
|
||||
/**
|
||||
* Finalize the stream and emit closing events
|
||||
*/
|
||||
protected finalize(): void {
|
||||
// Emit final chunk with finish_reason and usage
|
||||
const finalChunk: OpenAICompatibleChunk = {
|
||||
id: `chatcmpl-${this.state.messageId}`,
|
||||
object: 'chat.completion.chunk',
|
||||
created: this.createdTimestamp,
|
||||
model: this.state.model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {},
|
||||
finish_reason: this.finishReason || 'stop'
|
||||
}
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: this.state.inputTokens,
|
||||
completion_tokens: this.state.outputTokens,
|
||||
total_tokens: this.state.inputTokens + this.state.outputTokens
|
||||
}
|
||||
}
|
||||
|
||||
this.emit(finalChunk)
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a complete ChatCompletion object for non-streaming responses
|
||||
*/
|
||||
buildNonStreamingResponse(): OpenAICompatibleCompletion {
|
||||
// Collect text content
|
||||
let content: string | null = null
|
||||
const textBlock = this.state.blocks.get(0)
|
||||
if (textBlock && textBlock.type === 'text' && textBlock.content) {
|
||||
content = textBlock.content
|
||||
}
|
||||
|
||||
// Collect reasoning content
|
||||
let reasoningContent: string | null = null
|
||||
const thinkingBlock = this.state.blocks.get(-1)
|
||||
if (thinkingBlock && thinkingBlock.type === 'thinking' && thinkingBlock.content) {
|
||||
reasoningContent = thinkingBlock.content
|
||||
}
|
||||
|
||||
// Collect tool calls
|
||||
const toolCallsArray: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = Array.from(
|
||||
this.toolCalls.values()
|
||||
).map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: tc.arguments
|
||||
}
|
||||
}))
|
||||
|
||||
const message: OpenAICompatibleMessage = {
|
||||
role: 'assistant',
|
||||
content,
|
||||
refusal: null,
|
||||
...(reasoningContent ? { reasoning_content: reasoningContent } : {}),
|
||||
...(toolCallsArray.length > 0 ? { tool_calls: toolCallsArray } : {})
|
||||
}
|
||||
|
||||
return {
|
||||
id: `chatcmpl-${this.state.messageId}`,
|
||||
object: 'chat.completion',
|
||||
created: this.createdTimestamp,
|
||||
model: this.state.model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
message,
|
||||
finish_reason: this.finishReason || 'stop',
|
||||
logprobs: null
|
||||
}
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: this.state.inputTokens,
|
||||
completion_tokens: this.state.outputTokens,
|
||||
total_tokens: this.state.inputTokens + this.state.outputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default AiSdkToOpenAISSE
|
||||
161
src/main/apiServer/adapters/stream/BaseStreamAdapter.ts
Normal file
161
src/main/apiServer/adapters/stream/BaseStreamAdapter.ts
Normal file
@ -0,0 +1,161 @@
|
||||
/**
|
||||
* Base Stream Adapter
|
||||
*
|
||||
* Abstract base class for stream adapters that provides:
|
||||
* - Shared state management (messageId, tokens, blocks, etc.)
|
||||
* - TransformStream implementation
|
||||
* - Common utility methods
|
||||
*/
|
||||
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import type { AdapterState, ContentBlockState, IStreamAdapter, StreamAdapterOptions } from '../interfaces'
|
||||
|
||||
/**
|
||||
* Abstract base class for stream adapters
|
||||
*
|
||||
* Subclasses must implement:
|
||||
* - processChunk(): Handle individual stream chunks
|
||||
* - emitMessageStart(): Emit initial message event
|
||||
* - finalize(): Clean up and emit final events
|
||||
* - buildNonStreamingResponse(): Build complete response object
|
||||
*/
|
||||
export abstract class BaseStreamAdapter<TOutputEvent> implements IStreamAdapter<TOutputEvent> {
|
||||
protected state: AdapterState
|
||||
protected controller: TransformStreamDefaultController<TOutputEvent> | null = null
|
||||
private transformStream: TransformStream<TextStreamPart<ToolSet>, TOutputEvent>
|
||||
|
||||
constructor(options: StreamAdapterOptions) {
|
||||
this.state = this.createInitialState(options)
|
||||
this.transformStream = this.createTransformStream()
|
||||
}
|
||||
|
||||
/**
|
||||
* Create initial adapter state
|
||||
*/
|
||||
protected createInitialState(options: StreamAdapterOptions): AdapterState {
|
||||
return {
|
||||
messageId: options.messageId || `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`,
|
||||
model: options.model,
|
||||
inputTokens: options.inputTokens || 0,
|
||||
outputTokens: 0,
|
||||
cacheInputTokens: 0,
|
||||
currentBlockIndex: 0,
|
||||
blocks: new Map(),
|
||||
textBlockIndex: null,
|
||||
thinkingBlocks: new Map(),
|
||||
currentThinkingId: null,
|
||||
toolBlocks: new Map(),
|
||||
stopReason: null,
|
||||
hasEmittedMessageStart: false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the TransformStream for processing
|
||||
*/
|
||||
private createTransformStream(): TransformStream<TextStreamPart<ToolSet>, TOutputEvent> {
|
||||
return new TransformStream<TextStreamPart<ToolSet>, TOutputEvent>({
|
||||
start: (controller) => {
|
||||
this.controller = controller
|
||||
// Note: emitMessageStart is called lazily in transform or finalize
|
||||
// to allow configuration changes (like setInputTokens) after construction
|
||||
},
|
||||
transform: (chunk, _controller) => {
|
||||
// Ensure message_start is emitted before processing chunks
|
||||
this.emitMessageStart()
|
||||
this.processChunk(chunk)
|
||||
},
|
||||
flush: (_controller) => {
|
||||
// Ensure message_start is emitted even for empty streams
|
||||
this.emitMessageStart()
|
||||
this.finalize()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform input stream to output stream
|
||||
*/
|
||||
transform(input: ReadableStream<TextStreamPart<ToolSet>>): ReadableStream<TOutputEvent> {
|
||||
return input.pipeThrough(this.transformStream)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the internal TransformStream
|
||||
*/
|
||||
getTransformStream(): TransformStream<TextStreamPart<ToolSet>, TOutputEvent> {
|
||||
return this.transformStream
|
||||
}
|
||||
|
||||
/**
|
||||
* Get message ID
|
||||
*/
|
||||
getMessageId(): string {
|
||||
return this.state.messageId
|
||||
}
|
||||
|
||||
/**
|
||||
* Set input token count
|
||||
*/
|
||||
setInputTokens(count: number): void {
|
||||
this.state.inputTokens = count
|
||||
}
|
||||
|
||||
/**
|
||||
* Emit an event to the output stream
|
||||
*/
|
||||
protected emit(event: TOutputEvent): void {
|
||||
if (this.controller) {
|
||||
this.controller.enqueue(event)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get or create a content block
|
||||
*/
|
||||
protected getOrCreateBlock(index: number, type: ContentBlockState['type']): ContentBlockState {
|
||||
let block = this.state.blocks.get(index)
|
||||
if (!block) {
|
||||
block = {
|
||||
type,
|
||||
index,
|
||||
started: false,
|
||||
content: ''
|
||||
}
|
||||
this.state.blocks.set(index, block)
|
||||
}
|
||||
return block
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocate a new block index
|
||||
*/
|
||||
protected allocateBlockIndex(): number {
|
||||
return this.state.currentBlockIndex++
|
||||
}
|
||||
|
||||
// ===== Abstract methods to be implemented by subclasses =====
|
||||
|
||||
/**
|
||||
* Process a single chunk from the AI SDK stream
|
||||
*/
|
||||
protected abstract processChunk(chunk: TextStreamPart<ToolSet>): void
|
||||
|
||||
/**
|
||||
* Emit the initial message start event
|
||||
*/
|
||||
protected abstract emitMessageStart(): void
|
||||
|
||||
/**
|
||||
* Finalize the stream and emit closing events
|
||||
*/
|
||||
protected abstract finalize(): void
|
||||
|
||||
/**
|
||||
* Build a non-streaming response from accumulated state
|
||||
*/
|
||||
abstract buildNonStreamingResponse(): unknown
|
||||
}
|
||||
|
||||
export default BaseStreamAdapter
|
||||
3
src/main/apiServer/adapters/stream/index.ts
Normal file
3
src/main/apiServer/adapters/stream/index.ts
Normal file
@ -0,0 +1,3 @@
|
||||
export { AiSdkToAnthropicSSE } from './AiSdkToAnthropicSSE'
|
||||
export { AiSdkToOpenAISSE } from './AiSdkToOpenAISSE'
|
||||
export { BaseStreamAdapter } from './BaseStreamAdapter'
|
||||
@ -1,13 +1,10 @@
|
||||
import type { ChatCompletionCreateParams } from '@cherrystudio/openai/resources'
|
||||
import type { Request, Response } from 'express'
|
||||
import express from 'express'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import {
|
||||
ChatCompletionModelError,
|
||||
chatCompletionService,
|
||||
ChatCompletionValidationError
|
||||
} from '../services/chat-completion'
|
||||
import type { ExtendedChatCompletionCreateParams } from '../adapters'
|
||||
import { generateMessage, streamToResponse } from '../services/ProxyStreamService'
|
||||
import { validateModelId } from '../utils'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerChatRoutes')
|
||||
|
||||
@ -22,44 +19,17 @@ interface ErrorResponseBody {
|
||||
}
|
||||
|
||||
const mapChatCompletionError = (error: unknown): { status: number; body: ErrorResponseBody } => {
|
||||
if (error instanceof ChatCompletionValidationError) {
|
||||
logger.warn('Chat completion validation error', {
|
||||
errors: error.errors
|
||||
})
|
||||
|
||||
return {
|
||||
status: 400,
|
||||
body: {
|
||||
error: {
|
||||
message: error.errors.join('; '),
|
||||
type: 'invalid_request_error',
|
||||
code: 'validation_failed'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (error instanceof ChatCompletionModelError) {
|
||||
logger.warn('Chat completion model error', error.error)
|
||||
|
||||
return {
|
||||
status: 400,
|
||||
body: {
|
||||
error: {
|
||||
message: error.error.message,
|
||||
type: 'invalid_request_error',
|
||||
code: error.error.code
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (error instanceof Error) {
|
||||
let statusCode = 500
|
||||
let errorType = 'server_error'
|
||||
let errorCode = 'internal_error'
|
||||
|
||||
if (error.message.includes('API key') || error.message.includes('authentication')) {
|
||||
// Model validation errors
|
||||
if (error.message.includes('Model') && error.message.includes('not found')) {
|
||||
statusCode = 400
|
||||
errorType = 'invalid_request_error'
|
||||
errorCode = 'model_not_found'
|
||||
} else if (error.message.includes('API key') || error.message.includes('authentication')) {
|
||||
statusCode = 401
|
||||
errorType = 'authentication_error'
|
||||
errorCode = 'invalid_api_key'
|
||||
@ -182,7 +152,7 @@ const mapChatCompletionError = (error: unknown): { status: number; body: ErrorRe
|
||||
*/
|
||||
router.post('/completions', async (req: Request, res: Response) => {
|
||||
try {
|
||||
const request: ChatCompletionCreateParams = req.body
|
||||
const request = req.body as ExtendedChatCompletionCreateParams
|
||||
|
||||
if (!request) {
|
||||
return res.status(400).json({
|
||||
@ -194,6 +164,26 @@ router.post('/completions', async (req: Request, res: Response) => {
|
||||
})
|
||||
}
|
||||
|
||||
if (!request.model) {
|
||||
return res.status(400).json({
|
||||
error: {
|
||||
message: 'Model is required',
|
||||
type: 'invalid_request_error',
|
||||
code: 'missing_model'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if (!request.messages || request.messages.length === 0) {
|
||||
return res.status(400).json({
|
||||
error: {
|
||||
message: 'Messages are required',
|
||||
type: 'invalid_request_error',
|
||||
code: 'missing_messages'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
logger.debug('Chat completion request', {
|
||||
model: request.model,
|
||||
messageCount: request.messages?.length || 0,
|
||||
@ -201,40 +191,51 @@ router.post('/completions', async (req: Request, res: Response) => {
|
||||
temperature: request.temperature
|
||||
})
|
||||
|
||||
// Validate model and get provider
|
||||
const modelValidation = await validateModelId(request.model)
|
||||
if (!modelValidation.valid) {
|
||||
return res.status(400).json({
|
||||
error: {
|
||||
message: modelValidation.error?.message || 'Model not found',
|
||||
type: 'invalid_request_error',
|
||||
code: modelValidation.error?.code || 'model_not_found'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const provider = modelValidation.provider!
|
||||
const modelId = modelValidation.modelId!
|
||||
const isStreaming = !!request.stream
|
||||
|
||||
if (isStreaming) {
|
||||
const { stream } = await chatCompletionService.processStreamingCompletion(request)
|
||||
|
||||
res.setHeader('Content-Type', 'text/event-stream; charset=utf-8')
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
res.setHeader('X-Accel-Buffering', 'no')
|
||||
res.flushHeaders()
|
||||
|
||||
try {
|
||||
for await (const chunk of stream) {
|
||||
res.write(`data: ${JSON.stringify(chunk)}\n\n`)
|
||||
}
|
||||
res.write('data: [DONE]\n\n')
|
||||
} catch (streamError: any) {
|
||||
await streamToResponse({
|
||||
response: res,
|
||||
provider,
|
||||
modelId,
|
||||
params: request,
|
||||
inputFormat: 'openai',
|
||||
outputFormat: 'openai'
|
||||
})
|
||||
} catch (streamError) {
|
||||
logger.error('Stream error', { error: streamError })
|
||||
res.write(
|
||||
`data: ${JSON.stringify({
|
||||
error: {
|
||||
message: 'Stream processing error',
|
||||
type: 'server_error',
|
||||
code: 'stream_error'
|
||||
}
|
||||
})}\n\n`
|
||||
)
|
||||
} finally {
|
||||
res.end()
|
||||
// If headers weren't sent yet, return JSON error
|
||||
if (!res.headersSent) {
|
||||
const { status, body } = mapChatCompletionError(streamError)
|
||||
return res.status(status).json(body)
|
||||
}
|
||||
// Otherwise the error is already handled by streamToResponse
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const { response } = await chatCompletionService.processCompletion(request)
|
||||
const response = await generateMessage({
|
||||
provider,
|
||||
modelId,
|
||||
params: request,
|
||||
inputFormat: 'openai',
|
||||
outputFormat: 'openai'
|
||||
})
|
||||
return res.json(response)
|
||||
} catch (error: unknown) {
|
||||
const { status, body } = mapChatCompletionError(error)
|
||||
|
||||
@ -8,7 +8,7 @@ import express from 'express'
|
||||
import { approximateTokenSize } from 'tokenx'
|
||||
|
||||
import { messagesService } from '../services/messages'
|
||||
import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages'
|
||||
import { generateMessage, streamToResponse } from '../services/ProxyStreamService'
|
||||
import { getProviderById, isModelAnthropicCompatible, validateModelId } from '../utils'
|
||||
|
||||
/**
|
||||
@ -322,7 +322,7 @@ async function handleUnifiedProcessing({
|
||||
})
|
||||
|
||||
if (request.stream) {
|
||||
await streamUnifiedMessages({
|
||||
await streamToResponse({
|
||||
response: res,
|
||||
provider,
|
||||
modelId: actualModelId,
|
||||
@ -336,7 +336,7 @@ async function handleUnifiedProcessing({
|
||||
}
|
||||
})
|
||||
} else {
|
||||
const response = await generateUnifiedMessage({
|
||||
const response = await generateMessage({
|
||||
provider,
|
||||
modelId: actualModelId,
|
||||
params: request,
|
||||
|
||||
465
src/main/apiServer/services/ProxyStreamService.ts
Normal file
465
src/main/apiServer/services/ProxyStreamService.ts
Normal file
@ -0,0 +1,465 @@
|
||||
/**
|
||||
* Proxy Stream Service
|
||||
*
|
||||
* Handles proxying AI requests through the unified AI SDK pipeline,
|
||||
* converting between different API formats using the adapter system.
|
||||
*/
|
||||
|
||||
import type { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core'
|
||||
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import { generateSignature as cherryaiGenerateSignature } from '@main/integration/cherryai'
|
||||
import anthropicService from '@main/services/AnthropicService'
|
||||
import copilotService from '@main/services/CopilotService'
|
||||
import { reduxService } from '@main/services/ReduxService'
|
||||
import {
|
||||
type AiSdkConfig,
|
||||
type AiSdkConfigContext,
|
||||
formatProviderApiHost,
|
||||
initializeSharedProviders,
|
||||
type ProviderFormatContext,
|
||||
providerToAiSdkConfig as sharedProviderToAiSdkConfig,
|
||||
resolveActualProvider
|
||||
} from '@shared/aiCore'
|
||||
import { COPILOT_DEFAULT_HEADERS } from '@shared/aiCore/constant'
|
||||
import type { MinimalProvider } from '@shared/types'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import type { Provider } from '@types'
|
||||
import type { Provider as AiSdkProvider } from 'ai'
|
||||
import { simulateStreamingMiddleware, stepCountIs, wrapLanguageModel } from 'ai'
|
||||
import { net } from 'electron'
|
||||
import type { Response } from 'express'
|
||||
|
||||
import type { InputFormat, InputParamsMap, IStreamAdapter } from '../adapters'
|
||||
import { MessageConverterFactory, type OutputFormat, StreamAdapterFactory } from '../adapters'
|
||||
import { googleReasoningCache, openRouterReasoningCache } from './reasoning-cache'
|
||||
|
||||
const logger = loggerService.withContext('ProxyStreamService')
|
||||
|
||||
initializeSharedProviders({
|
||||
warn: (message) => logger.warn(message),
|
||||
error: (message, error) => logger.error(message, error)
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Configuration Interfaces
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Middleware type alias
|
||||
*/
|
||||
type LanguageModelMiddleware = LanguageModelV2Middleware
|
||||
|
||||
/**
|
||||
* Union type for all supported input params
|
||||
*/
|
||||
type InputParams = InputParamsMap[InputFormat]
|
||||
|
||||
/**
|
||||
* Configuration for streaming message requests
|
||||
*/
|
||||
export interface StreamConfig {
|
||||
response: Response
|
||||
provider: Provider
|
||||
modelId: string
|
||||
params: InputParams
|
||||
inputFormat?: InputFormat
|
||||
outputFormat?: OutputFormat
|
||||
onError?: (error: unknown) => void
|
||||
onComplete?: () => void
|
||||
middlewares?: LanguageModelMiddleware[]
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for non-streaming message generation
|
||||
*/
|
||||
export interface GenerateConfig {
|
||||
provider: Provider
|
||||
modelId: string
|
||||
params: InputParams
|
||||
inputFormat?: InputFormat
|
||||
outputFormat?: OutputFormat
|
||||
middlewares?: LanguageModelMiddleware[]
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal configuration for stream execution
|
||||
*/
|
||||
interface ExecuteStreamConfig {
|
||||
provider: Provider
|
||||
modelId: string
|
||||
params: InputParams
|
||||
inputFormat: InputFormat
|
||||
outputFormat: OutputFormat
|
||||
middlewares?: LanguageModelMiddleware[]
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Configuration
|
||||
// ============================================================================
|
||||
|
||||
function getMainProcessFormatContext(): ProviderFormatContext {
|
||||
const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai')
|
||||
return {
|
||||
vertex: {
|
||||
project: vertexSettings?.projectId || 'default-project',
|
||||
location: vertexSettings?.location || 'us-central1'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function isSupportStreamOptionsProvider(provider: MinimalProvider): boolean {
|
||||
const NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS = ['mistral'] as const
|
||||
return !NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS.some((pid) => pid === provider.id)
|
||||
}
|
||||
|
||||
const mainProcessSdkContext: AiSdkConfigContext = {
|
||||
isSupportStreamOptionsProvider,
|
||||
getIncludeUsageSetting: () =>
|
||||
reduxService.selectSync<boolean | undefined>('state.settings.openAI?.streamOptions?.includeUsage'),
|
||||
fetch: net.fetch as typeof globalThis.fetch
|
||||
}
|
||||
|
||||
function getActualProvider(provider: Provider, modelId: string): Provider {
|
||||
const model = provider.models?.find((m) => m.id === modelId)
|
||||
if (!model) return provider
|
||||
return resolveActualProvider(provider, model)
|
||||
}
|
||||
|
||||
function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig {
|
||||
const actualProvider = getActualProvider(provider, modelId)
|
||||
const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext())
|
||||
return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create AI SDK provider instance from config
|
||||
*/
|
||||
async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider> {
|
||||
let providerId = config.providerId
|
||||
|
||||
// Handle special provider modes
|
||||
if (providerId === 'openai' && config.options?.mode === 'chat') {
|
||||
providerId = 'openai-chat'
|
||||
} else if (providerId === 'azure' && config.options?.mode === 'responses') {
|
||||
providerId = 'azure-responses'
|
||||
} else if (providerId === 'cherryin' && config.options?.mode === 'chat') {
|
||||
providerId = 'cherryin-chat'
|
||||
}
|
||||
|
||||
const provider = await createProviderCore(providerId, config.options)
|
||||
return provider
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare special provider configuration for providers that need dynamic tokens
|
||||
*/
|
||||
async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkConfig): Promise<AiSdkConfig> {
|
||||
switch (provider.id) {
|
||||
case 'copilot': {
|
||||
const storedHeaders =
|
||||
((await reduxService.select('state.copilot.defaultHeaders')) as Record<string, string> | null) ?? {}
|
||||
const headers: Record<string, string> = {
|
||||
...COPILOT_DEFAULT_HEADERS,
|
||||
...storedHeaders
|
||||
}
|
||||
|
||||
try {
|
||||
const { token } = await copilotService.getToken(null as never, headers)
|
||||
config.options.apiKey = token
|
||||
const existingHeaders = (config.options.headers as Record<string, string> | undefined) ?? {}
|
||||
config.options.headers = {
|
||||
...headers,
|
||||
...existingHeaders
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to get Copilot token', error as Error)
|
||||
throw new Error('Failed to get Copilot token. Please re-authorize Copilot.')
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'anthropic': {
|
||||
if (provider.authType === 'oauth') {
|
||||
try {
|
||||
const oauthToken = await anthropicService.getValidAccessToken()
|
||||
if (!oauthToken) {
|
||||
throw new Error('Anthropic OAuth token not available. Please re-authorize.')
|
||||
}
|
||||
config.options = {
|
||||
...config.options,
|
||||
headers: {
|
||||
...(config.options.headers ? config.options.headers : {}),
|
||||
'Content-Type': 'application/json',
|
||||
'anthropic-version': '2023-06-01',
|
||||
'anthropic-beta': 'oauth-2025-04-20',
|
||||
Authorization: `Bearer ${oauthToken}`
|
||||
},
|
||||
baseURL: 'https://api.anthropic.com/v1',
|
||||
apiKey: ''
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to get Anthropic OAuth token', error as Error)
|
||||
throw new Error('Failed to get Anthropic OAuth token. Please re-authorize.')
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'cherryai': {
|
||||
const baseFetch = net.fetch as typeof globalThis.fetch
|
||||
config.options.fetch = async (url: RequestInfo | URL, options?: RequestInit) => {
|
||||
if (!options?.body) {
|
||||
return baseFetch(url, options)
|
||||
}
|
||||
const signature = cherryaiGenerateSignature({
|
||||
method: 'POST',
|
||||
path: '/chat/completions',
|
||||
query: '',
|
||||
body: JSON.parse(options.body as string)
|
||||
})
|
||||
return baseFetch(url, {
|
||||
...options,
|
||||
headers: {
|
||||
...(options.headers as Record<string, string>),
|
||||
...signature
|
||||
}
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Core Stream Execution
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Execute stream and return adapter with output stream
|
||||
*
|
||||
* Uses MessageConverterFactory to create the appropriate converter
|
||||
* based on input format, eliminating format-specific if-else logic.
|
||||
*/
|
||||
async function executeStream(config: ExecuteStreamConfig): Promise<{
|
||||
adapter: IStreamAdapter
|
||||
outputStream: ReadableStream<unknown>
|
||||
}> {
|
||||
const { provider, modelId, params, inputFormat, outputFormat, middlewares = [], plugins = [] } = config
|
||||
|
||||
// Convert provider config to AI SDK config
|
||||
let sdkConfig = providerToAiSdkConfig(provider, modelId)
|
||||
sdkConfig = await prepareSpecialProviderConfig(provider, sdkConfig)
|
||||
|
||||
// Create provider instance and get language model
|
||||
const aiSdkProvider = await createAiSdkProvider(sdkConfig)
|
||||
const baseModel = aiSdkProvider.languageModel(modelId)
|
||||
|
||||
// Apply middlewares if present
|
||||
const model =
|
||||
middlewares.length > 0 && typeof baseModel === 'object'
|
||||
? (wrapLanguageModel({ model: baseModel, middleware: middlewares as never }) as typeof baseModel)
|
||||
: baseModel
|
||||
|
||||
// Create executor with plugins
|
||||
const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins)
|
||||
|
||||
const converter = MessageConverterFactory.create(inputFormat, {
|
||||
googleReasoningCache,
|
||||
openRouterReasoningCache
|
||||
})
|
||||
|
||||
// Convert messages, tools, and extract options using unified interface
|
||||
const coreMessages = converter.toAiSdkMessages(params)
|
||||
const tools = converter.toAiSdkTools?.(params)
|
||||
const streamOptions = converter.extractStreamOptions(params)
|
||||
const providerOptions = converter.extractProviderOptions(provider, params)
|
||||
|
||||
// Create adapter via factory
|
||||
const adapter = StreamAdapterFactory.createAdapter(outputFormat, {
|
||||
model: `${provider.id}:${modelId}`
|
||||
})
|
||||
|
||||
// Execute AI SDK stream with extracted options
|
||||
const result = await executor.streamText({
|
||||
model,
|
||||
messages: coreMessages,
|
||||
...streamOptions,
|
||||
stopWhen: stepCountIs(100),
|
||||
headers: defaultAppHeaders(),
|
||||
tools,
|
||||
providerOptions
|
||||
})
|
||||
|
||||
// Transform stream using adapter
|
||||
const outputStream = adapter.transform(result.fullStream)
|
||||
|
||||
return { adapter, outputStream }
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Public API
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Stream a message request and write to HTTP response
|
||||
*
|
||||
* Uses TransformStream-based adapters for efficient streaming.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* await streamToResponse({
|
||||
* response: res,
|
||||
* provider,
|
||||
* modelId: 'claude-3-opus',
|
||||
* params: messageCreateParams,
|
||||
* outputFormat: 'anthropic'
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export async function streamToResponse(config: StreamConfig): Promise<void> {
|
||||
const {
|
||||
response,
|
||||
provider,
|
||||
modelId,
|
||||
params,
|
||||
inputFormat = 'anthropic',
|
||||
outputFormat = 'anthropic',
|
||||
onError,
|
||||
onComplete,
|
||||
middlewares = [],
|
||||
plugins = []
|
||||
} = config
|
||||
|
||||
logger.info('Starting proxy stream', {
|
||||
providerId: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId,
|
||||
inputFormat,
|
||||
outputFormat,
|
||||
middlewareCount: middlewares.length,
|
||||
pluginCount: plugins.length
|
||||
})
|
||||
|
||||
try {
|
||||
// Set SSE headers
|
||||
response.setHeader('Content-Type', 'text/event-stream')
|
||||
response.setHeader('Cache-Control', 'no-cache')
|
||||
response.setHeader('Connection', 'keep-alive')
|
||||
response.setHeader('X-Accel-Buffering', 'no')
|
||||
|
||||
const { outputStream } = await executeStream({
|
||||
provider,
|
||||
modelId,
|
||||
params,
|
||||
inputFormat,
|
||||
outputFormat,
|
||||
middlewares,
|
||||
plugins
|
||||
})
|
||||
|
||||
// Get formatter for the output format
|
||||
const formatter = StreamAdapterFactory.getFormatter(outputFormat)
|
||||
|
||||
// Stream events to response
|
||||
const reader = outputStream.getReader()
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
response.write(formatter.formatEvent(value))
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
|
||||
// Send done marker and end response
|
||||
response.write(formatter.formatDone())
|
||||
response.end()
|
||||
|
||||
logger.info('Proxy stream completed', { providerId: provider.id, modelId })
|
||||
onComplete?.()
|
||||
} catch (error) {
|
||||
logger.error('Error in proxy stream', error as Error, { providerId: provider.id, modelId })
|
||||
onError?.(error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a non-streaming message response
|
||||
*
|
||||
* Uses simulateStreamingMiddleware to reuse the same streaming logic.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const message = await generateMessage({
|
||||
* provider,
|
||||
* modelId: 'claude-3-opus',
|
||||
* params: messageCreateParams,
|
||||
* outputFormat: 'anthropic'
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export async function generateMessage(config: GenerateConfig): Promise<unknown> {
|
||||
const {
|
||||
provider,
|
||||
modelId,
|
||||
params,
|
||||
inputFormat = 'anthropic',
|
||||
outputFormat = 'anthropic',
|
||||
middlewares = [],
|
||||
plugins = []
|
||||
} = config
|
||||
|
||||
logger.info('Starting message generation', {
|
||||
providerId: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId,
|
||||
inputFormat,
|
||||
outputFormat,
|
||||
middlewareCount: middlewares.length,
|
||||
pluginCount: plugins.length
|
||||
})
|
||||
|
||||
try {
|
||||
// Add simulateStreamingMiddleware to reuse streaming logic
|
||||
const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares]
|
||||
|
||||
const { adapter, outputStream } = await executeStream({
|
||||
provider,
|
||||
modelId,
|
||||
params,
|
||||
inputFormat,
|
||||
outputFormat,
|
||||
middlewares: allMiddlewares,
|
||||
plugins
|
||||
})
|
||||
|
||||
// Consume the stream to populate adapter state
|
||||
const reader = outputStream.getReader()
|
||||
while (true) {
|
||||
const { done } = await reader.read()
|
||||
if (done) break
|
||||
}
|
||||
reader.releaseLock()
|
||||
|
||||
// Build final response from adapter
|
||||
const finalResponse = adapter.buildNonStreamingResponse()
|
||||
|
||||
logger.info('Message generation completed', { providerId: provider.id, modelId })
|
||||
|
||||
return finalResponse
|
||||
} catch (error) {
|
||||
logger.error('Error in message generation', error as Error, { providerId: provider.id, modelId })
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
export default {
|
||||
streamToResponse,
|
||||
generateMessage
|
||||
}
|
||||
@ -1,7 +1,7 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import * as z from 'zod'
|
||||
|
||||
import { type JsonSchemaLike, jsonSchemaToZod } from '../unified-messages'
|
||||
import { type JsonSchemaLike, jsonSchemaToZod } from '../../adapters/converters/json-schema-to-zod'
|
||||
|
||||
describe('jsonSchemaToZod', () => {
|
||||
describe('Basic Types', () => {
|
||||
|
||||
@ -1,10 +1,18 @@
|
||||
import type { MessageCreateParams } from '@anthropic-ai/sdk/resources/messages'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { convertAnthropicToAiMessages, convertAnthropicToolsToAiSdk } from '../unified-messages'
|
||||
import { AnthropicMessageConverter } from '../../adapters/converters/AnthropicMessageConverter'
|
||||
|
||||
describe('unified-messages', () => {
|
||||
describe('convertAnthropicToolsToAiSdk', () => {
|
||||
// Create a converter instance for testing
|
||||
const converter = new AnthropicMessageConverter()
|
||||
|
||||
// Helper functions that wrap the converter methods
|
||||
const convertAnthropicToAiMessages = (params: MessageCreateParams) => converter.toAiSdkMessages(params)
|
||||
const convertAnthropicToolsToAiSdk = (tools: MessageCreateParams['tools']) =>
|
||||
converter.toAiSdkTools({ model: 'test', max_tokens: 100, messages: [], tools })
|
||||
|
||||
describe('AnthropicMessageConverter', () => {
|
||||
describe('toAiSdkTools', () => {
|
||||
it('should return undefined for empty tools array', () => {
|
||||
const result = convertAnthropicToolsToAiSdk([])
|
||||
expect(result).toBeUndefined()
|
||||
@ -135,7 +143,7 @@ describe('unified-messages', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('convertAnthropicToAiMessages', () => {
|
||||
describe('toAiSdkMessages', () => {
|
||||
describe('System Messages', () => {
|
||||
it('should convert string system message', () => {
|
||||
const params: MessageCreateParams = {
|
||||
|
||||
@ -1,260 +0,0 @@
|
||||
import OpenAI from '@cherrystudio/openai'
|
||||
import type { ChatCompletionCreateParams, ChatCompletionCreateParamsStreaming } from '@cherrystudio/openai/resources'
|
||||
import type { Provider } from '@types'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import type { ModelValidationError } from '../utils'
|
||||
import { validateModelId } from '../utils'
|
||||
|
||||
const logger = loggerService.withContext('ChatCompletionService')
|
||||
|
||||
export interface ValidationResult {
|
||||
isValid: boolean
|
||||
errors: string[]
|
||||
}
|
||||
|
||||
export class ChatCompletionValidationError extends Error {
|
||||
constructor(public readonly errors: string[]) {
|
||||
super(`Request validation failed: ${errors.join('; ')}`)
|
||||
this.name = 'ChatCompletionValidationError'
|
||||
}
|
||||
}
|
||||
|
||||
export class ChatCompletionModelError extends Error {
|
||||
constructor(public readonly error: ModelValidationError) {
|
||||
super(`Model validation failed: ${error.message}`)
|
||||
this.name = 'ChatCompletionModelError'
|
||||
}
|
||||
}
|
||||
|
||||
export type PrepareRequestResult =
|
||||
| { status: 'validation_error'; errors: string[] }
|
||||
| { status: 'model_error'; error: ModelValidationError }
|
||||
| {
|
||||
status: 'ok'
|
||||
provider: Provider
|
||||
modelId: string
|
||||
client: OpenAI
|
||||
providerRequest: ChatCompletionCreateParams
|
||||
}
|
||||
|
||||
export class ChatCompletionService {
|
||||
async resolveProviderContext(
|
||||
model: string
|
||||
): Promise<
|
||||
{ ok: false; error: ModelValidationError } | { ok: true; provider: Provider; modelId: string; client: OpenAI }
|
||||
> {
|
||||
const modelValidation = await validateModelId(model)
|
||||
if (!modelValidation.valid) {
|
||||
return {
|
||||
ok: false,
|
||||
error: modelValidation.error!
|
||||
}
|
||||
}
|
||||
|
||||
const provider = modelValidation.provider!
|
||||
|
||||
if (provider.type !== 'openai') {
|
||||
return {
|
||||
ok: false,
|
||||
error: {
|
||||
type: 'unsupported_provider_type',
|
||||
message: `Provider '${provider.id}' of type '${provider.type}' is not supported for OpenAI chat completions`,
|
||||
code: 'unsupported_provider_type'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const modelId = modelValidation.modelId!
|
||||
|
||||
const client = new OpenAI({
|
||||
baseURL: provider.apiHost,
|
||||
apiKey: provider.apiKey
|
||||
})
|
||||
|
||||
return {
|
||||
ok: true,
|
||||
provider,
|
||||
modelId,
|
||||
client
|
||||
}
|
||||
}
|
||||
|
||||
async prepareRequest(request: ChatCompletionCreateParams, stream: boolean): Promise<PrepareRequestResult> {
|
||||
const requestValidation = this.validateRequest(request)
|
||||
if (!requestValidation.isValid) {
|
||||
return {
|
||||
status: 'validation_error',
|
||||
errors: requestValidation.errors
|
||||
}
|
||||
}
|
||||
|
||||
const providerContext = await this.resolveProviderContext(request.model!)
|
||||
if (!providerContext.ok) {
|
||||
return {
|
||||
status: 'model_error',
|
||||
error: providerContext.error
|
||||
}
|
||||
}
|
||||
|
||||
const { provider, modelId, client } = providerContext
|
||||
|
||||
logger.debug('Model validation successful', {
|
||||
provider: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId,
|
||||
fullModelId: request.model
|
||||
})
|
||||
|
||||
return {
|
||||
status: 'ok',
|
||||
provider,
|
||||
modelId,
|
||||
client,
|
||||
providerRequest: stream
|
||||
? {
|
||||
...request,
|
||||
model: modelId,
|
||||
stream: true as const
|
||||
}
|
||||
: {
|
||||
...request,
|
||||
model: modelId,
|
||||
stream: false as const
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
validateRequest(request: ChatCompletionCreateParams): ValidationResult {
|
||||
const errors: string[] = []
|
||||
|
||||
// Validate messages
|
||||
if (!request.messages) {
|
||||
errors.push('Messages array is required')
|
||||
} else if (!Array.isArray(request.messages)) {
|
||||
errors.push('Messages must be an array')
|
||||
} else if (request.messages.length === 0) {
|
||||
errors.push('Messages array cannot be empty')
|
||||
} else {
|
||||
// Validate each message
|
||||
request.messages.forEach((message, index) => {
|
||||
if (!message.role) {
|
||||
errors.push(`Message ${index}: role is required`)
|
||||
}
|
||||
if (!message.content) {
|
||||
errors.push(`Message ${index}: content is required`)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Validate optional parameters
|
||||
|
||||
return {
|
||||
isValid: errors.length === 0,
|
||||
errors
|
||||
}
|
||||
}
|
||||
|
||||
async processCompletion(request: ChatCompletionCreateParams): Promise<{
|
||||
provider: Provider
|
||||
modelId: string
|
||||
response: OpenAI.Chat.Completions.ChatCompletion
|
||||
}> {
|
||||
try {
|
||||
logger.debug('Processing chat completion request', {
|
||||
model: request.model,
|
||||
messageCount: request.messages.length,
|
||||
stream: request.stream
|
||||
})
|
||||
|
||||
const preparation = await this.prepareRequest(request, false)
|
||||
if (preparation.status === 'validation_error') {
|
||||
throw new ChatCompletionValidationError(preparation.errors)
|
||||
}
|
||||
|
||||
if (preparation.status === 'model_error') {
|
||||
throw new ChatCompletionModelError(preparation.error)
|
||||
}
|
||||
|
||||
const { provider, modelId, client, providerRequest } = preparation
|
||||
|
||||
logger.debug('Sending request to provider', {
|
||||
provider: provider.id,
|
||||
model: modelId,
|
||||
apiHost: provider.apiHost
|
||||
})
|
||||
|
||||
const response = (await client.chat.completions.create(providerRequest)) as OpenAI.Chat.Completions.ChatCompletion
|
||||
|
||||
logger.info('Chat completion processed', {
|
||||
modelId,
|
||||
provider: provider.id
|
||||
})
|
||||
return {
|
||||
provider,
|
||||
modelId,
|
||||
response
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.error('Error processing chat completion', {
|
||||
error,
|
||||
model: request.model
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async processStreamingCompletion(request: ChatCompletionCreateParams): Promise<{
|
||||
provider: Provider
|
||||
modelId: string
|
||||
stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
|
||||
}> {
|
||||
try {
|
||||
logger.debug('Processing streaming chat completion request', {
|
||||
model: request.model,
|
||||
messageCount: request.messages.length
|
||||
})
|
||||
|
||||
const preparation = await this.prepareRequest(request, true)
|
||||
if (preparation.status === 'validation_error') {
|
||||
throw new ChatCompletionValidationError(preparation.errors)
|
||||
}
|
||||
|
||||
if (preparation.status === 'model_error') {
|
||||
throw new ChatCompletionModelError(preparation.error)
|
||||
}
|
||||
|
||||
const { provider, modelId, client, providerRequest } = preparation
|
||||
|
||||
logger.debug('Sending streaming request to provider', {
|
||||
provider: provider.id,
|
||||
model: modelId,
|
||||
apiHost: provider.apiHost
|
||||
})
|
||||
|
||||
const streamRequest = providerRequest as ChatCompletionCreateParamsStreaming
|
||||
const stream = (await client.chat.completions.create(
|
||||
streamRequest
|
||||
)) as AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
|
||||
|
||||
logger.info('Streaming chat completion started', {
|
||||
modelId,
|
||||
provider: provider.id
|
||||
})
|
||||
return {
|
||||
provider,
|
||||
modelId,
|
||||
stream
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.error('Error processing streaming chat completion', {
|
||||
error,
|
||||
model: request.model
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Export singleton instance
|
||||
export const chatCompletionService = new ChatCompletionService()
|
||||
@ -1,762 +0,0 @@
|
||||
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
|
||||
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
|
||||
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
|
||||
import type { JSONSchema7, LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
|
||||
import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
|
||||
import type {
|
||||
ImageBlockParam,
|
||||
MessageCreateParams,
|
||||
TextBlockParam,
|
||||
Tool as AnthropicTool
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core'
|
||||
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@main/apiServer/adapters'
|
||||
import { generateSignature as cherryaiGenerateSignature } from '@main/integration/cherryai'
|
||||
import anthropicService from '@main/services/AnthropicService'
|
||||
import copilotService from '@main/services/CopilotService'
|
||||
import { reduxService } from '@main/services/ReduxService'
|
||||
import type { OpenRouterProviderOptions } from '@openrouter/ai-sdk-provider'
|
||||
import {
|
||||
type AiSdkConfig,
|
||||
type AiSdkConfigContext,
|
||||
formatProviderApiHost,
|
||||
initializeSharedProviders,
|
||||
type ProviderFormatContext,
|
||||
providerToAiSdkConfig as sharedProviderToAiSdkConfig,
|
||||
resolveActualProvider
|
||||
} from '@shared/aiCore'
|
||||
import { COPILOT_DEFAULT_HEADERS } from '@shared/aiCore/constant'
|
||||
import { isGemini3ModelId } from '@shared/aiCore/middlewares'
|
||||
import type { MinimalProvider } from '@shared/types'
|
||||
import { SystemProviderIds } from '@shared/types'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import { isAnthropicProvider, isGeminiProvider, isOpenAIProvider } from '@shared/utils/provider'
|
||||
import type { Provider } from '@types'
|
||||
import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool as AiSdkTool } from 'ai'
|
||||
import { simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel, zodSchema } from 'ai'
|
||||
import { net } from 'electron'
|
||||
import type { Response } from 'express'
|
||||
import * as z from 'zod'
|
||||
|
||||
import { googleReasoningCache, openRouterReasoningCache } from './reasoning-cache'
|
||||
|
||||
const logger = loggerService.withContext('UnifiedMessagesService')
|
||||
|
||||
const MAGIC_STRING = 'skip_thought_signature_validator'
|
||||
|
||||
function sanitizeJson(value: unknown): JSONValue {
|
||||
return JSON.parse(JSON.stringify(value))
|
||||
}
|
||||
|
||||
initializeSharedProviders({
|
||||
warn: (message) => logger.warn(message),
|
||||
error: (message, error) => logger.error(message, error)
|
||||
})
|
||||
|
||||
/**
|
||||
* Configuration for unified message streaming
|
||||
*/
|
||||
export interface UnifiedStreamConfig {
|
||||
response: Response
|
||||
provider: Provider
|
||||
modelId: string
|
||||
params: MessageCreateParams
|
||||
onError?: (error: unknown) => void
|
||||
onComplete?: () => void
|
||||
/**
|
||||
* Optional AI SDK middlewares to apply
|
||||
*/
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
/**
|
||||
* Optional AI Core plugins to use with the executor
|
||||
*/
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for non-streaming message generation
|
||||
*/
|
||||
export interface GenerateUnifiedMessageConfig {
|
||||
provider: Provider
|
||||
modelId: string
|
||||
params: MessageCreateParams
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
|
||||
function getMainProcessFormatContext(): ProviderFormatContext {
|
||||
const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai')
|
||||
return {
|
||||
vertex: {
|
||||
project: vertexSettings?.projectId || 'default-project',
|
||||
location: vertexSettings?.location || 'us-central1'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function isSupportStreamOptionsProvider(provider: MinimalProvider): boolean {
|
||||
const NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS = ['mistral'] as const
|
||||
return !NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS.some((pid) => pid === provider.id)
|
||||
}
|
||||
|
||||
const mainProcessSdkContext: AiSdkConfigContext = {
|
||||
isSupportStreamOptionsProvider,
|
||||
getIncludeUsageSetting: () =>
|
||||
reduxService.selectSync<boolean | undefined>('state.settings.openAI?.streamOptions?.includeUsage'),
|
||||
fetch: net.fetch as typeof globalThis.fetch
|
||||
}
|
||||
|
||||
function getActualProvider(provider: Provider, modelId: string): Provider {
|
||||
const model = provider.models?.find((m) => m.id === modelId)
|
||||
if (!model) return provider
|
||||
return resolveActualProvider(provider, model)
|
||||
}
|
||||
|
||||
function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig {
|
||||
const actualProvider = getActualProvider(provider, modelId)
|
||||
const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext())
|
||||
return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext)
|
||||
}
|
||||
|
||||
function convertAnthropicToolResultToAiSdk(
|
||||
content: string | Array<TextBlockParam | ImageBlockParam>
|
||||
): LanguageModelV2ToolResultOutput {
|
||||
if (typeof content === 'string') {
|
||||
return { type: 'text', value: content }
|
||||
}
|
||||
const values: Array<{ type: 'text'; text: string } | { type: 'media'; data: string; mediaType: string }> = []
|
||||
for (const block of content) {
|
||||
if (block.type === 'text') {
|
||||
values.push({ type: 'text', text: block.text })
|
||||
} else if (block.type === 'image') {
|
||||
values.push({
|
||||
type: 'media',
|
||||
data: block.source.type === 'base64' ? block.source.data : block.source.url,
|
||||
mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png'
|
||||
})
|
||||
}
|
||||
}
|
||||
return { type: 'content', value: values }
|
||||
}
|
||||
|
||||
/**
|
||||
* JSON Schema type for tool input schemas
|
||||
*/
|
||||
export type JsonSchemaLike = JSONSchema7
|
||||
|
||||
/**
|
||||
* Convert JSON Schema to Zod schema
|
||||
* This avoids non-standard fields like input_examples that Anthropic doesn't support
|
||||
* TODO: Anthropic/beta support input_examples
|
||||
*/
|
||||
export function jsonSchemaToZod(schema: JsonSchemaLike): z.ZodTypeAny {
|
||||
const schemaType = schema.type
|
||||
const enumValues = schema.enum
|
||||
const description = schema.description
|
||||
|
||||
// Handle enum first
|
||||
if (enumValues && Array.isArray(enumValues) && enumValues.length > 0) {
|
||||
if (enumValues.every((v) => typeof v === 'string')) {
|
||||
const zodEnum = z.enum(enumValues as [string, ...string[]])
|
||||
return description ? zodEnum.describe(description) : zodEnum
|
||||
}
|
||||
// For non-string enums, use union of literals
|
||||
const literals = enumValues.map((v) => z.literal(v as string | number | boolean))
|
||||
if (literals.length === 1) {
|
||||
return description ? literals[0].describe(description) : literals[0]
|
||||
}
|
||||
const zodUnion = z.union(literals as unknown as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
|
||||
return description ? zodUnion.describe(description) : zodUnion
|
||||
}
|
||||
|
||||
// Handle union types (type: ["string", "null"])
|
||||
if (Array.isArray(schemaType)) {
|
||||
const schemas = schemaType.map((t) =>
|
||||
jsonSchemaToZod({
|
||||
...schema,
|
||||
type: t,
|
||||
enum: undefined
|
||||
})
|
||||
)
|
||||
if (schemas.length === 1) {
|
||||
return schemas[0]
|
||||
}
|
||||
return z.union(schemas as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
|
||||
}
|
||||
|
||||
// Handle by type
|
||||
switch (schemaType) {
|
||||
case 'string': {
|
||||
let zodString = z.string()
|
||||
if (typeof schema.minLength === 'number') zodString = zodString.min(schema.minLength)
|
||||
if (typeof schema.maxLength === 'number') zodString = zodString.max(schema.maxLength)
|
||||
if (typeof schema.pattern === 'string') zodString = zodString.regex(new RegExp(schema.pattern))
|
||||
return description ? zodString.describe(description) : zodString
|
||||
}
|
||||
|
||||
case 'number':
|
||||
case 'integer': {
|
||||
let zodNumber = schemaType === 'integer' ? z.number().int() : z.number()
|
||||
if (typeof schema.minimum === 'number') zodNumber = zodNumber.min(schema.minimum)
|
||||
if (typeof schema.maximum === 'number') zodNumber = zodNumber.max(schema.maximum)
|
||||
return description ? zodNumber.describe(description) : zodNumber
|
||||
}
|
||||
|
||||
case 'boolean': {
|
||||
const zodBoolean = z.boolean()
|
||||
return description ? zodBoolean.describe(description) : zodBoolean
|
||||
}
|
||||
|
||||
case 'null':
|
||||
return z.null()
|
||||
|
||||
case 'array': {
|
||||
const items = schema.items
|
||||
let zodArray: z.ZodArray<z.ZodTypeAny>
|
||||
if (items && typeof items === 'object' && !Array.isArray(items)) {
|
||||
zodArray = z.array(jsonSchemaToZod(items as JsonSchemaLike))
|
||||
} else {
|
||||
zodArray = z.array(z.unknown())
|
||||
}
|
||||
if (typeof schema.minItems === 'number') zodArray = zodArray.min(schema.minItems)
|
||||
if (typeof schema.maxItems === 'number') zodArray = zodArray.max(schema.maxItems)
|
||||
return description ? zodArray.describe(description) : zodArray
|
||||
}
|
||||
|
||||
case 'object': {
|
||||
const properties = schema.properties
|
||||
const required = schema.required || []
|
||||
|
||||
// Always use z.object() to ensure "properties" field is present in output schema
|
||||
// OpenAI requires explicit properties field even for empty objects
|
||||
const shape: Record<string, z.ZodTypeAny> = {}
|
||||
if (properties && typeof properties === 'object') {
|
||||
for (const [key, propSchema] of Object.entries(properties)) {
|
||||
if (typeof propSchema === 'boolean') {
|
||||
shape[key] = propSchema ? z.unknown() : z.never()
|
||||
} else {
|
||||
const zodProp = jsonSchemaToZod(propSchema as JsonSchemaLike)
|
||||
shape[key] = required.includes(key) ? zodProp : zodProp.optional()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const zodObject = z.object(shape)
|
||||
return description ? zodObject.describe(description) : zodObject
|
||||
}
|
||||
|
||||
default:
|
||||
// Unknown type, use z.unknown()
|
||||
return z.unknown()
|
||||
}
|
||||
}
|
||||
|
||||
export function convertAnthropicToolsToAiSdk(
|
||||
tools: MessageCreateParams['tools']
|
||||
): Record<string, AiSdkTool> | undefined {
|
||||
if (!tools || tools.length === 0) return undefined
|
||||
|
||||
const aiSdkTools: Record<string, AiSdkTool> = {}
|
||||
for (const anthropicTool of tools) {
|
||||
if (anthropicTool.type === 'bash_20250124') continue
|
||||
const toolDef = anthropicTool as AnthropicTool
|
||||
const rawSchema = toolDef.input_schema
|
||||
// Convert Anthropic's InputSchema to JSONSchema7-compatible format
|
||||
const schema = jsonSchemaToZod(rawSchema as JsonSchemaLike)
|
||||
|
||||
// Use tool() with inputSchema (AI SDK v5 API)
|
||||
const aiTool = tool({
|
||||
description: toolDef.description || '',
|
||||
inputSchema: zodSchema(schema)
|
||||
})
|
||||
|
||||
aiSdkTools[toolDef.name] = aiTool
|
||||
}
|
||||
return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined
|
||||
}
|
||||
|
||||
export function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage[] {
|
||||
const messages: ModelMessage[] = []
|
||||
|
||||
// System message
|
||||
if (params.system) {
|
||||
if (typeof params.system === 'string') {
|
||||
messages.push({ role: 'system', content: params.system })
|
||||
} else if (Array.isArray(params.system)) {
|
||||
const systemText = params.system
|
||||
.filter((block) => block.type === 'text')
|
||||
.map((block) => block.text)
|
||||
.join('\n')
|
||||
if (systemText) {
|
||||
messages.push({ role: 'system', content: systemText })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const toolCallIdToName = new Map<string, string>()
|
||||
for (const msg of params.messages) {
|
||||
if (Array.isArray(msg.content)) {
|
||||
for (const block of msg.content) {
|
||||
if (block.type === 'tool_use') {
|
||||
toolCallIdToName.set(block.id, block.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// User/assistant messages
|
||||
for (const msg of params.messages) {
|
||||
if (typeof msg.content === 'string') {
|
||||
messages.push({
|
||||
role: msg.role === 'user' ? 'user' : 'assistant',
|
||||
content: msg.content
|
||||
})
|
||||
} else if (Array.isArray(msg.content)) {
|
||||
const textParts: TextPart[] = []
|
||||
const imageParts: ImagePart[] = []
|
||||
const reasoningParts: ReasoningPart[] = []
|
||||
const toolCallParts: ToolCallPart[] = []
|
||||
const toolResultParts: ToolResultPart[] = []
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === 'text') {
|
||||
textParts.push({ type: 'text', text: block.text })
|
||||
} else if (block.type === 'thinking') {
|
||||
reasoningParts.push({ type: 'reasoning', text: block.thinking })
|
||||
} else if (block.type === 'redacted_thinking') {
|
||||
reasoningParts.push({ type: 'reasoning', text: block.data })
|
||||
} else if (block.type === 'image') {
|
||||
const source = block.source
|
||||
if (source.type === 'base64') {
|
||||
imageParts.push({ type: 'image', image: `data:${source.media_type};base64,${source.data}` })
|
||||
} else if (source.type === 'url') {
|
||||
imageParts.push({ type: 'image', image: source.url })
|
||||
}
|
||||
} else if (block.type === 'tool_use') {
|
||||
const options: ProviderOptions = {}
|
||||
logger.debug('Processing tool call block', { block, msgRole: msg.role, model: params.model })
|
||||
if (isGemini3ModelId(params.model)) {
|
||||
if (googleReasoningCache.get(`google-${block.name}`)) {
|
||||
options.google = {
|
||||
thoughtSignature: MAGIC_STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
if (openRouterReasoningCache.get(`openrouter-${block.id}`)) {
|
||||
options.openrouter = {
|
||||
reasoning_details:
|
||||
(sanitizeJson(openRouterReasoningCache.get(`openrouter-${block.id}`)) as JSONValue[]) || []
|
||||
}
|
||||
}
|
||||
toolCallParts.push({
|
||||
type: 'tool-call',
|
||||
toolName: block.name,
|
||||
toolCallId: block.id,
|
||||
input: block.input,
|
||||
providerOptions: options
|
||||
})
|
||||
} else if (block.type === 'tool_result') {
|
||||
// Look up toolName from the pre-built map (covers cross-message references)
|
||||
const toolName = toolCallIdToName.get(block.tool_use_id) || 'unknown'
|
||||
toolResultParts.push({
|
||||
type: 'tool-result',
|
||||
toolCallId: block.tool_use_id,
|
||||
toolName,
|
||||
output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' }
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if (toolResultParts.length > 0) {
|
||||
messages.push({ role: 'tool', content: [...toolResultParts] })
|
||||
}
|
||||
|
||||
if (msg.role === 'user') {
|
||||
const userContent = [...textParts, ...imageParts]
|
||||
if (userContent.length > 0) {
|
||||
messages.push({ role: 'user', content: userContent })
|
||||
}
|
||||
} else {
|
||||
const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts]
|
||||
if (assistantContent.length > 0) {
|
||||
let providerOptions: ProviderOptions | undefined = undefined
|
||||
if (openRouterReasoningCache.get('openrouter')) {
|
||||
providerOptions = {
|
||||
openrouter: {
|
||||
reasoning_details: (sanitizeJson(openRouterReasoningCache.get('openrouter')) as JSONValue[]) || []
|
||||
}
|
||||
}
|
||||
} else if (isGemini3ModelId(params.model)) {
|
||||
providerOptions = {
|
||||
google: {
|
||||
thoughtSignature: MAGIC_STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
messages.push({ role: 'assistant', content: assistantContent, providerOptions })
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
interface ExecuteStreamConfig {
|
||||
provider: Provider
|
||||
modelId: string
|
||||
params: MessageCreateParams
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
plugins?: AiPlugin[]
|
||||
onEvent?: (event: Parameters<typeof formatSSEEvent>[0]) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* Create AI SDK provider instance from config
|
||||
* Similar to renderer's createAiSdkProvider
|
||||
*/
|
||||
async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider> {
|
||||
let providerId = config.providerId
|
||||
|
||||
// Handle special provider modes (same as renderer)
|
||||
if (providerId === 'openai' && config.options?.mode === 'chat') {
|
||||
providerId = 'openai-chat'
|
||||
} else if (providerId === 'azure' && config.options?.mode === 'responses') {
|
||||
providerId = 'azure-responses'
|
||||
} else if (providerId === 'cherryin' && config.options?.mode === 'chat') {
|
||||
providerId = 'cherryin-chat'
|
||||
}
|
||||
|
||||
const provider = await createProviderCore(providerId, config.options)
|
||||
|
||||
return provider
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare special provider configuration for providers that need dynamic tokens
|
||||
* Similar to renderer's prepareSpecialProviderConfig
|
||||
*/
|
||||
async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkConfig): Promise<AiSdkConfig> {
|
||||
switch (provider.id) {
|
||||
case 'copilot': {
|
||||
const storedHeaders =
|
||||
((await reduxService.select('state.copilot.defaultHeaders')) as Record<string, string> | null) ?? {}
|
||||
const headers: Record<string, string> = {
|
||||
...COPILOT_DEFAULT_HEADERS,
|
||||
...storedHeaders
|
||||
}
|
||||
|
||||
try {
|
||||
const { token } = await copilotService.getToken(null as any, headers)
|
||||
config.options.apiKey = token
|
||||
const existingHeaders = (config.options.headers as Record<string, string> | undefined) ?? {}
|
||||
config.options.headers = {
|
||||
...headers,
|
||||
...existingHeaders
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to get Copilot token', error as Error)
|
||||
throw new Error('Failed to get Copilot token. Please re-authorize Copilot.')
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'anthropic': {
|
||||
if (provider.authType === 'oauth') {
|
||||
try {
|
||||
const oauthToken = await anthropicService.getValidAccessToken()
|
||||
if (!oauthToken) {
|
||||
throw new Error('Anthropic OAuth token not available. Please re-authorize.')
|
||||
}
|
||||
config.options = {
|
||||
...config.options,
|
||||
headers: {
|
||||
...(config.options.headers ? config.options.headers : {}),
|
||||
'Content-Type': 'application/json',
|
||||
'anthropic-version': '2023-06-01',
|
||||
'anthropic-beta': 'oauth-2025-04-20',
|
||||
Authorization: `Bearer ${oauthToken}`
|
||||
},
|
||||
baseURL: 'https://api.anthropic.com/v1',
|
||||
apiKey: ''
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to get Anthropic OAuth token', error as Error)
|
||||
throw new Error('Failed to get Anthropic OAuth token. Please re-authorize.')
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'cherryai': {
|
||||
// Create a signed fetch wrapper for cherryai
|
||||
const baseFetch = net.fetch as typeof globalThis.fetch
|
||||
config.options.fetch = async (url: RequestInfo | URL, options?: RequestInit) => {
|
||||
if (!options?.body) {
|
||||
return baseFetch(url, options)
|
||||
}
|
||||
const signature = cherryaiGenerateSignature({
|
||||
method: 'POST',
|
||||
path: '/chat/completions',
|
||||
query: '',
|
||||
body: JSON.parse(options.body as string)
|
||||
})
|
||||
return baseFetch(url, {
|
||||
...options,
|
||||
headers: {
|
||||
...(options.headers as Record<string, string>),
|
||||
...signature
|
||||
}
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
function mapAnthropicThinkToAISdkProviderOptions(
|
||||
provider: Provider,
|
||||
config: MessageCreateParams['thinking']
|
||||
): ProviderOptions | undefined {
|
||||
if (!config) return undefined
|
||||
if (isAnthropicProvider(provider)) {
|
||||
return {
|
||||
anthropic: {
|
||||
...mapToAnthropicProviderOptions(config)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (isGeminiProvider(provider)) {
|
||||
return {
|
||||
google: {
|
||||
...mapToGeminiProviderOptions(config)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (isOpenAIProvider(provider)) {
|
||||
return {
|
||||
openai: {
|
||||
...mapToOpenAIProviderOptions(config)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (provider.id === SystemProviderIds.openrouter) {
|
||||
return {
|
||||
openrouter: {
|
||||
...mapToOpenRouterProviderOptions(config)
|
||||
}
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
function mapToAnthropicProviderOptions(config: NonNullable<MessageCreateParams['thinking']>): AnthropicProviderOptions {
|
||||
return {
|
||||
thinking: {
|
||||
type: config.type,
|
||||
budgetTokens: config.type === 'enabled' ? config.budget_tokens : undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function mapToGeminiProviderOptions(
|
||||
config: NonNullable<MessageCreateParams['thinking']>
|
||||
): GoogleGenerativeAIProviderOptions {
|
||||
return {
|
||||
thinkingConfig: {
|
||||
thinkingBudget: config.type === 'enabled' ? config.budget_tokens : -1,
|
||||
includeThoughts: config.type === 'enabled'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function mapToOpenAIProviderOptions(
|
||||
config: NonNullable<MessageCreateParams['thinking']>
|
||||
): OpenAIResponsesProviderOptions {
|
||||
return {
|
||||
reasoningEffort: config.type === 'enabled' ? 'high' : 'none'
|
||||
}
|
||||
}
|
||||
|
||||
function mapToOpenRouterProviderOptions(
|
||||
config: NonNullable<MessageCreateParams['thinking']>
|
||||
): OpenRouterProviderOptions {
|
||||
return {
|
||||
reasoning: {
|
||||
enabled: config.type === 'enabled',
|
||||
effort: 'high'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Core stream execution function - single source of truth for AI SDK calls
|
||||
*/
|
||||
async function executeStream(config: ExecuteStreamConfig): Promise<AiSdkToAnthropicSSE> {
|
||||
const { provider, modelId, params, middlewares = [], plugins = [], onEvent } = config
|
||||
|
||||
// Convert provider config to AI SDK config
|
||||
let sdkConfig = providerToAiSdkConfig(provider, modelId)
|
||||
|
||||
// Prepare special provider config (Copilot, Anthropic OAuth, etc.)
|
||||
sdkConfig = await prepareSpecialProviderConfig(provider, sdkConfig)
|
||||
|
||||
// Create provider instance and get language model
|
||||
const aiSdkProvider = await createAiSdkProvider(sdkConfig)
|
||||
const baseModel = aiSdkProvider.languageModel(modelId)
|
||||
|
||||
// Apply middlewares if present
|
||||
const model =
|
||||
middlewares.length > 0 && typeof baseModel === 'object'
|
||||
? (wrapLanguageModel({ model: baseModel, middleware: middlewares }) as typeof baseModel)
|
||||
: baseModel
|
||||
|
||||
// Create executor with plugins
|
||||
const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins)
|
||||
|
||||
// Convert messages and tools
|
||||
const coreMessages = convertAnthropicToAiMessages(params)
|
||||
const tools = convertAnthropicToolsToAiSdk(params.tools)
|
||||
|
||||
// Create the adapter
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: `${provider.id}:${modelId}`,
|
||||
onEvent: onEvent || (() => {})
|
||||
})
|
||||
|
||||
const result = await executor.streamText({
|
||||
model,
|
||||
messages: coreMessages,
|
||||
// FIXME: Claude Code传入的maxToken会超出有些模型限制,需做特殊处理,可能在v2好修复一点,现在维护的成本有点高
|
||||
// 已知: 豆包
|
||||
maxOutputTokens: params.max_tokens,
|
||||
temperature: params.temperature,
|
||||
topP: params.top_p,
|
||||
topK: params.top_k,
|
||||
stopSequences: params.stop_sequences,
|
||||
stopWhen: stepCountIs(100),
|
||||
headers: defaultAppHeaders(),
|
||||
tools,
|
||||
providerOptions: mapAnthropicThinkToAISdkProviderOptions(provider, params.thinking)
|
||||
})
|
||||
|
||||
// Process the stream through the adapter
|
||||
await adapter.processStream(result.fullStream)
|
||||
|
||||
return adapter
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream a message request using AI SDK executor and convert to Anthropic SSE format
|
||||
*/
|
||||
export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise<void> {
|
||||
const { response, provider, modelId, params, onError, onComplete, middlewares = [], plugins = [] } = config
|
||||
|
||||
logger.info('Starting unified message stream', {
|
||||
providerId: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId,
|
||||
stream: params.stream,
|
||||
middlewareCount: middlewares.length,
|
||||
pluginCount: plugins.length
|
||||
})
|
||||
|
||||
try {
|
||||
response.setHeader('Content-Type', 'text/event-stream')
|
||||
response.setHeader('Cache-Control', 'no-cache')
|
||||
response.setHeader('Connection', 'keep-alive')
|
||||
response.setHeader('X-Accel-Buffering', 'no')
|
||||
|
||||
await executeStream({
|
||||
provider,
|
||||
modelId,
|
||||
params,
|
||||
middlewares,
|
||||
plugins,
|
||||
onEvent: (event) => {
|
||||
logger.silly('Streaming event', { eventType: event.type })
|
||||
const sseData = formatSSEEvent(event)
|
||||
response.write(sseData)
|
||||
}
|
||||
})
|
||||
|
||||
// Send done marker
|
||||
response.write(formatSSEDone())
|
||||
response.end()
|
||||
|
||||
logger.info('Unified message stream completed', { providerId: provider.id, modelId })
|
||||
onComplete?.()
|
||||
} catch (error) {
|
||||
logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId })
|
||||
onError?.(error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a non-streaming message response
|
||||
*
|
||||
* Uses simulateStreamingMiddleware to reuse the same streaming logic,
|
||||
* similar to renderer's ModernAiProvider pattern.
|
||||
*/
|
||||
export async function generateUnifiedMessage(
|
||||
providerOrConfig: Provider | GenerateUnifiedMessageConfig,
|
||||
modelId?: string,
|
||||
params?: MessageCreateParams
|
||||
): Promise<ReturnType<typeof AiSdkToAnthropicSSE.prototype.buildNonStreamingResponse>> {
|
||||
// Support both old signature and new config-based signature
|
||||
let config: GenerateUnifiedMessageConfig
|
||||
if ('provider' in providerOrConfig && 'modelId' in providerOrConfig && 'params' in providerOrConfig) {
|
||||
config = providerOrConfig
|
||||
} else {
|
||||
config = {
|
||||
provider: providerOrConfig as Provider,
|
||||
modelId: modelId!,
|
||||
params: params!
|
||||
}
|
||||
}
|
||||
|
||||
const { provider, middlewares = [], plugins = [] } = config
|
||||
|
||||
logger.info('Starting unified message generation', {
|
||||
providerId: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId: config.modelId,
|
||||
middlewareCount: middlewares.length,
|
||||
pluginCount: plugins.length
|
||||
})
|
||||
|
||||
try {
|
||||
// Add simulateStreamingMiddleware to reuse streaming logic for non-streaming
|
||||
const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares]
|
||||
|
||||
const adapter = await executeStream({
|
||||
provider,
|
||||
modelId: config.modelId,
|
||||
params: config.params,
|
||||
middlewares: allMiddlewares,
|
||||
plugins
|
||||
})
|
||||
|
||||
const finalResponse = adapter.buildNonStreamingResponse()
|
||||
|
||||
logger.info('Unified message generation completed', {
|
||||
providerId: provider.id,
|
||||
modelId: config.modelId
|
||||
})
|
||||
|
||||
return finalResponse
|
||||
} catch (error) {
|
||||
logger.error('Error in unified message generation', error as Error, {
|
||||
providerId: provider.id,
|
||||
modelId: config.modelId
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
export default {
|
||||
streamUnifiedMessages,
|
||||
generateUnifiedMessage
|
||||
}
|
||||
@ -29,7 +29,7 @@ export async function getAvailableProviders(): Promise<Provider[]> {
|
||||
}
|
||||
|
||||
// Support all provider types that AI SDK can handle
|
||||
// The unified-messages service uses AI SDK which supports many providers
|
||||
// The ProxyStreamService uses AI SDK which supports many providers
|
||||
const supportedProviders = providers.filter((p: Provider) => p.enabled)
|
||||
|
||||
// Cache the filtered results
|
||||
|
||||
Loading…
Reference in New Issue
Block a user