This commit is contained in:
SuYao 2025-12-18 20:37:57 +08:00 committed by GitHub
commit af7bf97b35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
74 changed files with 6485 additions and 1358 deletions

2
.gitignore vendored
View File

@ -73,3 +73,5 @@ test-results
YOUR_MEMORY_FILE_PATH YOUR_MEMORY_FILE_PATH
.sessions/ .sessions/
.next/
*.tsbuildinfo

View File

@ -25,7 +25,10 @@ export default defineConfig({
'@shared': resolve('packages/shared'), '@shared': resolve('packages/shared'),
'@logger': resolve('src/main/services/LoggerService'), '@logger': resolve('src/main/services/LoggerService'),
'@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'), '@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'),
'@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node') '@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node'),
'@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'),
'@cherrystudio/ai-core': resolve('packages/aiCore/src'),
'@cherrystudio/ai-sdk-provider': resolve('packages/ai-sdk-provider/src')
} }
}, },
build: { build: {

View File

@ -0,0 +1,15 @@
/**
* Shared AI SDK Middlewares
*
* Environment-agnostic middlewares that can be used in both
* renderer process and main process (API server).
*/
export {
buildSharedMiddlewares,
getReasoningTagName,
isGemini3ModelId,
openrouterReasoningMiddleware,
type SharedMiddlewareConfig,
skipGeminiThoughtSignatureMiddleware
} from './middlewares'

View File

@ -0,0 +1,205 @@
/**
* Shared AI SDK Middlewares
*
* These middlewares are environment-agnostic and can be used in both
* renderer process and main process (API server).
*/
import type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
import { extractReasoningMiddleware } from 'ai'
/**
* Configuration for building shared middlewares
*/
export interface SharedMiddlewareConfig {
/**
* Whether to enable reasoning extraction
*/
enableReasoning?: boolean
/**
* Tag name for reasoning extraction
* Defaults based on model ID
*/
reasoningTagName?: string
/**
* Model ID - used to determine default reasoning tag and model detection
*/
modelId?: string
/**
* Provider ID (Cherry Studio provider ID)
* Used for provider-specific middlewares like OpenRouter
*/
providerId?: string
/**
* AI SDK Provider ID
* Used for Gemini thought signature middleware
* e.g., 'google', 'google-vertex'
*/
aiSdkProviderId?: string
}
/**
* Check if model ID represents a Gemini 3 (2.5) model
* that requires thought signature handling
*
* @param modelId - The model ID string (not Model object)
*/
export function isGemini3ModelId(modelId?: string): boolean {
if (!modelId) return false
const lowerModelId = modelId.toLowerCase()
return lowerModelId.includes('gemini-3')
}
/**
* Get the default reasoning tag name based on model ID
*
* Different models use different tags for reasoning content:
* - Most models: 'think'
* - GPT-OSS models: 'reasoning'
* - Gemini models: 'thought'
* - Seed models: 'seed:think'
*/
export function getReasoningTagName(modelId?: string): string {
if (!modelId) return 'think'
const lowerModelId = modelId.toLowerCase()
if (lowerModelId.includes('gpt-oss')) return 'reasoning'
if (lowerModelId.includes('gemini')) return 'thought'
if (lowerModelId.includes('seed-oss-36b')) return 'seed:think'
return 'think'
}
/**
* Skip Gemini Thought Signature Middleware
*
* Due to the complexity of multi-model client requests (which can switch
* to other models mid-process), this middleware skips all Gemini 3
* thinking signatures validation.
*
* @param aiSdkId - AI SDK Provider ID (e.g., 'google', 'google-vertex')
* @returns LanguageModelV2Middleware
*/
export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelV2Middleware {
const MAGIC_STRING = 'skip_thought_signature_validator'
return {
middlewareVersion: 'v2',
transformParams: async ({ params }) => {
const transformedParams = { ...params }
// Process messages in prompt
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
transformedParams.prompt = transformedParams.prompt.map((message) => {
if (typeof message.content !== 'string') {
for (const part of message.content) {
const googleOptions = part?.providerOptions?.[aiSdkId]
if (googleOptions?.thoughtSignature) {
googleOptions.thoughtSignature = MAGIC_STRING
}
}
}
return message
})
}
return transformedParams
}
}
}
/**
* OpenRouter Reasoning Middleware
*
* Filters out [REDACTED] blocks from OpenRouter reasoning responses.
* OpenRouter may include [REDACTED] markers in reasoning content that
* should be removed for cleaner output.
*
* @see https://openrouter.ai/docs/docs/best-practices/reasoning-tokens
* @returns LanguageModelV2Middleware
*/
export function openrouterReasoningMiddleware(): LanguageModelV2Middleware {
const REDACTED_BLOCK = '[REDACTED]'
return {
middlewareVersion: 'v2',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate()
const modifiedContent = content.map((part) => {
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
return {
...part,
text: part.text.replace(REDACTED_BLOCK, '')
}
}
return part
})
return { content: modifiedContent, ...rest }
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream()
return {
stream: stream.pipeThrough(
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
transform(
chunk: LanguageModelV2StreamPart,
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
) {
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
controller.enqueue({
...chunk,
delta: chunk.delta.replace(REDACTED_BLOCK, '')
})
} else {
controller.enqueue(chunk)
}
}
})
),
...rest
}
}
}
}
/**
* Build shared middlewares based on configuration
*
* This function builds a set of middlewares that are commonly needed
* across different environments (renderer, API server).
*
* @param config - Configuration for middleware building
* @returns Array of AI SDK middlewares
*
* @example
* ```typescript
* import { buildSharedMiddlewares } from '@shared/middleware'
*
* const middlewares = buildSharedMiddlewares({
* enableReasoning: true,
* modelId: 'gemini-2.5-pro',
* providerId: 'openrouter',
* aiSdkProviderId: 'google'
* })
* ```
*/
export function buildSharedMiddlewares(config: SharedMiddlewareConfig): LanguageModelV2Middleware[] {
const middlewares: LanguageModelV2Middleware[] = []
// 1. Reasoning extraction middleware
if (config.enableReasoning) {
const tagName = config.reasoningTagName || getReasoningTagName(config.modelId)
middlewares.push(extractReasoningMiddleware({ tagName }))
}
// 2. OpenRouter-specific: filter [REDACTED] blocks
if (config.providerId === 'openrouter' && config.enableReasoning) {
middlewares.push(openrouterReasoningMiddleware())
}
// 3. Gemini 3 (2.5) specific: skip thought signature validation
if (isGemini3ModelId(config.modelId) && config.aiSdkProviderId) {
middlewares.push(skipGeminiThoughtSignatureMiddleware(config.aiSdkProviderId))
}
return middlewares
}

View File

@ -9,13 +9,27 @@
*/ */
import Anthropic from '@anthropic-ai/sdk' import Anthropic from '@anthropic-ai/sdk'
import type { TextBlockParam } from '@anthropic-ai/sdk/resources' import type { MessageCreateParams, TextBlockParam, Tool as AnthropicTool } from '@anthropic-ai/sdk/resources'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import type { Provider } from '@types' import { type Provider, SystemProviderIds } from '@types'
import type { ModelMessage } from 'ai' import type { ModelMessage } from 'ai'
const logger = loggerService.withContext('anthropic-sdk') const logger = loggerService.withContext('anthropic-sdk')
/**
* Context for Anthropic SDK client creation.
* This allows the shared module to be used in different environments
* by providing environment-specific implementations.
*/
export interface AnthropicSdkContext {
/**
* Custom fetch function to use for HTTP requests.
* In Electron main process, this should be `net.fetch`.
* In other environments, can use the default fetch or a custom implementation.
*/
fetch?: typeof globalThis.fetch
}
const defaultClaudeCodeSystemPrompt = `You are Claude Code, Anthropic's official CLI for Claude.` const defaultClaudeCodeSystemPrompt = `You are Claude Code, Anthropic's official CLI for Claude.`
const defaultClaudeCodeSystem: Array<TextBlockParam> = [ const defaultClaudeCodeSystem: Array<TextBlockParam> = [
@ -58,8 +72,11 @@ const defaultClaudeCodeSystem: Array<TextBlockParam> = [
export function getSdkClient( export function getSdkClient(
provider: Provider, provider: Provider,
oauthToken?: string | null, oauthToken?: string | null,
extraHeaders?: Record<string, string | string[]> extraHeaders?: Record<string, string | string[]>,
context?: AnthropicSdkContext
): Anthropic { ): Anthropic {
const customFetch = context?.fetch
if (provider.authType === 'oauth') { if (provider.authType === 'oauth') {
if (!oauthToken) { if (!oauthToken) {
throw new Error('OAuth token is not available') throw new Error('OAuth token is not available')
@ -85,7 +102,8 @@ export function getSdkClient(
'x-stainless-runtime': 'node', 'x-stainless-runtime': 'node',
'x-stainless-runtime-version': 'v22.18.0', 'x-stainless-runtime-version': 'v22.18.0',
...extraHeaders ...extraHeaders
} },
fetch: customFetch
}) })
} }
const baseURL = const baseURL =
@ -101,11 +119,12 @@ export function getSdkClient(
baseURL, baseURL,
dangerouslyAllowBrowser: true, dangerouslyAllowBrowser: true,
defaultHeaders: { defaultHeaders: {
'anthropic-beta': 'output-128k-2025-02-19', 'anthropic-beta': 'interleaved-thinking-2025-05-14',
'APP-Code': 'MLTG2087', 'APP-Code': 'MLTG2087',
...provider.extra_headers, ...provider.extra_headers,
...extraHeaders ...extraHeaders
} },
fetch: customFetch
}) })
} }
@ -115,9 +134,11 @@ export function getSdkClient(
baseURL, baseURL,
dangerouslyAllowBrowser: true, dangerouslyAllowBrowser: true,
defaultHeaders: { defaultHeaders: {
'anthropic-beta': 'output-128k-2025-02-19', 'anthropic-beta': 'interleaved-thinking-2025-05-14',
Authorization: provider.id === SystemProviderIds.longcat ? `Bearer ${provider.apiKey}` : undefined,
...provider.extra_headers ...provider.extra_headers
} },
fetch: customFetch
}) })
} }
@ -168,3 +189,31 @@ export function buildClaudeCodeSystemModelMessage(system?: string | Array<TextBl
content: block.text content: block.text
})) }))
} }
/**
* Sanitize tool definitions for Anthropic API.
*
* Removes non-standard fields like `input_examples` from tool definitions
* that Anthropic's API doesn't support. This prevents validation errors when
* tools with extended fields are passed to the Anthropic SDK.
*
* @param tools - Array of tool definitions from MessageCreateParams
* @returns Sanitized tools array with non-standard fields removed
*
* @example
* ```typescript
* const sanitizedTools = sanitizeToolsForAnthropic(request.tools)
* ```
*/
export function sanitizeToolsForAnthropic(tools?: MessageCreateParams['tools']): MessageCreateParams['tools'] {
if (!tools || tools.length === 0) return tools
return tools.map((tool) => {
if ('type' in tool && tool.type !== 'custom') return tool
// oxlint-disable-next-line no-unused-vars
const { input_examples, ...sanitizedTool } = tool as AnthropicTool & { input_examples?: unknown }
return sanitizedTool as typeof tool
})
}

View File

@ -43,6 +43,35 @@ export function isSiliconAnthropicCompatibleModel(modelId: string): boolean {
} }
/** /**
* Silicon provider's Anthropic API host URL. * PPIO provider models that support Anthropic API endpoint.
* These models can be used with Claude Code via the Anthropic-compatible API.
*
* @see https://ppio.com/docs/model/llm-anthropic-compatibility
*/ */
export const SILICON_ANTHROPIC_API_HOST = 'https://api.siliconflow.cn' export const PPIO_ANTHROPIC_COMPATIBLE_MODELS: readonly string[] = [
'moonshotai/kimi-k2-thinking',
'minimax/minimax-m2',
'deepseek/deepseek-v3.2-exp',
'deepseek/deepseek-v3.1-terminus',
'zai-org/glm-4.6',
'moonshotai/kimi-k2-0905',
'deepseek/deepseek-v3.1',
'moonshotai/kimi-k2-instruct',
'qwen/qwen3-next-80b-a3b-instruct',
'qwen/qwen3-next-80b-a3b-thinking'
]
/**
* Creates a Set for efficient lookup of PPIO Anthropic-compatible model IDs.
*/
const PPIO_ANTHROPIC_COMPATIBLE_MODEL_SET = new Set(PPIO_ANTHROPIC_COMPATIBLE_MODELS)
/**
* Checks if a model ID is compatible with Anthropic API on PPIO provider.
*
* @param modelId - The model ID to check
* @returns true if the model supports Anthropic API endpoint
*/
export function isPpioAnthropicCompatibleModel(modelId: string): boolean {
return PPIO_ANTHROPIC_COMPATIBLE_MODEL_SET.has(modelId)
}

View File

@ -1,13 +1,13 @@
/** /**
* AiHubMix规则集 * AiHubMix规则集
*/ */
import { isOpenAILLMModel } from '@renderer/config/models' import { getLowerBaseModelName } from '@shared/utils/naming'
import type { Provider } from '@renderer/types'
import type { MinimalModel, MinimalProvider } from '../types'
import { provider2Provider, startsWith } from './helper' import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types' import type { RuleSet } from './types'
const extraProviderConfig = (provider: Provider) => { const extraProviderConfig = <P extends MinimalProvider>(provider: P) => {
return { return {
...provider, ...provider,
extra_headers: { extra_headers: {
@ -17,11 +17,23 @@ const extraProviderConfig = (provider: Provider) => {
} }
} }
function isOpenAILLMModel<M extends MinimalModel>(model: M): boolean {
const modelId = getLowerBaseModelName(model.id)
const reasonings = ['o1', 'o3', 'o4', 'gpt-oss']
if (reasonings.some((r) => modelId.includes(r))) {
return true
}
if (modelId.includes('gpt')) {
return true
}
return false
}
const AIHUBMIX_RULES: RuleSet = { const AIHUBMIX_RULES: RuleSet = {
rules: [ rules: [
{ {
match: startsWith('claude'), match: startsWith('claude'),
provider: (provider: Provider) => { provider: (provider) => {
return extraProviderConfig({ return extraProviderConfig({
...provider, ...provider,
type: 'anthropic' type: 'anthropic'
@ -34,7 +46,7 @@ const AIHUBMIX_RULES: RuleSet = {
!model.id.endsWith('-nothink') && !model.id.endsWith('-nothink') &&
!model.id.endsWith('-search') && !model.id.endsWith('-search') &&
!model.id.includes('embedding'), !model.id.includes('embedding'),
provider: (provider: Provider) => { provider: (provider) => {
return extraProviderConfig({ return extraProviderConfig({
...provider, ...provider,
type: 'gemini', type: 'gemini',
@ -44,7 +56,7 @@ const AIHUBMIX_RULES: RuleSet = {
}, },
{ {
match: isOpenAILLMModel, match: isOpenAILLMModel,
provider: (provider: Provider) => { provider: (provider) => {
return extraProviderConfig({ return extraProviderConfig({
...provider, ...provider,
type: 'openai-response' type: 'openai-response'
@ -52,7 +64,8 @@ const AIHUBMIX_RULES: RuleSet = {
} }
} }
], ],
fallbackRule: (provider: Provider) => extraProviderConfig(provider) fallbackRule: (provider) => extraProviderConfig(provider)
} }
export const aihubmixProviderCreator = provider2Provider.bind(null, AIHUBMIX_RULES) export const aihubmixProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(AIHUBMIX_RULES, model, provider)

View File

@ -0,0 +1,22 @@
import type { MinimalModel, MinimalProvider, ProviderType } from '../types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
const AZURE_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: MinimalProvider) => ({
...provider,
type: 'anthropic' as ProviderType,
apiHost: provider.apiHost + 'anthropic/v1',
id: 'azure-anthropic'
})
}
],
fallbackRule: (provider: MinimalProvider) => provider
}
export const azureAnthropicProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(AZURE_ANTHROPIC_RULES, model, provider)

View File

@ -0,0 +1,32 @@
import type { MinimalModel, MinimalProvider } from '../types'
import type { RuleSet } from './types'
export const startsWith =
(prefix: string) =>
<M extends MinimalModel>(model: M) =>
model.id.toLowerCase().startsWith(prefix.toLowerCase())
export const endpointIs =
(type: string) =>
<M extends MinimalModel>(model: M) =>
model.endpoint_type === type
/**
* Provider
* @param ruleSet
* @param model
* @param provider provider对象
* @returns provider对象
*/
export function provider2Provider<M extends MinimalModel, R extends MinimalProvider, P extends R = R>(
ruleSet: RuleSet<M, R>,
model: M,
provider: P
): P {
for (const rule of ruleSet.rules) {
if (rule.match(model)) {
return rule.provider(provider) as P
}
}
return ruleSet.fallbackRule(provider) as P
}

View File

@ -0,0 +1,6 @@
export { aihubmixProviderCreator } from './aihubmix'
export { azureAnthropicProviderCreator } from './azure-anthropic'
export { endpointIs, provider2Provider, startsWith } from './helper'
export { newApiResolverCreator } from './newApi'
export type { RuleSet } from './types'
export { vertexAnthropicProviderCreator } from './vertex-anthropic'

View File

@ -1,8 +1,7 @@
/** /**
* NewAPI规则集 * NewAPI规则集
*/ */
import type { Provider } from '@renderer/types' import type { MinimalModel, MinimalProvider, ProviderType } from '../types'
import { endpointIs, provider2Provider } from './helper' import { endpointIs, provider2Provider } from './helper'
import type { RuleSet } from './types' import type { RuleSet } from './types'
@ -10,42 +9,43 @@ const NEWAPI_RULES: RuleSet = {
rules: [ rules: [
{ {
match: endpointIs('anthropic'), match: endpointIs('anthropic'),
provider: (provider: Provider) => { provider: (provider) => {
return { return {
...provider, ...provider,
type: 'anthropic' type: 'anthropic' as ProviderType
} }
} }
}, },
{ {
match: endpointIs('gemini'), match: endpointIs('gemini'),
provider: (provider: Provider) => { provider: (provider) => {
return { return {
...provider, ...provider,
type: 'gemini' type: 'gemini' as ProviderType
} }
} }
}, },
{ {
match: endpointIs('openai-response'), match: endpointIs('openai-response'),
provider: (provider: Provider) => { provider: (provider) => {
return { return {
...provider, ...provider,
type: 'openai-response' type: 'openai-response' as ProviderType
} }
} }
}, },
{ {
match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model), match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model),
provider: (provider: Provider) => { provider: (provider) => {
return { return {
...provider, ...provider,
type: 'openai' type: 'openai' as ProviderType
} }
} }
} }
], ],
fallbackRule: (provider: Provider) => provider fallbackRule: (provider) => provider
} }
export const newApiResolverCreator = provider2Provider.bind(null, NEWAPI_RULES) export const newApiResolverCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(NEWAPI_RULES, model, provider)

View File

@ -0,0 +1,9 @@
import type { MinimalModel, MinimalProvider } from '../types'
export interface RuleSet<M extends MinimalModel = MinimalModel, P extends MinimalProvider = MinimalProvider> {
rules: Array<{
match: (model: M) => boolean
provider: (provider: P) => P
}>
fallbackRule: (provider: P) => P
}

View File

@ -0,0 +1,19 @@
import type { MinimalModel, MinimalProvider } from '../types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
const VERTEX_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: MinimalProvider) => ({
...provider,
id: 'google-vertex-anthropic'
})
}
],
fallbackRule: (provider: MinimalProvider) => provider
}
export const vertexAnthropicProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(VERTEX_ANTHROPIC_RULES, model, provider)

View File

@ -0,0 +1,26 @@
import { getLowerBaseModelName } from '@shared/utils/naming'
import type { MinimalModel } from './types'
export const COPILOT_EDITOR_VERSION = 'vscode/1.104.1'
export const COPILOT_PLUGIN_VERSION = 'copilot-chat/0.26.7'
export const COPILOT_INTEGRATION_ID = 'vscode-chat'
export const COPILOT_USER_AGENT = 'GitHubCopilotChat/0.26.7'
export const COPILOT_DEFAULT_HEADERS = {
'Copilot-Integration-Id': COPILOT_INTEGRATION_ID,
'User-Agent': COPILOT_USER_AGENT,
'Editor-Version': COPILOT_EDITOR_VERSION,
'Editor-Plugin-Version': COPILOT_PLUGIN_VERSION,
'editor-version': COPILOT_EDITOR_VERSION,
'editor-plugin-version': COPILOT_PLUGIN_VERSION,
'copilot-vision-request': 'true'
} as const
// Models that require the OpenAI Responses endpoint when routed through GitHub Copilot (#10560)
const COPILOT_RESPONSES_MODEL_IDS = ['gpt-5-codex', 'gpt-5.1-codex', 'gpt-5.1-codex-mini']
export function isCopilotResponsesModel<M extends MinimalModel>(model: M): boolean {
const normalizedId = getLowerBaseModelName(model.id)
return COPILOT_RESPONSES_MODEL_IDS.some((target) => normalizedId === target)
}

View File

@ -0,0 +1,101 @@
/**
* Provider Type Detection Utilities
*
* Functions to detect provider types based on provider configuration.
* These are pure functions that only depend on provider.type and provider.id.
*
* NOTE: These functions should match the logic in @renderer/utils/provider.ts
*/
import type { MinimalProvider } from './types'
/**
* Check if provider is Anthropic type
*/
export function isAnthropicProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'anthropic'
}
/**
* Check if provider is OpenAI Response type (openai-response)
* NOTE: This matches isOpenAIProvider in renderer/utils/provider.ts
*/
export function isOpenAIProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'openai-response'
}
/**
* Check if provider is Gemini type
*/
export function isGeminiProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'gemini'
}
/**
* Check if provider is Azure OpenAI type
*/
export function isAzureOpenAIProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'azure-openai'
}
/**
* Check if provider is Vertex AI type
*/
export function isVertexProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'vertexai'
}
/**
* Check if provider is AWS Bedrock type
*/
export function isAwsBedrockProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'aws-bedrock'
}
export function isAIGatewayProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'gateway'
}
export function isOllamaProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'ollama'
}
/**
* Check if Azure OpenAI provider uses responses endpoint
* Matches isAzureResponsesEndpoint in renderer/utils/provider.ts
*/
export function isAzureResponsesEndpoint<P extends MinimalProvider>(provider: P): boolean {
return provider.apiVersion === 'preview' || provider.apiVersion === 'v1'
}
/**
* Check if provider is Cherry AI type
* Matches isCherryAIProvider in renderer/utils/provider.ts
*/
export function isCherryAIProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.id === 'cherryai'
}
/**
* Check if provider is Perplexity type
* Matches isPerplexityProvider in renderer/utils/provider.ts
*/
export function isPerplexityProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.id === 'perplexity'
}
/**
* Check if provider is new-api type (supports multiple backends)
* Matches isNewApiProvider in renderer/utils/provider.ts
*/
export function isNewApiProvider<P extends MinimalProvider>(provider: P): boolean {
return ['new-api', 'cherryin'].includes(provider.id) || provider.type === ('new-api' as string)
}
/**
* Check if provider is OpenAI compatible
* Matches isOpenAICompatibleProvider in renderer/utils/provider.ts
*/
export function isOpenAICompatibleProvider<P extends MinimalProvider>(provider: P): boolean {
return ['openai', 'new-api', 'mistral'].includes(provider.type)
}

View File

@ -0,0 +1,141 @@
/**
* Provider API Host Formatting
*
* Utilities for formatting provider API hosts to work with AI SDK.
* These handle the differences between how Cherry Studio stores API hosts
* and how AI SDK expects them.
*/
import {
formatApiHost,
formatAzureOpenAIApiHost,
formatOllamaApiHost,
formatVertexApiHost,
isWithTrailingSharp,
routeToEndpoint,
withoutTrailingSlash
} from '../utils/url'
import {
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isOllamaProvider,
isPerplexityProvider,
isVertexProvider
} from './detection'
import type { MinimalProvider } from './types'
import { SystemProviderIds } from './types'
/**
* Interface for environment-specific implementations
* Renderer and Main process can provide their own implementations
*/
export interface ProviderFormatContext {
vertex: {
project: string
location: string
}
}
/**
* Default Azure OpenAI API host formatter
*/
export function defaultFormatAzureOpenAIApiHost(host: string): string {
const normalizedHost = withoutTrailingSlash(host)
?.replace(/\/v1$/, '')
.replace(/\/openai$/, '')
// AI SDK will add /v1
return formatApiHost(normalizedHost + '/openai', false)
}
/**
* Format provider API host for AI SDK
*
* This function normalizes the apiHost to work with AI SDK.
* Different providers have different requirements:
* - Most providers: add /v1 suffix
* - Gemini: add /v1beta suffix
* - Some providers: no suffix needed
*
* @param provider - The provider to format
* @param context - Optional context with environment-specific implementations
* @returns Provider with formatted apiHost (and anthropicApiHost if applicable)
*/
export function formatProviderApiHost<T extends MinimalProvider>(provider: T, context: ProviderFormatContext): T {
const formatted = { ...provider }
const appendApiVersion = !isWithTrailingSharp(provider.apiHost)
// Format anthropicApiHost if present
if (formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost, appendApiVersion)
}
// Format based on provider type
if (isAnthropicProvider(provider)) {
const baseHost = formatted.anthropicApiHost || formatted.apiHost
// AI SDK needs /v1 in baseURL
formatted.apiHost = formatApiHost(baseHost, appendApiVersion)
if (!formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatted.apiHost
}
} else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isOllamaProvider(formatted)) {
formatted.apiHost = formatOllamaApiHost(formatted.apiHost)
} else if (isGeminiProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, appendApiVersion, 'v1beta')
} else if (isAzureOpenAIProvider(formatted)) {
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
} else if (isVertexProvider(formatted)) {
formatted.apiHost = formatVertexApiHost(formatted, context.vertex.project, context.vertex.location)
} else if (isCherryAIProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isPerplexityProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else {
formatted.apiHost = formatApiHost(formatted.apiHost, appendApiVersion)
}
return formatted
}
/**
* Get the base URL for AI SDK from a formatted provider
*
* This extracts the baseURL that AI SDK expects, handling
* the '#' endpoint routing format if present.
*
* @param formattedApiHost - The formatted apiHost (after formatProviderApiHost)
* @returns The baseURL for AI SDK
*/
export function getBaseUrlForAiSdk(formattedApiHost: string): string {
const { baseURL } = routeToEndpoint(formattedApiHost)
return baseURL
}
/**
* Get rotated API key from comma-separated keys
*
* This is the interface for API key rotation. The actual implementation
* depends on the environment (renderer uses window.keyv, main uses its own storage).
*/
export interface ApiKeyRotator {
/**
* Get the next API key in rotation
* @param providerId - The provider ID for tracking rotation
* @param keys - Comma-separated API keys
* @returns The next API key to use
*/
getRotatedKey(providerId: string, keys: string): string
}
/**
* Simple API key rotator that always returns the first key
* Use this when rotation is not needed
*/
export const simpleKeyRotator: ApiKeyRotator = {
getRotatedKey(_providerId: string, keys: string): string {
const keyList = keys.split(',').map((k) => k.trim())
return keyList[0] || keys
}
}

View File

@ -0,0 +1,49 @@
/**
* Shared Provider Utilities
*
* This module exports utilities for working with AI providers
* that can be shared between main process and renderer process.
*/
// Type definitions
export type { MinimalProvider, ProviderType, SystemProviderId } from './types'
export { SystemProviderIds } from './types'
// Provider type detection
export {
isAIGatewayProvider,
isAnthropicProvider,
isAwsBedrockProvider,
isAzureOpenAIProvider,
isAzureResponsesEndpoint,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isOllamaProvider,
isOpenAICompatibleProvider,
isOpenAIProvider,
isPerplexityProvider,
isVertexProvider
} from './detection'
// API host formatting
export type { ApiKeyRotator, ProviderFormatContext } from './format'
export {
defaultFormatAzureOpenAIApiHost,
formatProviderApiHost,
getBaseUrlForAiSdk,
simpleKeyRotator
} from './format'
// Provider ID mapping
export { getAiSdkProviderId, STATIC_PROVIDER_MAPPING, tryResolveProviderId } from './mapping'
// AI SDK configuration
export type { AiSdkConfig, AiSdkConfigContext } from './sdk-config'
export { providerToAiSdkConfig } from './sdk-config'
// Provider resolution
export { resolveActualProvider } from './resolve'
// Provider initialization
export { initializeSharedProviders, SHARED_PROVIDER_CONFIGS } from './initialization'

View File

@ -0,0 +1,114 @@
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
type ProviderInitializationLogger = {
warn?: (message: string) => void
error?: (message: string, error: Error) => void
}
export const SHARED_PROVIDER_CONFIGS: ProviderConfig[] = [
{
id: 'openrouter',
name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: true,
aliases: ['openrouter']
},
{
id: 'google-vertex',
name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex/edge'),
creatorFunctionName: 'createVertex',
supportsImageGeneration: true,
aliases: ['vertexai']
},
{
id: 'google-vertex-anthropic',
name: 'Google Vertex AI Anthropic',
import: () => import('@ai-sdk/google-vertex/anthropic/edge'),
creatorFunctionName: 'createVertexAnthropic',
supportsImageGeneration: true,
aliases: ['vertexai-anthropic']
},
{
id: 'azure-anthropic',
name: 'Azure AI Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false,
aliases: ['azure-anthropic']
},
{
id: 'github-copilot-openai-compatible',
name: 'GitHub Copilot OpenAI Compatible',
import: () => import('@opeoginni/github-copilot-openai-compatible'),
creatorFunctionName: 'createGitHubCopilotOpenAICompatible',
supportsImageGeneration: false,
aliases: ['copilot', 'github-copilot']
},
{
id: 'bedrock',
name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: true,
aliases: ['aws-bedrock']
},
{
id: 'perplexity',
name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false,
aliases: ['perplexity']
},
{
id: 'mistral',
name: 'Mistral',
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral',
supportsImageGeneration: false,
aliases: ['mistral']
},
{
id: 'huggingface',
name: 'HuggingFace',
import: () => import('@ai-sdk/huggingface'),
creatorFunctionName: 'createHuggingFace',
supportsImageGeneration: true,
aliases: ['hf', 'hugging-face']
},
{
id: 'gateway',
name: 'Vercel AI Gateway',
import: () => import('@ai-sdk/gateway'),
creatorFunctionName: 'createGateway',
supportsImageGeneration: true,
aliases: ['ai-gateway']
},
{
id: 'cerebras',
name: 'Cerebras',
import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras',
supportsImageGeneration: false
},
{
id: 'ollama',
name: 'Ollama',
import: () => import('ollama-ai-provider-v2'),
creatorFunctionName: 'createOllama',
supportsImageGeneration: false
}
] as const
export function initializeSharedProviders(logger?: ProviderInitializationLogger): void {
try {
const successCount = registerMultipleProviderConfigs(SHARED_PROVIDER_CONFIGS)
if (successCount < SHARED_PROVIDER_CONFIGS.length) {
logger?.warn?.('Some providers failed to register. Check previous error logs.')
}
} catch (error) {
logger?.error?.('Failed to initialize shared providers', error as Error)
}
}

View File

@ -0,0 +1,95 @@
/**
* Provider ID Mapping
*
* Maps Cherry Studio provider IDs/types to AI SDK provider IDs.
* This logic should match @renderer/aiCore/provider/factory.ts
*/
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from './detection'
import type { MinimalProvider } from './types'
/**
* Static mapping from Cherry Studio provider ID/type to AI SDK provider ID
* Matches STATIC_PROVIDER_MAPPING in @renderer/aiCore/provider/factory.ts
*/
export const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
gemini: 'google', // Google Gemini -> google
'azure-openai': 'azure', // Azure OpenAI -> azure
'openai-response': 'openai', // OpenAI Responses -> openai
grok: 'xai', // Grok -> xai
copilot: 'github-copilot-openai-compatible'
}
/**
* Try to resolve a provider identifier to an AI SDK provider ID
* Matches tryResolveProviderId in @renderer/aiCore/provider/factory.ts
*
* @param identifier - The provider ID or type to resolve
* @param checker - Provider config checker (defaults to static mapping only)
* @returns The resolved AI SDK provider ID, or null if not found
*/
export function tryResolveProviderId(identifier: string): ProviderId | null {
// 1. 检查静态映射
const staticMapping = STATIC_PROVIDER_MAPPING[identifier]
if (staticMapping) {
return staticMapping
}
// 2. 检查AiCore是否支持包括别名支持
if (hasProviderConfigByAlias(identifier)) {
// 解析为真实的Provider ID
return resolveProviderConfigId(identifier) as ProviderId
}
return null
}
/**
* Get the AI SDK Provider ID for a Cherry Studio provider
* Matches getAiSdkProviderId in @renderer/aiCore/provider/factory.ts
*
* Logic:
* 1. Handle Azure OpenAI specially (check responses endpoint)
* 2. Try to resolve from provider.id
* 3. Try to resolve from provider.type (but not for generic 'openai' type)
* 4. Check for OpenAI API host pattern
* 5. Fallback to provider's own ID
*
* @param provider - The Cherry Studio provider
* @param checker - Provider config checker (defaults to static mapping only)
* @returns The AI SDK provider ID to use
*/
export function getAiSdkProviderId(provider: MinimalProvider): ProviderId {
// 1. Handle Azure OpenAI specially - check this FIRST before other resolution
if (isAzureOpenAIProvider(provider)) {
if (isAzureResponsesEndpoint(provider)) {
return 'azure-responses'
}
return 'azure'
}
// 2. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id)
if (resolvedFromId) {
return resolvedFromId
}
// 3. 尝试解析provider.type
// 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上
if (provider.type !== 'openai') {
const resolvedFromType = tryResolveProviderId(provider.type)
if (resolvedFromType) {
return resolvedFromType
}
}
// 4. Check for OpenAI API host pattern
if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat'
}
// 5. 最后的fallback使用provider本身的id
return provider.id
}

View File

@ -0,0 +1,43 @@
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { isAzureOpenAIProvider, isNewApiProvider } from './detection'
import type { MinimalModel, MinimalProvider } from './types'
export interface ResolveActualProviderOptions<P extends MinimalProvider> {
isSystemProvider?: (provider: P) => boolean
}
const defaultIsSystemProvider = <P extends MinimalProvider>(provider: P): boolean => {
if ('isSystem' in provider) {
return Boolean((provider as unknown as { isSystem?: boolean }).isSystem)
}
return false
}
export function resolveActualProvider<M extends MinimalModel, P extends MinimalProvider>(
provider: P,
model: M,
options: ResolveActualProviderOptions<P> = {}
): P {
let resolvedProvider = provider
if (isNewApiProvider(resolvedProvider)) {
resolvedProvider = newApiResolverCreator(model, resolvedProvider)
}
const isSystemProvider = options.isSystemProvider?.(resolvedProvider) ?? defaultIsSystemProvider(resolvedProvider)
if (isSystemProvider && resolvedProvider.id === 'aihubmix') {
resolvedProvider = aihubmixProviderCreator(model, resolvedProvider)
}
if (isSystemProvider && resolvedProvider.id === 'vertexai') {
resolvedProvider = vertexAnthropicProviderCreator(model, resolvedProvider)
}
if (isAzureOpenAIProvider(resolvedProvider)) {
resolvedProvider = azureAnthropicProviderCreator(model, resolvedProvider)
}
return resolvedProvider
}

View File

@ -0,0 +1,289 @@
/**
* AI SDK Configuration
*
* Shared utilities for converting Cherry Studio Provider to AI SDK configuration.
* Environment-specific logic (renderer/main) is injected via context interfaces.
*/
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
import { defaultAppHeaders } from '@shared/utils'
import { isEmpty } from 'lodash'
import { routeToEndpoint } from '../utils/url'
import { isAzureOpenAIProvider, isOllamaProvider } from './detection'
import { getAiSdkProviderId } from './mapping'
import type { MinimalProvider } from './types'
import { SystemProviderIds } from './types'
/**
* AI SDK configuration result
*/
export interface AiSdkConfig {
providerId: string
options: Record<string, unknown>
}
/**
* Context for environment-specific implementations
*/
export interface AiSdkConfigContext {
/**
* Check if a model uses chat completion only (for OpenAI response mode)
* Default: returns false
*/
isOpenAIChatCompletionOnlyModel?: (modelId: string) => boolean
/**
* Check if provider supports stream options
* Default: returns true
*/
isSupportStreamOptionsProvider?: (provider: MinimalProvider) => boolean
/**
* Get includeUsage setting for stream options
* Default: returns undefined
*/
getIncludeUsageSetting?: () => boolean | undefined | Promise<boolean | undefined>
/**
* Get Copilot default headers (constants)
* Default: returns empty object
*/
getCopilotDefaultHeaders?: () => Record<string, string>
/**
* Get Copilot stored headers from state
* Default: returns empty object
*/
getCopilotStoredHeaders?: () => Record<string, string>
/**
* Get AWS Bedrock configuration
* Default: returns undefined (not configured)
*/
getAwsBedrockConfig?: () =>
| {
authType: 'apiKey' | 'iam'
region: string
apiKey?: string
accessKeyId?: string
secretAccessKey?: string
}
| undefined
/**
* Get Vertex AI configuration
* Default: returns undefined (not configured)
*/
getVertexConfig?: (provider: MinimalProvider) =>
| {
project: string
location: string
googleCredentials: {
privateKey: string
clientEmail: string
}
}
| undefined
/**
* Get endpoint type for cherryin provider
*/
getEndpointType?: (modelId: string) => string | undefined
/**
* Custom fetch implementation
* Main process: use Electron net.fetch
* Renderer process: use browser fetch (default)
*/
fetch?: typeof globalThis.fetch
/**
* Get CherryAI signed fetch wrapper
* Returns a fetch function that adds signature headers to requests
*/
getCherryAISignedFetch?: () => typeof globalThis.fetch
}
/**
* Convert Cherry Studio Provider to AI SDK configuration
*
* @param provider - The formatted provider (after formatProviderApiHost)
* @param modelId - The model ID to use
* @param context - Environment-specific implementations
* @returns AI SDK configuration
*/
export function providerToAiSdkConfig(
provider: MinimalProvider,
modelId: string,
context: AiSdkConfigContext = {}
): AiSdkConfig {
const isOpenAIChatCompletionOnlyModel = context.isOpenAIChatCompletionOnlyModel || (() => false)
const isSupportStreamOptionsProvider = context.isSupportStreamOptionsProvider || (() => true)
const getIncludeUsageSetting = context.getIncludeUsageSetting || (() => undefined)
const aiSdkProviderId = getAiSdkProviderId(provider)
// Build base config
const { baseURL, endpoint } = routeToEndpoint(provider.apiHost)
const baseConfig = {
baseURL,
apiKey: provider.apiKey
}
let includeUsage: boolean | undefined = undefined
if (isSupportStreamOptionsProvider(provider)) {
const setting = getIncludeUsageSetting()
includeUsage = setting instanceof Promise ? undefined : setting
}
// Handle Copilot specially
if (provider.id === SystemProviderIds.copilot) {
const defaultHeaders = context.getCopilotDefaultHeaders?.() ?? {}
const storedHeaders = context.getCopilotStoredHeaders?.() ?? {}
const copilotExtraOptions: Record<string, unknown> = {
headers: {
...defaultHeaders,
...storedHeaders,
...provider.extra_headers
},
name: provider.id,
includeUsage
}
if (context.fetch) {
copilotExtraOptions.fetch = context.fetch
}
const options = ProviderConfigFactory.fromProvider(
'github-copilot-openai-compatible',
baseConfig,
copilotExtraOptions
)
return {
providerId: 'github-copilot-openai-compatible',
options
}
}
if (isOllamaProvider(provider)) {
return {
providerId: 'ollama',
options: {
...baseConfig,
headers: {
...provider.extra_headers,
Authorization: !isEmpty(baseConfig.apiKey) ? `Bearer ${baseConfig.apiKey}` : undefined
}
}
}
}
// Build extra options
const extraOptions: Record<string, unknown> = {}
if (endpoint) {
extraOptions.endpoint = endpoint
}
// Handle OpenAI mode
if (provider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(modelId)) {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && provider.type === 'openai')) {
extraOptions.mode = 'chat'
}
// Add extra headers
const headers: Record<string, string | undefined> = {
...defaultAppHeaders(),
...provider.extra_headers
}
if (aiSdkProviderId === 'openai') {
headers['X-Api-Key'] = baseConfig.apiKey
}
extraOptions.headers = headers
// Handle Azure modes
if (aiSdkProviderId === 'azure-responses') {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'azure') {
extraOptions.mode = 'chat'
}
if (isAzureOpenAIProvider(provider)) {
const apiVersion = provider.apiVersion?.trim()
if (apiVersion) {
extraOptions.apiVersion = apiVersion
if (!['preview', 'v1'].includes(apiVersion)) {
extraOptions.useDeploymentBasedUrls = true
}
}
}
// Handle AWS Bedrock
if (aiSdkProviderId === 'bedrock') {
const bedrockConfig = context.getAwsBedrockConfig?.()
if (bedrockConfig) {
extraOptions.region = bedrockConfig.region
if (bedrockConfig.authType === 'apiKey') {
extraOptions.apiKey = bedrockConfig.apiKey
} else {
extraOptions.accessKeyId = bedrockConfig.accessKeyId
extraOptions.secretAccessKey = bedrockConfig.secretAccessKey
}
}
}
// Handle Vertex AI
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
const vertexConfig = context.getVertexConfig?.(provider)
if (vertexConfig) {
extraOptions.project = vertexConfig.project
extraOptions.location = vertexConfig.location
extraOptions.googleCredentials = {
...vertexConfig.googleCredentials,
privateKey: formatPrivateKey(vertexConfig.googleCredentials.privateKey)
}
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
}
}
// Handle cherryin endpoint type
if (aiSdkProviderId === 'cherryin') {
const endpointType = context.getEndpointType?.(modelId)
if (endpointType) {
extraOptions.endpointType = endpointType
}
}
// Handle cherryai signed fetch
if (provider.id === 'cherryai') {
const signedFetch = context.getCherryAISignedFetch?.()
if (signedFetch) {
extraOptions.fetch = signedFetch
}
} else if (context.fetch) {
extraOptions.fetch = context.fetch
}
// Check if AI SDK supports this provider natively
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId,
options
}
}
// Fallback to openai-compatible
const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: provider.id,
...extraOptions,
includeUsage
}
}
}

View File

@ -0,0 +1,177 @@
import * as z from 'zod'
export const ProviderTypeSchema = z.enum([
'openai',
'openai-response',
'anthropic',
'gemini',
'azure-openai',
'vertexai',
'mistral',
'aws-bedrock',
'vertex-anthropic',
'new-api',
'gateway',
'ollama'
])
export type ProviderType = z.infer<typeof ProviderTypeSchema>
/**
* Minimal provider interface for shared utilities
* This is the subset of Provider that shared code needs
*/
export type MinimalProvider = {
id: string
type: ProviderType
apiKey: string
apiHost: string
anthropicApiHost?: string
apiVersion?: string
extra_headers?: Record<string, string>
}
/**
* Minimal model interface for shared utilities
* This is the subset of Model that shared code needs
*/
export type MinimalModel = {
id: string
endpoint_type?: string
}
export const SystemProviderIdSchema = z.enum([
'cherryin',
'silicon',
'aihubmix',
'ocoolai',
'deepseek',
'ppio',
'alayanew',
'qiniu',
'dmxapi',
'burncloud',
'tokenflux',
'302ai',
'cephalon',
'lanyun',
'ph8',
'openrouter',
'ollama',
'ovms',
'new-api',
'lmstudio',
'anthropic',
'openai',
'azure-openai',
'gemini',
'vertexai',
'github',
'copilot',
'zhipu',
'yi',
'moonshot',
'baichuan',
'dashscope',
'stepfun',
'doubao',
'infini',
'minimax',
'groq',
'together',
'fireworks',
'nvidia',
'grok',
'hyperbolic',
'mistral',
'jina',
'perplexity',
'modelscope',
'xirang',
'hunyuan',
'tencent-cloud-ti',
'baidu-cloud',
'gpustack',
'voyageai',
'aws-bedrock',
'poe',
'aionly',
'longcat',
'huggingface',
'sophnet',
'gateway',
'cerebras',
'mimo'
])
export type SystemProviderId = z.infer<typeof SystemProviderIdSchema>
export const isSystemProviderId = (id: string): id is SystemProviderId => {
return SystemProviderIdSchema.safeParse(id).success
}
export const SystemProviderIds = {
cherryin: 'cherryin',
silicon: 'silicon',
aihubmix: 'aihubmix',
ocoolai: 'ocoolai',
deepseek: 'deepseek',
ppio: 'ppio',
alayanew: 'alayanew',
qiniu: 'qiniu',
dmxapi: 'dmxapi',
burncloud: 'burncloud',
tokenflux: 'tokenflux',
'302ai': '302ai',
cephalon: 'cephalon',
lanyun: 'lanyun',
ph8: 'ph8',
sophnet: 'sophnet',
openrouter: 'openrouter',
ollama: 'ollama',
ovms: 'ovms',
'new-api': 'new-api',
lmstudio: 'lmstudio',
anthropic: 'anthropic',
openai: 'openai',
'azure-openai': 'azure-openai',
gemini: 'gemini',
vertexai: 'vertexai',
github: 'github',
copilot: 'copilot',
zhipu: 'zhipu',
yi: 'yi',
moonshot: 'moonshot',
baichuan: 'baichuan',
dashscope: 'dashscope',
stepfun: 'stepfun',
doubao: 'doubao',
infini: 'infini',
minimax: 'minimax',
groq: 'groq',
together: 'together',
fireworks: 'fireworks',
nvidia: 'nvidia',
grok: 'grok',
hyperbolic: 'hyperbolic',
mistral: 'mistral',
jina: 'jina',
perplexity: 'perplexity',
modelscope: 'modelscope',
xirang: 'xirang',
hunyuan: 'hunyuan',
'tencent-cloud-ti': 'tencent-cloud-ti',
'baidu-cloud': 'baidu-cloud',
gpustack: 'gpustack',
voyageai: 'voyageai',
'aws-bedrock': 'aws-bedrock',
poe: 'poe',
aionly: 'aionly',
longcat: 'longcat',
huggingface: 'huggingface',
gateway: 'gateway',
cerebras: 'cerebras',
mimo: 'mimo'
} as const satisfies Record<SystemProviderId, SystemProviderId>
export type SystemProviderIdTypeMap = typeof SystemProviderIds

View File

@ -0,0 +1,3 @@
export { defaultAppHeaders } from './headers'
export { getBaseModelName, getLowerBaseModelName } from './naming'
export * from './url'

View File

@ -0,0 +1,36 @@
/**
* ID
*
* - 'deepseek/deepseek-r1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1'
* @param {string} id ID
* @param {string} [delimiter='/'] '/'
* @returns {string}
*/
export const getBaseModelName = (id: string, delimiter: string = '/'): string => {
const parts = id.split(delimiter)
return parts[parts.length - 1]
}
/**
* ID
*
* - 'deepseek/DeepSeek-R1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1'
* @param {string} id ID
* @param {string} [delimiter='/'] '/'
* @returns {string}
*/
export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => {
const baseModelName = getBaseModelName(id, delimiter).toLowerCase()
// for openrouter
if (baseModelName.endsWith(':free')) {
return baseModelName.replace(':free', '')
}
// for cherryin
if (baseModelName.endsWith('(free)')) {
return baseModelName.replace('(free)', '')
}
return baseModelName
}

View File

@ -0,0 +1,293 @@
/**
* Shared API Utilities
*
* Common utilities for API URL formatting and validation.
* Used by both main process (API Server) and renderer.
*/
import type { MinimalProvider } from '@shared/provider'
import { trim } from 'lodash'
// Supported endpoints for routing
export const SUPPORTED_IMAGE_ENDPOINT_LIST = ['images/generations', 'images/edits', 'predict'] as const
export const SUPPORTED_ENDPOINT_LIST = [
'chat/completions',
'responses',
'messages',
'generateContent',
'streamGenerateContent',
...SUPPORTED_IMAGE_ENDPOINT_LIST
] as const
/**
* Removes the trailing slash from a URL string if it exists.
*/
export function withoutTrailingSlash<T extends string>(url: T): T {
return url.replace(/\/$/, '') as T
}
/**
* Removes the trailing '#' from a URL string if it exists.
*
* @template T - The string type to preserve type safety
* @param {T} url - The URL string to process
* @returns {T} The URL string without a trailing '#'
*
* @example
* ```ts
* withoutTrailingSharp('https://example.com#') // 'https://example.com'
* withoutTrailingSharp('https://example.com') // 'https://example.com'
* ```
*/
export function withoutTrailingSharp<T extends string>(url: T): T {
return url.replace(/#$/, '') as T
}
/**
* Checks if a URL string ends with a trailing '#' character.
*
* @template T - The string type to preserve type safety
* @param {T} url - The URL string to check
* @returns {boolean} True if the URL ends with '#', false otherwise
*
* @example
* ```ts
* isWithTrailingSharp('https://example.com#') // true
* isWithTrailingSharp('https://example.com') // false
* ```
*/
export function isWithTrailingSharp<T extends string>(url: T): boolean {
return url.endsWith('#')
}
/**
* Matches a version segment in a path that starts with `/v<number>` and optionally
* continues with `alpha` or `beta`. The segment may be followed by `/` or the end
* of the string (useful for cases like `/v3alpha/resources`).
*/
const VERSION_REGEX_PATTERN = '\\/v\\d+(?:alpha|beta)?(?=\\/|$)'
/**
* Matches an API version at the end of a URL (with optional trailing slash).
* Used to detect and extract versions only from the trailing position.
*/
const TRAILING_VERSION_REGEX = /\/v\d+(?:alpha|beta)?\/?$/i
/**
* host path /v1/v2beta
*
* @param host - host path
* @returns path true false
*/
export function hasAPIVersion(host?: string): boolean {
if (!host) return false
const regex = new RegExp(VERSION_REGEX_PATTERN, 'i')
try {
const url = new URL(host)
return regex.test(url.pathname)
} catch {
// 若无法作为完整 URL 解析,则当作路径直接检测
return regex.test(host)
}
}
/**
* Azure OpenAI API
*/
export function formatAzureOpenAIApiHost(host: string): string {
const normalizedHost = withoutTrailingSlash(host)
?.replace(/\/v1$/, '')
.replace(/\/openai$/, '')
// NOTE: AISDK会添加上`v1`
return formatApiHost(normalizedHost + '/openai', false)
}
export function formatVertexApiHost(
provider: MinimalProvider,
project: string = 'test-project',
location: string = 'us-central1'
): string {
const { apiHost } = provider
const trimmedHost = withoutTrailingSlash(trim(apiHost))
if (!trimmedHost || trimmedHost.endsWith('aiplatform.googleapis.com')) {
const host =
location === 'global' ? 'https://aiplatform.googleapis.com' : `https://${location}-aiplatform.googleapis.com`
return `${formatApiHost(host)}/projects/${project}/locations/${location}`
}
return formatApiHost(trimmedHost)
}
/**
* Ollama API
*/
export function formatOllamaApiHost(host: string): string {
const normalizedHost = withoutTrailingSlash(host)
?.replace(/\/v1$/, '')
?.replace(/\/api$/, '')
?.replace(/\/chat$/, '')
return formatApiHost(normalizedHost + '/api', false)
}
/**
* Formats an API host URL by normalizing it and optionally appending an API version.
*
* @param host - The API host URL to format. Leading/trailing whitespace will be trimmed and trailing slashes removed.
* @param supportApiVersion - Whether the API version is supported. Defaults to `true`.
* @param apiVersion - The API version to append if needed. Defaults to `'v1'`.
*
* @returns The formatted API host URL. If the host is empty after normalization, returns an empty string.
* If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host with trailing '#' removed.
* Otherwise, returns the host with the API version appended.
*
* @example
* formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1'
* formatApiHost('https://api.example.com#') // Returns 'https://api.example.com'
* formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2'
*/
export function formatApiHost(host?: string, supportApiVersion: boolean = true, apiVersion: string = 'v1'): string {
const normalizedHost = withoutTrailingSlash(trim(host))
if (!normalizedHost) {
return ''
}
const shouldAppendApiVersion = !(normalizedHost.endsWith('#') || !supportApiVersion || hasAPIVersion(normalizedHost))
if (shouldAppendApiVersion) {
return `${normalizedHost}/${apiVersion}`
} else {
return withoutTrailingSharp(normalizedHost)
}
}
/**
* Converts an API host URL into separate base URL and endpoint components.
*
* This function extracts endpoint information from a composite API host string.
* If the host ends with '#', it attempts to match the preceding part against the supported endpoint list.
*
* @param apiHost - The API host string to parse
* @returns An object containing:
* - `baseURL`: The base URL without the endpoint suffix
* - `endpoint`: The matched endpoint identifier, or empty string if no match found
*
* @example
* routeToEndpoint('https://api.example.com/openai/chat/completions#')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: 'chat/completions' }
*
* @example
* routeToEndpoint('https://api.example.com/v1')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: '' }
*/
export function routeToEndpoint(apiHost: string): { baseURL: string; endpoint: string } {
const trimmedHost = (apiHost || '').trim()
if (!trimmedHost.endsWith('#')) {
return { baseURL: trimmedHost, endpoint: '' }
}
// Remove trailing #
const host = trimmedHost.slice(0, -1)
const endpointMatch = SUPPORTED_ENDPOINT_LIST.find((endpoint) => host.endsWith(endpoint))
if (!endpointMatch) {
const baseURL = withoutTrailingSlash(host)
return { baseURL, endpoint: '' }
}
const baseSegment = host.slice(0, host.length - endpointMatch.length)
const baseURL = withoutTrailingSlash(baseSegment).replace(/:$/, '') // Remove trailing colon (gemini special case)
return { baseURL, endpoint: endpointMatch }
}
/**
* Gets the AI SDK compatible base URL from a provider's apiHost.
*
* AI SDK expects baseURL WITH version suffix (e.g., /v1).
* This function:
* 1. Handles '#' endpoint routing format
* 2. Ensures the URL has a version suffix (adds /v1 if missing)
*
* @param apiHost - The provider's apiHost value (may or may not have /v1)
* @param apiVersion - The API version to use if missing. Defaults to 'v1'.
* @returns The baseURL suitable for AI SDK (with version suffix)
*
* @example
* getAiSdkBaseUrl('https://api.openai.com') // 'https://api.openai.com/v1'
* getAiSdkBaseUrl('https://api.openai.com/v1') // 'https://api.openai.com/v1'
* getAiSdkBaseUrl('https://api.example.com/chat/completions#') // 'https://api.example.com'
*/
export function getAiSdkBaseUrl(apiHost: string, apiVersion: string = 'v1'): string {
// First handle '#' endpoint routing format
const { baseURL } = routeToEndpoint(apiHost)
// If already has version, return as-is
if (hasAPIVersion(baseURL)) {
return withoutTrailingSlash(baseURL)
}
// Add version suffix
return `${withoutTrailingSlash(baseURL)}/${apiVersion}`
}
/**
* Validates an API host address.
*
* @param apiHost - The API host address to validate
* @returns true if valid URL with http/https protocol, false otherwise
*/
export function validateApiHost(apiHost: string): boolean {
if (!apiHost || !apiHost.trim()) {
return true // Allow empty
}
try {
const url = new URL(apiHost.trim())
return url.protocol === 'http:' || url.protocol === 'https:'
} catch {
return false
}
}
/**
* Extracts the trailing API version segment from a URL path.
*
* This function extracts API version patterns (e.g., `v1`, `v2beta`) from the end of a URL.
* Only versions at the end of the path are extracted, not versions in the middle.
* The returned version string does not include leading or trailing slashes.
*
* @param {string} url - The URL string to parse.
* @returns {string | undefined} The trailing API version found (e.g., 'v1', 'v2beta'), or undefined if none found.
*
* @example
* getTrailingApiVersion('https://api.example.com/v1') // 'v1'
* getTrailingApiVersion('https://api.example.com/v2beta/') // 'v2beta'
* getTrailingApiVersion('https://api.example.com/v1/chat') // undefined (version not at end)
* getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1/xxx/v1beta') // 'v1beta'
* getTrailingApiVersion('https://api.example.com') // undefined
*/
export function getTrailingApiVersion(url: string): string | undefined {
const match = url.match(TRAILING_VERSION_REGEX)
if (match) {
// Extract version without leading slash and trailing slash
return match[0].replace(/^\//, '').replace(/\/$/, '')
}
return undefined
}
/**
* Removes the trailing API version segment from a URL path.
*
* This function removes API version patterns (e.g., `/v1`, `/v2beta`) from the end of a URL.
* Only versions at the end of the path are removed, not versions in the middle.
*
* @param {string} url - The URL string to process.
* @returns {string} The URL with the trailing API version removed, or the original URL if no trailing version found.
*
* @example
* withoutTrailingApiVersion('https://api.example.com/v1') // 'https://api.example.com'
* withoutTrailingApiVersion('https://api.example.com/v2beta/') // 'https://api.example.com'
* withoutTrailingApiVersion('https://api.example.com/v1/chat') // 'https://api.example.com/v1/chat' (no change)
* withoutTrailingApiVersion('https://api.example.com') // 'https://api.example.com'
*/
export function withoutTrailingApiVersion(url: string): string {
return url.replace(TRAILING_VERSION_REGEX, '')
}

View File

@ -0,0 +1,637 @@
/**
* AI SDK to Anthropic SSE Adapter
*
* Converts AI SDK's fullStream (TextStreamPart) events to Anthropic Messages API SSE format.
* This enables any AI provider supported by AI SDK to be exposed via Anthropic-compatible API.
*
* Anthropic SSE Event Flow:
* 1. message_start - Initial message with metadata
* 2. content_block_start - Begin a content block (text, tool_use, thinking)
* 3. content_block_delta - Incremental content updates
* 4. content_block_stop - End a content block
* 5. message_delta - Updates to overall message (stop_reason, usage)
* 6. message_stop - Stream complete
*
* @see https://docs.anthropic.com/en/api/messages-streaming
*/
import type {
ContentBlock,
InputJSONDelta,
Message,
MessageDeltaUsage,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
RawMessageDeltaEvent,
RawMessageStartEvent,
RawMessageStopEvent,
RawMessageStreamEvent,
StopReason,
TextBlock,
TextDelta,
ThinkingBlock,
ThinkingDelta,
ToolUseBlock,
Usage
} from '@anthropic-ai/sdk/resources/messages'
import { loggerService } from '@logger'
import { type FinishReason, type LanguageModelUsage, type TextStreamPart, type ToolSet } from 'ai'
import { googleReasoningCache, openRouterReasoningCache } from '../services/reasoning-cache'
const logger = loggerService.withContext('AiSdkToAnthropicSSE')
interface ContentBlockState {
type: 'text' | 'tool_use' | 'thinking'
index: number
started: boolean
content: string
// For tool_use blocks
toolId?: string
toolName?: string
toolInput?: string
}
interface AdapterState {
messageId: string
model: string
inputTokens: number
outputTokens: number
cacheInputTokens: number
currentBlockIndex: number
blocks: Map<number, ContentBlockState>
textBlockIndex: number | null
// Track multiple thinking blocks by their reasoning ID
thinkingBlocks: Map<string, number> // reasoningId -> blockIndex
currentThinkingId: string | null // Currently active thinking block ID
toolBlocks: Map<string, number> // toolCallId -> blockIndex
stopReason: StopReason | null
hasEmittedMessageStart: boolean
}
export type SSEEventCallback = (event: RawMessageStreamEvent) => void
export interface AiSdkToAnthropicSSEOptions {
model: string
messageId?: string
inputTokens?: number
onEvent: SSEEventCallback
}
/**
* Adapter that converts AI SDK fullStream events to Anthropic SSE events
*/
export class AiSdkToAnthropicSSE {
private state: AdapterState
private onEvent: SSEEventCallback
constructor(options: AiSdkToAnthropicSSEOptions) {
this.onEvent = options.onEvent
this.state = {
messageId: options.messageId || `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`,
model: options.model,
inputTokens: options.inputTokens || 0,
outputTokens: 0,
cacheInputTokens: 0,
currentBlockIndex: 0,
blocks: new Map(),
textBlockIndex: null,
thinkingBlocks: new Map(),
currentThinkingId: null,
toolBlocks: new Map(),
stopReason: null,
hasEmittedMessageStart: false
}
}
/**
* Process the AI SDK stream and emit Anthropic SSE events
*/
async processStream(fullStream: ReadableStream<TextStreamPart<ToolSet>>): Promise<void> {
const reader = fullStream.getReader()
try {
// Emit message_start at the beginning
this.emitMessageStart()
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
this.processChunk(value)
}
// Ensure all blocks are closed and emit final events
this.finalize()
} catch (error) {
await reader.cancel()
throw error
} finally {
reader.releaseLock()
}
}
/**
* Process a single AI SDK chunk and emit corresponding Anthropic events
*/
private processChunk(chunk: TextStreamPart<ToolSet>): void {
logger.silly('AiSdkToAnthropicSSE - Processing chunk:', { chunk: JSON.stringify(chunk) })
switch (chunk.type) {
// === Text Events ===
case 'text-start':
this.startTextBlock()
break
case 'text-delta':
this.emitTextDelta(chunk.text || '')
break
case 'text-end':
this.stopTextBlock()
break
// === Reasoning/Thinking Events ===
case 'reasoning-start': {
const reasoningId = chunk.id
this.startThinkingBlock(reasoningId)
break
}
case 'reasoning-delta': {
const reasoningId = chunk.id
this.emitThinkingDelta(chunk.text || '', reasoningId)
break
}
case 'reasoning-end': {
const reasoningId = chunk.id
this.stopThinkingBlock(reasoningId)
break
}
// === Tool Events ===
case 'tool-call':
if (googleReasoningCache && chunk.providerMetadata?.google?.thoughtSignature) {
googleReasoningCache.set(
`google-${chunk.toolName}`,
chunk.providerMetadata?.google?.thoughtSignature as string
)
}
if (
openRouterReasoningCache &&
chunk.providerMetadata?.openrouter?.reasoning_details &&
Array.isArray(chunk.providerMetadata.openrouter.reasoning_details)
) {
openRouterReasoningCache.set(
`openrouter-${chunk.toolCallId}`,
JSON.parse(JSON.stringify(chunk.providerMetadata.openrouter.reasoning_details))
)
}
this.handleToolCall({
type: 'tool-call',
toolCallId: chunk.toolCallId,
toolName: chunk.toolName,
args: chunk.input
})
break
case 'tool-result':
// this.handleToolResult({
// type: 'tool-result',
// toolCallId: chunk.toolCallId,
// toolName: chunk.toolName,
// args: chunk.input,
// result: chunk.output
// })
break
case 'finish-step':
if (chunk.finishReason === 'tool-calls') {
this.state.stopReason = 'tool_use'
}
break
case 'finish':
this.handleFinish(chunk)
break
case 'error':
throw chunk.error
// Ignore other event types
default:
break
}
}
private emitMessageStart(): void {
if (this.state.hasEmittedMessageStart) return
this.state.hasEmittedMessageStart = true
const usage: Usage = {
input_tokens: this.state.inputTokens,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
server_tool_use: null
}
const message: Message = {
id: this.state.messageId,
type: 'message',
role: 'assistant',
content: [],
model: this.state.model,
stop_reason: null,
stop_sequence: null,
usage
}
const event: RawMessageStartEvent = {
type: 'message_start',
message
}
this.onEvent(event)
}
private startTextBlock(): void {
// If we already have a text block, don't create another
if (this.state.textBlockIndex !== null) return
const index = this.state.currentBlockIndex++
this.state.textBlockIndex = index
this.state.blocks.set(index, {
type: 'text',
index,
started: true,
content: ''
})
const contentBlock: TextBlock = {
type: 'text',
text: '',
citations: null
}
const event: RawContentBlockStartEvent = {
type: 'content_block_start',
index,
content_block: contentBlock
}
this.onEvent(event)
}
private emitTextDelta(text: string): void {
if (!text) return
// Auto-start text block if not started
if (this.state.textBlockIndex === null) {
this.startTextBlock()
}
const index = this.state.textBlockIndex!
const block = this.state.blocks.get(index)
if (block) {
block.content += text
}
const delta: TextDelta = {
type: 'text_delta',
text
}
const event: RawContentBlockDeltaEvent = {
type: 'content_block_delta',
index,
delta
}
this.onEvent(event)
}
private stopTextBlock(): void {
if (this.state.textBlockIndex === null) return
const index = this.state.textBlockIndex
const event: RawContentBlockStopEvent = {
type: 'content_block_stop',
index
}
this.onEvent(event)
this.state.textBlockIndex = null
}
private startThinkingBlock(reasoningId: string): void {
// Check if this thinking block already exists
if (this.state.thinkingBlocks.has(reasoningId)) return
const index = this.state.currentBlockIndex++
this.state.thinkingBlocks.set(reasoningId, index)
this.state.currentThinkingId = reasoningId
this.state.blocks.set(index, {
type: 'thinking',
index,
started: true,
content: ''
})
const contentBlock: ThinkingBlock = {
type: 'thinking',
thinking: '',
signature: ''
}
const event: RawContentBlockStartEvent = {
type: 'content_block_start',
index,
content_block: contentBlock
}
this.onEvent(event)
}
private emitThinkingDelta(text: string, reasoningId?: string): void {
if (!text) return
// Determine which thinking block to use
const targetId = reasoningId || this.state.currentThinkingId
if (!targetId) {
// Auto-start thinking block if not started
const newId = `reasoning_${Date.now()}`
this.startThinkingBlock(newId)
return this.emitThinkingDelta(text, newId)
}
const index = this.state.thinkingBlocks.get(targetId)
if (index === undefined) {
// If the block doesn't exist, create it
this.startThinkingBlock(targetId)
return this.emitThinkingDelta(text, targetId)
}
const block = this.state.blocks.get(index)
if (block) {
block.content += text
}
const delta: ThinkingDelta = {
type: 'thinking_delta',
thinking: text
}
const event: RawContentBlockDeltaEvent = {
type: 'content_block_delta',
index,
delta
}
this.onEvent(event)
}
private stopThinkingBlock(reasoningId?: string): void {
const targetId = reasoningId || this.state.currentThinkingId
if (!targetId) return
const index = this.state.thinkingBlocks.get(targetId)
if (index === undefined) return
const event: RawContentBlockStopEvent = {
type: 'content_block_stop',
index
}
this.onEvent(event)
this.state.thinkingBlocks.delete(targetId)
// Update currentThinkingId if we just closed the current one
if (this.state.currentThinkingId === targetId) {
// Set to the most recent remaining thinking block, or null if none
const remaining = Array.from(this.state.thinkingBlocks.keys())
this.state.currentThinkingId = remaining.length > 0 ? remaining[remaining.length - 1] : null
}
}
private handleToolCall(chunk: { type: 'tool-call'; toolCallId: string; toolName: string; args: unknown }): void {
const { toolCallId, toolName, args } = chunk
// Check if we already have this tool call
if (this.state.toolBlocks.has(toolCallId)) {
return
}
const index = this.state.currentBlockIndex++
this.state.toolBlocks.set(toolCallId, index)
const inputJson = JSON.stringify(args)
this.state.blocks.set(index, {
type: 'tool_use',
index,
started: true,
content: inputJson,
toolId: toolCallId,
toolName,
toolInput: inputJson
})
// Emit content_block_start for tool_use
const contentBlock: ToolUseBlock = {
type: 'tool_use',
id: toolCallId,
name: toolName,
input: {}
}
const startEvent: RawContentBlockStartEvent = {
type: 'content_block_start',
index,
content_block: contentBlock
}
this.onEvent(startEvent)
// Emit the full input as a delta (Anthropic streams JSON incrementally)
const delta: InputJSONDelta = {
type: 'input_json_delta',
partial_json: inputJson
}
const deltaEvent: RawContentBlockDeltaEvent = {
type: 'content_block_delta',
index,
delta
}
this.onEvent(deltaEvent)
// Emit content_block_stop
const stopEvent: RawContentBlockStopEvent = {
type: 'content_block_stop',
index
}
this.onEvent(stopEvent)
// Mark that we have tool use
this.state.stopReason = 'tool_use'
}
private handleFinish(chunk: { type: 'finish'; finishReason?: FinishReason; totalUsage?: LanguageModelUsage }): void {
// Update usage
if (chunk.totalUsage) {
this.state.inputTokens = chunk.totalUsage.inputTokens || 0
this.state.outputTokens = chunk.totalUsage.outputTokens || 0
this.state.cacheInputTokens = chunk.totalUsage.cachedInputTokens || 0
}
// Determine finish reason
if (!this.state.stopReason) {
switch (chunk.finishReason) {
case 'stop':
this.state.stopReason = 'end_turn'
break
case 'length':
this.state.stopReason = 'max_tokens'
break
case 'tool-calls':
this.state.stopReason = 'tool_use'
break
case 'content-filter':
this.state.stopReason = 'refusal'
break
default:
this.state.stopReason = 'end_turn'
}
}
}
private finalize(): void {
// Close any open blocks
if (this.state.textBlockIndex !== null) {
this.stopTextBlock()
}
// Close all open thinking blocks
for (const reasoningId of this.state.thinkingBlocks.keys()) {
this.stopThinkingBlock(reasoningId)
}
// Emit message_delta with final stop reason and usage
const usage: MessageDeltaUsage = {
output_tokens: this.state.outputTokens,
input_tokens: this.state.inputTokens,
cache_creation_input_tokens: this.state.cacheInputTokens,
cache_read_input_tokens: null,
server_tool_use: null
}
const messageDeltaEvent: RawMessageDeltaEvent = {
type: 'message_delta',
delta: {
stop_reason: this.state.stopReason || 'end_turn',
stop_sequence: null
},
usage
}
this.onEvent(messageDeltaEvent)
// Emit message_stop
const messageStopEvent: RawMessageStopEvent = {
type: 'message_stop'
}
this.onEvent(messageStopEvent)
}
/**
* Set input token count (typically from prompt)
*/
setInputTokens(count: number): void {
this.state.inputTokens = count
}
/**
* Get the current message ID
*/
getMessageId(): string {
return this.state.messageId
}
/**
* Build a complete Message object for non-streaming responses
*/
buildNonStreamingResponse(): Message {
const content: ContentBlock[] = []
// Collect all content blocks in order
const sortedBlocks = Array.from(this.state.blocks.values()).sort((a, b) => a.index - b.index)
for (const block of sortedBlocks) {
switch (block.type) {
case 'text':
content.push({
type: 'text',
text: block.content,
citations: null
} as TextBlock)
break
case 'thinking':
content.push({
type: 'thinking',
thinking: block.content
} as ThinkingBlock)
break
case 'tool_use':
content.push({
type: 'tool_use',
id: block.toolId!,
name: block.toolName!,
input: JSON.parse(block.toolInput || '{}')
} as ToolUseBlock)
break
}
}
return {
id: this.state.messageId,
type: 'message',
role: 'assistant',
content,
model: this.state.model,
stop_reason: this.state.stopReason || 'end_turn',
stop_sequence: null,
usage: {
input_tokens: this.state.inputTokens,
output_tokens: this.state.outputTokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
server_tool_use: null
}
}
}
}
/**
* Format an Anthropic SSE event for HTTP streaming
*/
export function formatSSEEvent(event: RawMessageStreamEvent): string {
return `event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n`
}
/**
* Create a done marker for SSE stream
*/
export function formatSSEDone(): string {
return 'data: [DONE]\n\n'
}
export default AiSdkToAnthropicSSE

View File

@ -0,0 +1,536 @@
import type { RawMessageStreamEvent } from '@anthropic-ai/sdk/resources/messages'
import type { FinishReason, LanguageModelUsage, TextStreamPart, ToolSet } from 'ai'
import { describe, expect, it, vi } from 'vitest'
import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '../AiSdkToAnthropicSSE'
const createTextDelta = (text: string, id = 'text_0'): TextStreamPart<ToolSet> => ({
type: 'text-delta',
id,
text
})
const createTextStart = (id = 'text_0'): TextStreamPart<ToolSet> => ({
type: 'text-start',
id
})
const createTextEnd = (id = 'text_0'): TextStreamPart<ToolSet> => ({
type: 'text-end',
id
})
const createFinish = (
finishReason: FinishReason | undefined = 'stop',
totalUsage?: Partial<LanguageModelUsage>
): TextStreamPart<ToolSet> => {
const defaultUsage: LanguageModelUsage = {
inputTokens: 0,
outputTokens: 0,
totalTokens: 0
}
const event: TextStreamPart<ToolSet> = {
type: 'finish',
finishReason: finishReason || 'stop',
totalUsage: { ...defaultUsage, ...totalUsage }
}
return event
}
// Helper to create stream
function createMockStream(events: readonly TextStreamPart<ToolSet>[]) {
return new ReadableStream<TextStreamPart<ToolSet>>({
start(controller) {
for (const event of events) {
controller.enqueue(event)
}
controller.close()
}
})
}
describe('AiSdkToAnthropicSSE', () => {
describe('Text Processing', () => {
it('should emit message_start and process text-delta events', async () => {
const events: RawMessageStreamEvent[] = []
const adapter = new AiSdkToAnthropicSSE({
model: 'test:model',
onEvent: (event) => events.push(event)
})
// Create a mock stream with text events
const stream = createMockStream([createTextDelta('Hello'), createTextDelta(' world'), createFinish('stop')])
await adapter.processStream(stream)
// Verify message_start
expect(events[0]).toMatchObject({
type: 'message_start',
message: {
role: 'assistant',
model: 'test:model'
}
})
// Verify content_block_start for text
expect(events[1]).toMatchObject({
type: 'content_block_start',
content_block: { type: 'text' }
})
// Verify text deltas
expect(events[2]).toMatchObject({
type: 'content_block_delta',
delta: { type: 'text_delta', text: 'Hello' }
})
expect(events[3]).toMatchObject({
type: 'content_block_delta',
delta: { type: 'text_delta', text: ' world' }
})
// Verify content_block_stop
expect(events[4]).toMatchObject({
type: 'content_block_stop'
})
// Verify message_delta with stop_reason
expect(events[5]).toMatchObject({
type: 'message_delta',
delta: { stop_reason: 'end_turn' }
})
// Verify message_stop
expect(events[6]).toMatchObject({
type: 'message_stop'
})
})
it('should handle text-start and text-end events', async () => {
const events: RawMessageStreamEvent[] = []
const adapter = new AiSdkToAnthropicSSE({
model: 'test:model',
onEvent: (event) => events.push(event)
})
const stream = createMockStream([
createTextStart(),
createTextDelta('Test'),
createTextEnd(),
createFinish('stop')
])
await adapter.processStream(stream)
// Should have content_block_start, delta, and content_block_stop
const blockEvents = events.filter((e) => e.type.startsWith('content_block'))
expect(blockEvents.length).toBeGreaterThanOrEqual(3)
})
it('should auto-start text block if not explicitly started', async () => {
const events: RawMessageStreamEvent[] = []
const adapter = new AiSdkToAnthropicSSE({
model: 'test:model',
onEvent: (event) => events.push(event)
})
const stream = createMockStream([createTextDelta('Auto-started'), createFinish('stop')])
await adapter.processStream(stream)
// Should automatically emit content_block_start
expect(events.some((e) => e.type === 'content_block_start')).toBe(true)
})
})
describe('Tool Call Processing', () => {
it('should emit tool_use block for tool-call events', async () => {
const events: RawMessageStreamEvent[] = []
const adapter = new AiSdkToAnthropicSSE({
model: 'test:model',
onEvent: (event) => events.push(event)
})
const stream = createMockStream([
{
type: 'tool-call',
toolCallId: 'call_123',
toolName: 'get_weather',
input: { location: 'SF' }
},
createFinish('tool-calls')
])
await adapter.processStream(stream)
// Find tool_use block events
const blockStart = events.find((e) => {
if (e.type === 'content_block_start') {
return e.content_block.type === 'tool_use'
}
return false
})
expect(blockStart).toBeDefined()
if (blockStart && blockStart.type === 'content_block_start') {
expect(blockStart.content_block).toMatchObject({
type: 'tool_use',
id: 'call_123',
name: 'get_weather'
})
}
// Should emit input_json_delta
const delta = events.find((e) => {
if (e.type === 'content_block_delta') {
return e.delta.type === 'input_json_delta'
}
return false
})
expect(delta).toBeDefined()
// Should have stop_reason as tool_use
const messageDelta = events.find((e) => e.type === 'message_delta')
if (messageDelta && messageDelta.type === 'message_delta') {
expect(messageDelta.delta.stop_reason).toBe('tool_use')
}
})
it('should not create duplicate tool blocks', async () => {
const events: RawMessageStreamEvent[] = []
const adapter = new AiSdkToAnthropicSSE({
model: 'test:model',
onEvent: (event) => events.push(event)
})
const toolCallEvent: TextStreamPart<ToolSet> = {
type: 'tool-call',
toolCallId: 'call_123',
toolName: 'test_tool',
input: {}
}
const stream = createMockStream([toolCallEvent, toolCallEvent, createFinish()])
await adapter.processStream(stream)
// Should only have one tool_use block
const toolBlocks = events.filter((e) => {
if (e.type === 'content_block_start') {
return e.content_block.type === 'tool_use'
}
return false
})
expect(toolBlocks.length).toBe(1)
})
})
describe('Reasoning/Thinking Processing', () => {
it('should emit thinking block for reasoning events', async () => {
const events: RawMessageStreamEvent[] = []
const adapter = new AiSdkToAnthropicSSE({
model: 'test:model',
onEvent: (event) => events.push(event)
})
const stream = createMockStream([
{ type: 'reasoning-start', id: 'reason_1' },
{ type: 'reasoning-delta', id: 'reason_1', text: 'Thinking...' },
{ type: 'reasoning-end', id: 'reason_1' },
createFinish()
])
await adapter.processStream(stream)
// Find thinking block events
const blockStart = events.find((e) => {
if (e.type === 'content_block_start') {
return e.content_block.type === 'thinking'
}
return false
})
expect(blockStart).toBeDefined()
// Should emit thinking_delta
const delta = events.find((e) => {
if (e.type === 'content_block_delta') {
return e.delta.type === 'thinking_delta'
}
return false
})
expect(delta).toBeDefined()
if (delta && delta.type === 'content_block_delta' && delta.delta.type === 'thinking_delta') {
expect(delta.delta.thinking).toBe('Thinking...')
}
})
it('should handle multiple thinking blocks', async () => {
const events: RawMessageStreamEvent[] = []
const adapter = new AiSdkToAnthropicSSE({
model: 'test:model',
onEvent: (event) => events.push(event)
})
const stream = createMockStream([
{ type: 'reasoning-start', id: 'reason_1' },
{ type: 'reasoning-delta', id: 'reason_1', text: 'First thought' },
{ type: 'reasoning-start', id: 'reason_2' },
{ type: 'reasoning-delta', id: 'reason_2', text: 'Second thought' },
{ type: 'reasoning-end', id: 'reason_1' },
{ type: 'reasoning-end', id: 'reason_2' },
createFinish()
])
await adapter.processStream(stream)
// Should have two thinking blocks
const thinkingBlocks = events.filter((e) => {
if (e.type === 'content_block_start') {
return e.content_block.type === 'thinking'
}
return false
})
expect(thinkingBlocks.length).toBe(2)
})
})
describe('Finish Reasons', () => {
it('should map finish reasons correctly', async () => {
const testCases: Array<{
aiSdkReason: FinishReason
expectedReason: string
}> = [
{ aiSdkReason: 'stop', expectedReason: 'end_turn' },
{ aiSdkReason: 'length', expectedReason: 'max_tokens' },
{ aiSdkReason: 'tool-calls', expectedReason: 'tool_use' },
{ aiSdkReason: 'content-filter', expectedReason: 'refusal' }
]
for (const { aiSdkReason, expectedReason } of testCases) {
const events: RawMessageStreamEvent[] = []
const adapter = new AiSdkToAnthropicSSE({
model: 'test:model',
onEvent: (event) => events.push(event)
})
const stream = createMockStream([createFinish(aiSdkReason)])
await adapter.processStream(stream)
const messageDelta = events.find((e) => e.type === 'message_delta')
if (messageDelta && messageDelta.type === 'message_delta') {
expect(messageDelta.delta.stop_reason).toBe(expectedReason)
}
}
})
})
describe('Usage Tracking', () => {
it('should track token usage', async () => {
const events: RawMessageStreamEvent[] = []
const adapter = new AiSdkToAnthropicSSE({
model: 'test:model',
inputTokens: 100,
onEvent: (event) => events.push(event)
})
const stream = createMockStream([
createTextDelta('Hello'),
createFinish('stop', {
inputTokens: 100,
outputTokens: 50,
cachedInputTokens: 20
})
])
await adapter.processStream(stream)
const messageDelta = events.find((e) => e.type === 'message_delta')
if (messageDelta && messageDelta.type === 'message_delta') {
expect(messageDelta.usage).toMatchObject({
input_tokens: 100,
output_tokens: 50,
cache_creation_input_tokens: 20
})
}
})
})
describe('Non-Streaming Response', () => {
it('should build complete message for non-streaming', async () => {
const adapter = new AiSdkToAnthropicSSE({
model: 'test:model',
onEvent: vi.fn()
})
const stream = createMockStream([
createTextDelta('Hello world'),
{
type: 'tool-call',
toolCallId: 'call_1',
toolName: 'test',
input: { arg: 'value' }
},
createFinish('tool-calls', { inputTokens: 10, outputTokens: 20 })
])
await adapter.processStream(stream)
const response = adapter.buildNonStreamingResponse()
expect(response).toMatchObject({
type: 'message',
role: 'assistant',
model: 'test:model',
stop_reason: 'tool_use'
})
expect(response.content).toHaveLength(2)
expect(response.content[0]).toMatchObject({
type: 'text',
text: 'Hello world'
})
expect(response.content[1]).toMatchObject({
type: 'tool_use',
id: 'call_1',
name: 'test',
input: { arg: 'value' }
})
expect(response.usage).toMatchObject({
input_tokens: 10,
output_tokens: 20
})
})
})
describe('Error Handling', () => {
it('should throw on error events', async () => {
const adapter = new AiSdkToAnthropicSSE({
model: 'test:model',
onEvent: vi.fn()
})
const testError = new Error('Test error')
const stream = createMockStream([{ type: 'error', error: testError }])
await expect(adapter.processStream(stream)).rejects.toThrow('Test error')
})
})
describe('Edge Cases', () => {
it('should handle empty stream', async () => {
const events: RawMessageStreamEvent[] = []
const adapter = new AiSdkToAnthropicSSE({
model: 'test:model',
onEvent: (event) => events.push(event)
})
const stream = new ReadableStream<TextStreamPart<ToolSet>>({
start(controller) {
controller.close()
}
})
await adapter.processStream(stream)
// Should still emit message_start, message_delta, and message_stop
expect(events.some((e) => e.type === 'message_start')).toBe(true)
expect(events.some((e) => e.type === 'message_delta')).toBe(true)
expect(events.some((e) => e.type === 'message_stop')).toBe(true)
})
it('should handle empty text deltas', async () => {
const events: RawMessageStreamEvent[] = []
const adapter = new AiSdkToAnthropicSSE({
model: 'test:model',
onEvent: (event) => events.push(event)
})
const stream = createMockStream([createTextDelta(''), createTextDelta(''), createFinish()])
await adapter.processStream(stream)
// Should not emit deltas for empty text
const deltas = events.filter((e) => e.type === 'content_block_delta')
expect(deltas.length).toBe(0)
})
})
describe('Utility Functions', () => {
it('should format SSE events correctly', () => {
const event: RawMessageStreamEvent = {
type: 'message_start',
message: {
id: 'msg_123',
type: 'message',
role: 'assistant',
content: [],
model: 'test',
stop_reason: null,
stop_sequence: null,
usage: {
input_tokens: 10,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
server_tool_use: null
}
}
}
const formatted = formatSSEEvent(event)
expect(formatted).toContain('event: message_start')
expect(formatted).toContain('data: ')
expect(formatted).toContain('"type":"message_start"')
expect(formatted.endsWith('\n\n')).toBe(true)
})
it('should format SSE done marker correctly', () => {
const done = formatSSEDone()
expect(done).toBe('data: [DONE]\n\n')
})
})
describe('Message ID', () => {
it('should use provided message ID', () => {
const adapter = new AiSdkToAnthropicSSE({
model: 'test:model',
messageId: 'custom_msg_123',
onEvent: vi.fn()
})
expect(adapter.getMessageId()).toBe('custom_msg_123')
})
it('should generate message ID if not provided', () => {
const adapter = new AiSdkToAnthropicSSE({
model: 'test:model',
onEvent: vi.fn()
})
const messageId = adapter.getMessageId()
expect(messageId).toMatch(/^msg_/)
})
})
describe('Input Tokens', () => {
it('should allow setting input tokens', () => {
const events: RawMessageStreamEvent[] = []
const adapter = new AiSdkToAnthropicSSE({
model: 'test:model',
onEvent: (event) => events.push(event)
})
adapter.setInputTokens(500)
const stream = createMockStream([createFinish()])
return adapter.processStream(stream).then(() => {
const messageStart = events.find((e) => e.type === 'message_start')
if (messageStart && messageStart.type === 'message_start') {
expect(messageStart.message.usage.input_tokens).toBe(500)
}
})
})
})
})

View File

@ -0,0 +1,13 @@
/**
* Shared Adapters
*
* This module exports adapters for converting between different AI API formats.
*/
export {
AiSdkToAnthropicSSE,
type AiSdkToAnthropicSSEOptions,
formatSSEDone,
formatSSEEvent,
type SSEEventCallback
} from './AiSdkToAnthropicSSE'

View File

@ -0,0 +1,95 @@
import * as z from 'zod/v4'
enum ReasoningFormat {
Unknown = 'unknown',
OpenAIResponsesV1 = 'openai-responses-v1',
XAIResponsesV1 = 'xai-responses-v1',
AnthropicClaudeV1 = 'anthropic-claude-v1',
GoogleGeminiV1 = 'google-gemini-v1'
}
// Anthropic Claude was the first reasoning that we're
// passing back and forth
export const DEFAULT_REASONING_FORMAT = ReasoningFormat.AnthropicClaudeV1
function isDefinedOrNotNull<T>(value: T | null | undefined): value is T {
return value !== null && value !== undefined
}
export enum ReasoningDetailType {
Summary = 'reasoning.summary',
Encrypted = 'reasoning.encrypted',
Text = 'reasoning.text'
}
export const CommonReasoningDetailSchema = z
.object({
id: z.string().nullish(),
format: z.enum(ReasoningFormat).nullish(),
index: z.number().optional()
})
.loose()
export const ReasoningDetailSummarySchema = z
.object({
type: z.literal(ReasoningDetailType.Summary),
summary: z.string()
})
.extend(CommonReasoningDetailSchema.shape)
export type ReasoningDetailSummary = z.infer<typeof ReasoningDetailSummarySchema>
export const ReasoningDetailEncryptedSchema = z
.object({
type: z.literal(ReasoningDetailType.Encrypted),
data: z.string()
})
.extend(CommonReasoningDetailSchema.shape)
export type ReasoningDetailEncrypted = z.infer<typeof ReasoningDetailEncryptedSchema>
export const ReasoningDetailTextSchema = z
.object({
type: z.literal(ReasoningDetailType.Text),
text: z.string().nullish(),
signature: z.string().nullish()
})
.extend(CommonReasoningDetailSchema.shape)
export type ReasoningDetailText = z.infer<typeof ReasoningDetailTextSchema>
export const ReasoningDetailUnionSchema = z.union([
ReasoningDetailSummarySchema,
ReasoningDetailEncryptedSchema,
ReasoningDetailTextSchema
])
export type ReasoningDetailUnion = z.infer<typeof ReasoningDetailUnionSchema>
const ReasoningDetailsWithUnknownSchema = z.union([ReasoningDetailUnionSchema, z.unknown().transform(() => null)])
export const ReasoningDetailArraySchema = z
.array(ReasoningDetailsWithUnknownSchema)
.transform((d) => d.filter((d): d is ReasoningDetailUnion => !!d))
export const OutputUnionToReasoningDetailsSchema = z.union([
z
.object({
delta: z.object({
reasoning_details: z.array(ReasoningDetailsWithUnknownSchema)
})
})
.transform((data) => data.delta.reasoning_details.filter(isDefinedOrNotNull)),
z
.object({
message: z.object({
reasoning_details: z.array(ReasoningDetailsWithUnknownSchema)
})
})
.transform((data) => data.message.reasoning_details.filter(isDefinedOrNotNull)),
z
.object({
text: z.string(),
reasoning_details: z.array(ReasoningDetailsWithUnknownSchema)
})
.transform((data) => data.reasoning_details.filter(isDefinedOrNotNull))
])

View File

@ -0,0 +1,393 @@
import { describe, expect, it } from 'vitest'
import { estimateTokenCount } from '../messages'
describe('estimateTokenCount', () => {
describe('Text Content', () => {
it('should estimate tokens for simple string content', () => {
const input = {
messages: [
{
role: 'user' as const,
content: 'Hello, world!'
}
]
}
const tokens = estimateTokenCount(input)
// Should include text tokens + role overhead (3)
expect(tokens).toBeGreaterThan(3)
expect(tokens).toBeLessThan(20)
})
it('should estimate tokens for multiple messages', () => {
const input = {
messages: [
{ role: 'user' as const, content: 'First message' },
{ role: 'assistant' as const, content: 'Second message' },
{ role: 'user' as const, content: 'Third message' }
]
}
const tokens = estimateTokenCount(input)
// Should include text tokens + role overhead (3 per message = 9)
expect(tokens).toBeGreaterThan(9)
})
it('should estimate tokens for text content blocks', () => {
const input = {
messages: [
{
role: 'user' as const,
content: [
{ type: 'text' as const, text: 'Hello' },
{ type: 'text' as const, text: 'World' }
]
}
]
}
const tokens = estimateTokenCount(input)
expect(tokens).toBeGreaterThan(3)
})
it('should handle empty messages array', () => {
const input = {
messages: []
}
const tokens = estimateTokenCount(input)
expect(tokens).toBe(0)
})
it('should handle messages with empty content', () => {
const input = {
messages: [{ role: 'user' as const, content: '' }]
}
const tokens = estimateTokenCount(input)
// Should only have role overhead (3)
expect(tokens).toBe(3)
})
})
describe('System Messages', () => {
it('should estimate tokens for string system message', () => {
const input = {
messages: [{ role: 'user' as const, content: 'Hello' }],
system: 'You are a helpful assistant.'
}
const tokens = estimateTokenCount(input)
// Should include system tokens + message tokens + role overhead
expect(tokens).toBeGreaterThan(3)
})
it('should estimate tokens for system content blocks', () => {
const input = {
messages: [{ role: 'user' as const, content: 'Hello' }],
system: [
{ type: 'text' as const, text: 'System instruction 1' },
{ type: 'text' as const, text: 'System instruction 2' }
]
}
const tokens = estimateTokenCount(input)
expect(tokens).toBeGreaterThan(3)
})
})
describe('Image Content', () => {
it('should estimate tokens for base64 images', () => {
// Create a fake base64 string (400 characters = ~300 bytes when decoded)
const fakeBase64 = 'A'.repeat(400)
const input = {
messages: [
{
role: 'user' as const,
content: [
{
type: 'image' as const,
source: {
type: 'base64' as const,
media_type: 'image/png' as const,
data: fakeBase64
}
}
]
}
]
}
const tokens = estimateTokenCount(input)
// Should estimate based on data size: 400 * 0.75 / 100 = 3 tokens + role overhead (3)
expect(tokens).toBeGreaterThan(3)
expect(tokens).toBeLessThan(10)
})
it('should estimate tokens for URL images', () => {
const input = {
messages: [
{
role: 'user' as const,
content: [
{
type: 'image' as const,
source: {
type: 'url' as const,
url: 'https://example.com/image.png'
}
}
]
}
]
}
const tokens = estimateTokenCount(input)
// Should use default estimate: 1000 + role overhead (3)
expect(tokens).toBe(1003)
})
it('should estimate tokens for mixed text and image content', () => {
const input = {
messages: [
{
role: 'user' as const,
content: [
{ type: 'text' as const, text: 'What is in this image?' },
{
type: 'image' as const,
source: {
type: 'url' as const,
url: 'https://example.com/image.png'
}
}
]
}
]
}
const tokens = estimateTokenCount(input)
// Should include text tokens + 1000 (image) + role overhead (3)
expect(tokens).toBeGreaterThan(1003)
})
})
describe('Tool Content', () => {
it('should estimate tokens for tool_use blocks', () => {
const input = {
messages: [
{
role: 'assistant' as const,
content: [
{
type: 'tool_use' as const,
id: 'tool_123',
name: 'get_weather',
input: { location: 'San Francisco', unit: 'celsius' }
}
]
}
]
}
const tokens = estimateTokenCount(input)
// Should include: tool name tokens + input JSON tokens + 10 (overhead) + 3 (role)
expect(tokens).toBeGreaterThan(13)
})
it('should estimate tokens for tool_result blocks with string content', () => {
const input = {
messages: [
{
role: 'user' as const,
content: [
{
type: 'tool_result' as const,
tool_use_id: 'tool_123',
content: 'The weather in San Francisco is 18°C and sunny.'
}
]
}
]
}
const tokens = estimateTokenCount(input)
// Should include: content tokens + 10 (overhead) + 3 (role)
expect(tokens).toBeGreaterThan(13)
})
it('should estimate tokens for tool_result blocks with array content', () => {
const input = {
messages: [
{
role: 'user' as const,
content: [
{
type: 'tool_result' as const,
tool_use_id: 'tool_123',
content: [
{ type: 'text' as const, text: 'Result 1' },
{ type: 'text' as const, text: 'Result 2' }
]
}
]
}
]
}
const tokens = estimateTokenCount(input)
// Should include: text tokens + 10 (overhead) + 3 (role)
expect(tokens).toBeGreaterThan(13)
})
it('should handle tool_use without input', () => {
const input = {
messages: [
{
role: 'assistant' as const,
content: [
{
type: 'tool_use' as const,
id: 'tool_123',
name: 'no_input_tool',
input: {}
}
]
}
]
}
const tokens = estimateTokenCount(input)
// Should include: tool name tokens + 10 (overhead) + 3 (role)
expect(tokens).toBeGreaterThan(13)
})
})
describe('Complex Scenarios', () => {
it('should estimate tokens for multi-turn conversation with various content types', () => {
const input = {
messages: [
{
role: 'user' as const,
content: [
{ type: 'text' as const, text: 'Analyze this image' },
{
type: 'image' as const,
source: {
type: 'url' as const,
url: 'https://example.com/chart.png'
}
}
]
},
{
role: 'assistant' as const,
content: [
{
type: 'tool_use' as const,
id: 'tool_1',
name: 'analyze_image',
input: { url: 'https://example.com/chart.png' }
}
]
},
{
role: 'user' as const,
content: [
{
type: 'tool_result' as const,
tool_use_id: 'tool_1',
content: 'The chart shows sales data for Q4 2024.'
}
]
},
{
role: 'assistant' as const,
content: 'Based on the analysis, the sales trend is positive.'
}
],
system: 'You are a data analyst assistant.'
}
const tokens = estimateTokenCount(input)
// Should include:
// - System message tokens
// - Message 1: text + image (1000) + 3
// - Message 2: tool_use + 10 + 3
// - Message 3: tool_result + 10 + 3
// - Message 4: text + 3
expect(tokens).toBeGreaterThan(1032) // At least 1000 (image) + 32 (overhead)
})
it('should handle very long text content', () => {
const longText = 'word '.repeat(1000) // ~5000 characters
const input = {
messages: [{ role: 'user' as const, content: longText }]
}
const tokens = estimateTokenCount(input)
// Should estimate based on text length using tokenx
expect(tokens).toBeGreaterThan(1000)
})
it('should handle multiple images in single message', () => {
const input = {
messages: [
{
role: 'user' as const,
content: [
{
type: 'image' as const,
source: { type: 'url' as const, url: 'https://example.com/1.png' }
},
{
type: 'image' as const,
source: { type: 'url' as const, url: 'https://example.com/2.png' }
},
{
type: 'image' as const,
source: { type: 'url' as const, url: 'https://example.com/3.png' }
}
]
}
]
}
const tokens = estimateTokenCount(input)
// Should estimate: 3 * 1000 (images) + 3 (role)
expect(tokens).toBe(3003)
})
})
describe('Edge Cases', () => {
it('should handle undefined system message', () => {
const input = {
messages: [{ role: 'user' as const, content: 'Hello' }],
system: undefined
}
const tokens = estimateTokenCount(input)
expect(tokens).toBeGreaterThan(0)
})
it('should handle empty system message', () => {
const input = {
messages: [{ role: 'user' as const, content: 'Hello' }],
system: ''
}
const tokens = estimateTokenCount(input)
expect(tokens).toBeGreaterThan(0)
})
it('should handle content blocks with missing text', () => {
const input = {
messages: [
{
role: 'user' as const,
content: [{ type: 'text' as const, text: undefined as any }]
}
]
}
const tokens = estimateTokenCount(input)
// Should only have role overhead
expect(tokens).toBe(3)
})
it('should handle empty content array', () => {
const input = {
messages: [
{
role: 'user' as const,
content: []
}
]
}
const tokens = estimateTokenCount(input)
// Should only have role overhead
expect(tokens).toBe(3)
})
})
})

View File

@ -1,17 +1,129 @@
import type { MessageCreateParams } from '@anthropic-ai/sdk/resources' import type { MessageCreateParams } from '@anthropic-ai/sdk/resources'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { buildSharedMiddlewares, type SharedMiddlewareConfig } from '@shared/ai-sdk-middlewares'
import { getAiSdkProviderId } from '@shared/provider'
import type { Provider } from '@types' import type { Provider } from '@types'
import type { Request, Response } from 'express' import type { Request, Response } from 'express'
import express from 'express' import express from 'express'
import { approximateTokenSize } from 'tokenx'
import { messagesService } from '../services/messages' import { messagesService } from '../services/messages'
import { getProviderById, validateModelId } from '../utils' import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages'
import { getProviderById, isModelAnthropicCompatible, validateModelId } from '../utils'
/**
* Check if a specific model on a provider should use direct Anthropic SDK
*
* A provider+model combination is considered "Anthropic-compatible" if:
* 1. It's a native Anthropic provider (type === 'anthropic'), OR
* 2. It has anthropicApiHost configured AND the specific model supports Anthropic API
* (for aggregated providers like Silicon, only certain models support Anthropic endpoint)
*
* @param provider - The provider to check
* @param modelId - The model ID to check (without provider prefix)
* @returns true if should use direct Anthropic SDK, false for unified SDK
*/
function shouldUseDirectAnthropic(provider: Provider, modelId: string): boolean {
// Native Anthropic provider - always use direct SDK
if (provider.type === 'anthropic') {
return true
}
// No anthropicApiHost configured - use unified SDK
if (!provider.anthropicApiHost?.trim()) {
return false
}
// Has anthropicApiHost - check model-level compatibility
// For aggregated providers, only specific models support Anthropic API
return isModelAnthropicCompatible(provider, modelId)
}
const logger = loggerService.withContext('ApiServerMessagesRoutes') const logger = loggerService.withContext('ApiServerMessagesRoutes')
const router = express.Router() const router = express.Router()
const providerRouter = express.Router({ mergeParams: true }) const providerRouter = express.Router({ mergeParams: true })
/**
* Estimate token count from messages
* Uses tokenx library for accurate token estimation and supports images, tools
*/
export interface CountTokensInput {
messages: MessageCreateParams['messages']
system?: MessageCreateParams['system']
}
export function estimateTokenCount(input: CountTokensInput): number {
const { messages, system } = input
let totalTokens = 0
// Count system message tokens using tokenx
if (system) {
if (typeof system === 'string') {
totalTokens += approximateTokenSize(system)
} else if (Array.isArray(system)) {
for (const block of system) {
if (block.type === 'text' && block.text) {
totalTokens += approximateTokenSize(block.text)
}
}
}
}
// Count message tokens
for (const msg of messages) {
if (typeof msg.content === 'string') {
totalTokens += approximateTokenSize(msg.content)
} else if (Array.isArray(msg.content)) {
for (const block of msg.content) {
if (block.type === 'text' && block.text) {
totalTokens += approximateTokenSize(block.text)
} else if (block.type === 'image') {
// Image token estimation (consistent with TokenService)
if (block.source.type === 'base64') {
// Base64 images: estimate from data length
const dataSize = block.source.data.length * 0.75 // base64 to bytes
totalTokens += Math.floor(dataSize / 100)
} else {
// URL images: use default estimate
totalTokens += 1000
}
} else if (block.type === 'tool_use') {
// Tool use token estimation: name + input JSON
if (block.name) {
totalTokens += approximateTokenSize(block.name)
}
if (block.input) {
const inputJson = JSON.stringify(block.input)
totalTokens += approximateTokenSize(inputJson)
}
// Add overhead for tool use structure
totalTokens += 10
} else if (block.type === 'tool_result') {
// Tool result token estimation
if (typeof block.content === 'string') {
totalTokens += approximateTokenSize(block.content)
} else if (Array.isArray(block.content)) {
for (const item of block.content) {
if (typeof item === 'string') {
totalTokens += approximateTokenSize(item)
} else if (item.type === 'text' && item.text) {
totalTokens += approximateTokenSize(item.text)
}
}
}
// Add overhead for tool result structure
totalTokens += 10
}
}
}
// Add role overhead
totalTokens += 3
}
return totalTokens
}
// Helper function for basic request validation // Helper function for basic request validation
async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> { async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> {
const request: MessageCreateParams = req.body const request: MessageCreateParams = req.body
@ -32,22 +144,101 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro
return { valid: true } return { valid: true }
} }
/**
* Shared handler for count_tokens endpoint
* Validates request and returns token count estimation
*/
async function handleCountTokens(
req: Request,
res: Response,
options: {
requireModel?: boolean
logContext?: Record<string, any>
} = {}
): Promise<Response> {
try {
const { model, messages, system } = req.body
const { requireModel = false, logContext = {} } = options
// Validate model parameter if required
if (requireModel && !model) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'model parameter is required'
}
})
}
// Validate messages parameter
if (!messages || !Array.isArray(messages)) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'messages parameter is required'
}
})
}
// Estimate token count
const estimatedTokens = estimateTokenCount({ messages, system })
// Log with context
logger.debug('Token count estimated', {
model,
messageCount: messages.length,
estimatedTokens,
...logContext
})
return res.json({
input_tokens: estimatedTokens
})
} catch (error: any) {
logger.error('Token counting error', { error })
return res.status(500).json({
type: 'error',
error: {
type: 'api_error',
message: error.message || 'Internal server error'
}
})
}
}
interface HandleMessageProcessingOptions { interface HandleMessageProcessingOptions {
req: Request
res: Response res: Response
provider: Provider provider: Provider
request: MessageCreateParams request: MessageCreateParams
modelId?: string modelId?: string
} }
async function handleMessageProcessing({ /**
req, * Handle message processing using direct Anthropic SDK
* Used for providers with anthropicApiHost or native Anthropic providers
* This bypasses AI SDK conversion and uses native Anthropic protocol
*/
async function handleDirectAnthropicProcessing({
res, res,
provider, provider,
request, request,
modelId modelId,
}: HandleMessageProcessingOptions): Promise<void> { extraHeaders
}: HandleMessageProcessingOptions & { extraHeaders?: Record<string, string | string[]> }): Promise<void> {
const actualModelId = modelId || request.model
logger.info('Processing message via direct Anthropic SDK', {
providerId: provider.id,
providerType: provider.type,
modelId: actualModelId,
stream: !!request.stream,
anthropicApiHost: provider.anthropicApiHost
})
try { try {
// Validate request
const validation = messagesService.validateRequest(request) const validation = messagesService.validateRequest(request)
if (!validation.isValid) { if (!validation.isValid) {
res.status(400).json({ res.status(400).json({
@ -60,28 +251,126 @@ async function handleMessageProcessing({
return return
} }
const extraHeaders = messagesService.prepareHeaders(req.headers) // Process message using messagesService (native Anthropic SDK)
const { client, anthropicRequest } = await messagesService.processMessage({ const { client, anthropicRequest } = await messagesService.processMessage({
provider, provider,
request, request,
extraHeaders, extraHeaders,
modelId modelId: actualModelId
}) })
if (request.stream) { if (request.stream) {
// Use native Anthropic streaming
await messagesService.handleStreaming(client, anthropicRequest, { response: res }, provider) await messagesService.handleStreaming(client, anthropicRequest, { response: res }, provider)
return } else {
} // Use native Anthropic non-streaming
const response = await client.messages.create(anthropicRequest) const response = await client.messages.create(anthropicRequest)
res.json(response) res.json(response)
}
} catch (error: any) { } catch (error: any) {
logger.error('Message processing error', { error }) logger.error('Direct Anthropic processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error) const { statusCode, errorResponse } = messagesService.transformError(error)
res.status(statusCode).json(errorResponse) res.status(statusCode).json(errorResponse)
} }
} }
/**
* Handle message processing using unified AI SDK
* Used for non-Anthropic providers that need format conversion
* - Uses AI SDK adapters with output converted to Anthropic SSE format
*/
async function handleUnifiedProcessing({
res,
provider,
request,
modelId
}: HandleMessageProcessingOptions): Promise<void> {
const actualModelId = modelId || request.model
logger.info('Processing message via unified AI SDK', {
providerId: provider.id,
providerType: provider.type,
modelId: actualModelId,
stream: !!request.stream
})
try {
// Validate request
const validation = messagesService.validateRequest(request)
if (!validation.isValid) {
res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: validation.errors.join('; ')
}
})
return
}
const middlewareConfig: SharedMiddlewareConfig = {
modelId: actualModelId,
providerId: provider.id,
aiSdkProviderId: getAiSdkProviderId(provider)
}
const middlewares = buildSharedMiddlewares(middlewareConfig)
logger.debug('Built middlewares for unified processing', {
middlewareCount: middlewares.length,
modelId: actualModelId,
providerId: provider.id
})
if (request.stream) {
await streamUnifiedMessages({
response: res,
provider,
modelId: actualModelId,
params: request,
middlewares,
onError: (error) => {
logger.error('Stream error', error as Error)
},
onComplete: () => {
logger.debug('Stream completed')
}
})
} else {
const response = await generateUnifiedMessage({
provider,
modelId: actualModelId,
params: request,
middlewares
})
res.json(response)
}
} catch (error: any) {
const { statusCode, errorResponse } = messagesService.transformError(error)
res.status(statusCode).json(errorResponse)
}
}
/**
* Handle message processing - routes to appropriate handler based on provider and model
*
* Routing logic:
* - Native Anthropic providers (type === 'anthropic'): Direct Anthropic SDK
* - Providers with anthropicApiHost AND model supports Anthropic API: Direct Anthropic SDK
* - Other providers/models: Unified AI SDK with Anthropic SSE conversion
*/
async function handleMessageProcessing({
res,
provider,
request,
modelId
}: HandleMessageProcessingOptions): Promise<void> {
const actualModelId = modelId || request.model
if (shouldUseDirectAnthropic(provider, actualModelId)) {
return handleDirectAnthropicProcessing({ res, provider, request, modelId })
}
return handleUnifiedProcessing({ res, provider, request, modelId })
}
/** /**
* @swagger * @swagger
* /v1/messages: * /v1/messages:
@ -235,7 +524,7 @@ router.post('/', async (req: Request, res: Response) => {
const provider = modelValidation.provider! const provider = modelValidation.provider!
const modelId = modelValidation.modelId! const modelId = modelValidation.modelId!
return handleMessageProcessing({ req, res, provider, request, modelId }) return handleMessageProcessing({ res, provider, request, modelId })
} catch (error: any) { } catch (error: any) {
logger.error('Message processing error', { error }) logger.error('Message processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error) const { statusCode, errorResponse } = messagesService.transformError(error)
@ -393,7 +682,7 @@ providerRouter.post('/', async (req: Request, res: Response) => {
const request: MessageCreateParams = req.body const request: MessageCreateParams = req.body
return handleMessageProcessing({ req, res, provider, request }) return handleMessageProcessing({ res, provider, request })
} catch (error: any) { } catch (error: any) {
logger.error('Message processing error', { error }) logger.error('Message processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error) const { statusCode, errorResponse } = messagesService.transformError(error)
@ -401,4 +690,58 @@ providerRouter.post('/', async (req: Request, res: Response) => {
} }
}) })
/**
* @swagger
* /v1/messages/count_tokens:
* post:
* summary: Count tokens for messages
* description: Count tokens for Anthropic Messages API format (required by Claude Code SDK)
* tags: [Messages]
* requestBody:
* required: true
* content:
* application/json:
* schema:
* type: object
* required:
* - model
* - messages
* properties:
* model:
* type: string
* description: Model ID
* messages:
* type: array
* items:
* type: object
* system:
* type: string
* description: System message
* responses:
* 200:
* description: Token count response
* content:
* application/json:
* schema:
* type: object
* properties:
* input_tokens:
* type: integer
* 400:
* description: Bad request
*/
router.post('/count_tokens', async (req: Request, res: Response) => {
return handleCountTokens(req, res, { requireModel: true })
})
/**
* Provider-specific count_tokens endpoint
*/
providerRouter.post('/count_tokens', async (req: Request, res: Response) => {
return handleCountTokens(req, res, {
requireModel: false,
logContext: { providerId: req.params.provider }
})
})
export { providerRouter as messagesProviderRoutes, router as messagesRoutes } export { providerRouter as messagesProviderRoutes, router as messagesRoutes }

View File

@ -0,0 +1,340 @@
import { describe, expect, it } from 'vitest'
import * as z from 'zod'
import { type JsonSchemaLike, jsonSchemaToZod } from '../unified-messages'
describe('jsonSchemaToZod', () => {
describe('Basic Types', () => {
it('should convert string type', () => {
const schema: JsonSchemaLike = { type: 'string' }
const result = jsonSchemaToZod(schema)
expect(result).toBeInstanceOf(z.ZodString)
expect(result.safeParse('hello').success).toBe(true)
expect(result.safeParse(123).success).toBe(false)
})
it('should convert string with minLength', () => {
const schema: JsonSchemaLike = { type: 'string', minLength: 3 }
const result = jsonSchemaToZod(schema)
expect(result.safeParse('ab').success).toBe(false)
expect(result.safeParse('abc').success).toBe(true)
})
it('should convert string with maxLength', () => {
const schema: JsonSchemaLike = { type: 'string', maxLength: 5 }
const result = jsonSchemaToZod(schema)
expect(result.safeParse('hello').success).toBe(true)
expect(result.safeParse('hello world').success).toBe(false)
})
it('should convert string with pattern', () => {
const schema: JsonSchemaLike = { type: 'string', pattern: '^[0-9]+$' }
const result = jsonSchemaToZod(schema)
expect(result.safeParse('123').success).toBe(true)
expect(result.safeParse('abc').success).toBe(false)
})
it('should convert number type', () => {
const schema: JsonSchemaLike = { type: 'number' }
const result = jsonSchemaToZod(schema)
expect(result).toBeInstanceOf(z.ZodNumber)
expect(result.safeParse(42).success).toBe(true)
expect(result.safeParse(3.14).success).toBe(true)
expect(result.safeParse('42').success).toBe(false)
})
it('should convert integer type', () => {
const schema: JsonSchemaLike = { type: 'integer' }
const result = jsonSchemaToZod(schema)
expect(result.safeParse(42).success).toBe(true)
expect(result.safeParse(3.14).success).toBe(false)
})
it('should convert number with minimum', () => {
const schema: JsonSchemaLike = { type: 'number', minimum: 10 }
const result = jsonSchemaToZod(schema)
expect(result.safeParse(5).success).toBe(false)
expect(result.safeParse(10).success).toBe(true)
expect(result.safeParse(15).success).toBe(true)
})
it('should convert number with maximum', () => {
const schema: JsonSchemaLike = { type: 'number', maximum: 100 }
const result = jsonSchemaToZod(schema)
expect(result.safeParse(50).success).toBe(true)
expect(result.safeParse(100).success).toBe(true)
expect(result.safeParse(150).success).toBe(false)
})
it('should convert boolean type', () => {
const schema: JsonSchemaLike = { type: 'boolean' }
const result = jsonSchemaToZod(schema)
expect(result).toBeInstanceOf(z.ZodBoolean)
expect(result.safeParse(true).success).toBe(true)
expect(result.safeParse(false).success).toBe(true)
expect(result.safeParse('true').success).toBe(false)
})
it('should convert null type', () => {
const schema: JsonSchemaLike = { type: 'null' }
const result = jsonSchemaToZod(schema)
expect(result).toBeInstanceOf(z.ZodNull)
expect(result.safeParse(null).success).toBe(true)
expect(result.safeParse(undefined).success).toBe(false)
})
})
describe('Enum Types', () => {
it('should convert string enum', () => {
const schema: JsonSchemaLike = { enum: ['red', 'green', 'blue'] }
const result = jsonSchemaToZod(schema)
expect(result.safeParse('red').success).toBe(true)
expect(result.safeParse('green').success).toBe(true)
expect(result.safeParse('yellow').success).toBe(false)
})
it('should convert non-string enum with literals', () => {
const schema: JsonSchemaLike = { enum: [1, 2, 3] }
const result = jsonSchemaToZod(schema)
expect(result.safeParse(1).success).toBe(true)
expect(result.safeParse(2).success).toBe(true)
expect(result.safeParse(4).success).toBe(false)
})
it('should convert single value enum', () => {
const schema: JsonSchemaLike = { enum: ['only'] }
const result = jsonSchemaToZod(schema)
expect(result.safeParse('only').success).toBe(true)
expect(result.safeParse('other').success).toBe(false)
})
it('should convert mixed enum', () => {
const schema: JsonSchemaLike = { enum: ['text', 1, true] }
const result = jsonSchemaToZod(schema)
expect(result.safeParse('text').success).toBe(true)
expect(result.safeParse(1).success).toBe(true)
expect(result.safeParse(true).success).toBe(true)
expect(result.safeParse(false).success).toBe(false)
})
})
describe('Array Types', () => {
it('should convert array of strings', () => {
const schema: JsonSchemaLike = {
type: 'array',
items: { type: 'string' }
}
const result = jsonSchemaToZod(schema)
expect(result.safeParse(['a', 'b']).success).toBe(true)
expect(result.safeParse([1, 2]).success).toBe(false)
})
it('should convert array without items (unknown)', () => {
const schema: JsonSchemaLike = { type: 'array' }
const result = jsonSchemaToZod(schema)
expect(result.safeParse([]).success).toBe(true)
expect(result.safeParse(['a', 1, true]).success).toBe(true)
})
it('should convert array with minItems', () => {
const schema: JsonSchemaLike = {
type: 'array',
items: { type: 'number' },
minItems: 2
}
const result = jsonSchemaToZod(schema)
expect(result.safeParse([1]).success).toBe(false)
expect(result.safeParse([1, 2]).success).toBe(true)
})
it('should convert array with maxItems', () => {
const schema: JsonSchemaLike = {
type: 'array',
items: { type: 'number' },
maxItems: 3
}
const result = jsonSchemaToZod(schema)
expect(result.safeParse([1, 2, 3]).success).toBe(true)
expect(result.safeParse([1, 2, 3, 4]).success).toBe(false)
})
})
describe('Object Types', () => {
it('should convert simple object', () => {
const schema: JsonSchemaLike = {
type: 'object',
properties: {
name: { type: 'string' },
age: { type: 'number' }
}
}
const result = jsonSchemaToZod(schema)
expect(result.safeParse({ name: 'John', age: 30 }).success).toBe(true)
expect(result.safeParse({ name: 'John', age: '30' }).success).toBe(false)
})
it('should handle required fields', () => {
const schema: JsonSchemaLike = {
type: 'object',
properties: {
name: { type: 'string' },
age: { type: 'number' }
},
required: ['name']
}
const result = jsonSchemaToZod(schema)
expect(result.safeParse({ name: 'John', age: 30 }).success).toBe(true)
expect(result.safeParse({ age: 30 }).success).toBe(false)
expect(result.safeParse({ name: 'John' }).success).toBe(true)
})
it('should convert empty object', () => {
const schema: JsonSchemaLike = { type: 'object' }
const result = jsonSchemaToZod(schema)
expect(result.safeParse({}).success).toBe(true)
})
it('should convert nested objects', () => {
const schema: JsonSchemaLike = {
type: 'object',
properties: {
user: {
type: 'object',
properties: {
name: { type: 'string' },
email: { type: 'string' }
}
}
}
}
const result = jsonSchemaToZod(schema)
expect(result.safeParse({ user: { name: 'John', email: 'john@example.com' } }).success).toBe(true)
expect(result.safeParse({ user: { name: 'John' } }).success).toBe(true)
})
})
describe('Union Types', () => {
it('should convert union type (type array)', () => {
const schema: JsonSchemaLike = { type: ['string', 'null'] }
const result = jsonSchemaToZod(schema)
expect(result.safeParse('hello').success).toBe(true)
expect(result.safeParse(null).success).toBe(true)
expect(result.safeParse(123).success).toBe(false)
})
it('should convert single type array', () => {
const schema: JsonSchemaLike = { type: ['string'] }
const result = jsonSchemaToZod(schema)
expect(result.safeParse('hello').success).toBe(true)
expect(result.safeParse(123).success).toBe(false)
})
it('should convert multiple union types', () => {
const schema: JsonSchemaLike = { type: ['string', 'number', 'boolean'] }
const result = jsonSchemaToZod(schema)
expect(result.safeParse('text').success).toBe(true)
expect(result.safeParse(42).success).toBe(true)
expect(result.safeParse(true).success).toBe(true)
expect(result.safeParse(null).success).toBe(false)
})
})
describe('Description Handling', () => {
it('should preserve description for string', () => {
const schema: JsonSchemaLike = {
type: 'string',
description: 'A user name'
}
const result = jsonSchemaToZod(schema)
expect(result.description).toBe('A user name')
})
it('should preserve description for enum', () => {
const schema: JsonSchemaLike = {
enum: ['red', 'green', 'blue'],
description: 'Available colors'
}
const result = jsonSchemaToZod(schema)
expect(result.description).toBe('Available colors')
})
it('should preserve description for object', () => {
const schema: JsonSchemaLike = {
type: 'object',
description: 'User object',
properties: {
name: { type: 'string' }
}
}
const result = jsonSchemaToZod(schema)
expect(result.description).toBe('User object')
})
})
describe('Edge Cases', () => {
it('should handle unknown type', () => {
const schema: JsonSchemaLike = { type: 'unknown-type' as any }
const result = jsonSchemaToZod(schema)
expect(result).toBeInstanceOf(z.ZodType)
expect(result.safeParse(anything).success).toBe(true)
})
it('should handle schema without type', () => {
const schema: JsonSchemaLike = {}
const result = jsonSchemaToZod(schema)
expect(result).toBeInstanceOf(z.ZodType)
expect(result.safeParse(anything).success).toBe(true)
})
it('should handle complex nested schema', () => {
const schema: JsonSchemaLike = {
type: 'object',
properties: {
items: {
type: 'array',
items: {
type: 'object',
properties: {
id: { type: 'integer' },
name: { type: 'string' },
tags: {
type: 'array',
items: { type: 'string' }
}
},
required: ['id']
}
}
}
}
const result = jsonSchemaToZod(schema)
const validData = {
items: [
{ id: 1, name: 'Item 1', tags: ['tag1', 'tag2'] },
{ id: 2, tags: [] }
]
}
expect(result.safeParse(validData).success).toBe(true)
const invalidData = {
items: [{ name: 'No ID' }]
}
expect(result.safeParse(invalidData).success).toBe(false)
})
})
describe('OpenRouter Model IDs', () => {
it('should handle model identifier format with colons', () => {
const schema: JsonSchemaLike = {
type: 'string',
enum: ['openrouter:anthropic/claude-3.5-sonnet:free', 'openrouter:gpt-4:paid']
}
const result = jsonSchemaToZod(schema)
expect(result.safeParse('openrouter:anthropic/claude-3.5-sonnet:free').success).toBe(true)
expect(result.safeParse('openrouter:gpt-4:paid').success).toBe(true)
expect(result.safeParse('other').success).toBe(false)
})
})
})
const anything = Math.random() > 0.5 ? 'string' : Math.random() > 0.5 ? 123 : { a: true }

View File

@ -0,0 +1,795 @@
import type { MessageCreateParams } from '@anthropic-ai/sdk/resources/messages'
import { describe, expect, it } from 'vitest'
import { convertAnthropicToAiMessages, convertAnthropicToolsToAiSdk } from '../unified-messages'
describe('unified-messages', () => {
describe('convertAnthropicToolsToAiSdk', () => {
it('should return undefined for empty tools array', () => {
const result = convertAnthropicToolsToAiSdk([])
expect(result).toBeUndefined()
})
it('should return undefined for undefined tools', () => {
const result = convertAnthropicToolsToAiSdk(undefined)
expect(result).toBeUndefined()
})
it('should convert simple tool with string schema', () => {
const anthropicTools: MessageCreateParams['tools'] = [
{
type: 'custom',
name: 'get_weather',
description: 'Get current weather',
input_schema: {
type: 'object',
properties: {
location: { type: 'string' }
},
required: ['location']
}
}
]
const result = convertAnthropicToolsToAiSdk(anthropicTools)
expect(result).toBeDefined()
expect(result).toHaveProperty('get_weather')
expect(result!.get_weather).toHaveProperty('description', 'Get current weather')
})
it('should convert multiple tools', () => {
const anthropicTools: MessageCreateParams['tools'] = [
{
type: 'custom',
name: 'tool1',
description: 'First tool',
input_schema: {
type: 'object',
properties: {}
}
},
{
type: 'custom',
name: 'tool2',
description: 'Second tool',
input_schema: {
type: 'object',
properties: {}
}
}
]
const result = convertAnthropicToolsToAiSdk(anthropicTools)
expect(result).toBeDefined()
expect(Object.keys(result!)).toHaveLength(2)
expect(result).toHaveProperty('tool1')
expect(result).toHaveProperty('tool2')
})
it('should convert tool with complex schema', () => {
const anthropicTools: MessageCreateParams['tools'] = [
{
type: 'custom',
name: 'search',
description: 'Search for information',
input_schema: {
type: 'object',
properties: {
query: { type: 'string', minLength: 1 },
limit: { type: 'integer', minimum: 1, maximum: 100 },
filters: {
type: 'array',
items: { type: 'string' }
}
},
required: ['query']
}
}
]
const result = convertAnthropicToolsToAiSdk(anthropicTools)
expect(result).toBeDefined()
expect(result).toHaveProperty('search')
})
it('should skip bash_20250124 tool type', () => {
const anthropicTools: MessageCreateParams['tools'] = [
{
type: 'bash_20250124',
name: 'bash'
},
{
type: 'custom',
name: 'regular_tool',
description: 'A regular tool',
input_schema: {
type: 'object',
properties: {}
}
}
]
const result = convertAnthropicToolsToAiSdk(anthropicTools)
expect(result).toBeDefined()
expect(Object.keys(result!)).toHaveLength(1)
expect(result).toHaveProperty('regular_tool')
expect(result).not.toHaveProperty('bash')
})
it('should handle tool with no description', () => {
const anthropicTools: MessageCreateParams['tools'] = [
{
type: 'custom',
name: 'no_desc_tool',
input_schema: {
type: 'object',
properties: {}
}
}
]
const result = convertAnthropicToolsToAiSdk(anthropicTools)
expect(result).toBeDefined()
expect(result).toHaveProperty('no_desc_tool')
expect(result!.no_desc_tool).toHaveProperty('description', '')
})
})
describe('convertAnthropicToAiMessages', () => {
describe('System Messages', () => {
it('should convert string system message', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
system: 'You are a helpful assistant.',
messages: [
{
role: 'user',
content: 'Hello'
}
]
}
const result = convertAnthropicToAiMessages(params)
expect(result).toHaveLength(2)
expect(result[0]).toEqual({
role: 'system',
content: 'You are a helpful assistant.'
})
})
it('should convert array system message', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
system: [
{ type: 'text', text: 'Instruction 1' },
{ type: 'text', text: 'Instruction 2' }
],
messages: [
{
role: 'user',
content: 'Hello'
}
]
}
const result = convertAnthropicToAiMessages(params)
expect(result[0]).toEqual({
role: 'system',
content: 'Instruction 1\nInstruction 2'
})
})
it('should handle no system message', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'user',
content: 'Hello'
}
]
}
const result = convertAnthropicToAiMessages(params)
expect(result[0].role).toBe('user')
})
})
describe('Text Messages', () => {
it('should convert simple string message', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'user',
content: 'Hello, world!'
}
]
}
const result = convertAnthropicToAiMessages(params)
expect(result).toHaveLength(1)
expect(result[0]).toEqual({
role: 'user',
content: 'Hello, world!'
})
})
it('should convert text block array', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'user',
content: [
{ type: 'text', text: 'First part' },
{ type: 'text', text: 'Second part' }
]
}
]
}
const result = convertAnthropicToAiMessages(params)
expect(result).toHaveLength(1)
expect(result[0].role).toBe('user')
expect(Array.isArray(result[0].content)).toBe(true)
if (Array.isArray(result[0].content)) {
expect(result[0].content).toHaveLength(2)
expect(result[0].content[0]).toEqual({ type: 'text', text: 'First part' })
expect(result[0].content[1]).toEqual({ type: 'text', text: 'Second part' })
}
})
it('should convert assistant message', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'user',
content: 'Hello'
},
{
role: 'assistant',
content: 'Hi there!'
}
]
}
const result = convertAnthropicToAiMessages(params)
expect(result).toHaveLength(2)
expect(result[1]).toEqual({
role: 'assistant',
content: 'Hi there!'
})
})
})
describe('Image Messages', () => {
it('should convert base64 image', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'user',
content: [
{
type: 'image',
source: {
type: 'base64',
media_type: 'image/png',
data: 'iVBORw0KGgo='
}
}
]
}
]
}
const result = convertAnthropicToAiMessages(params)
expect(result).toHaveLength(1)
expect(Array.isArray(result[0].content)).toBe(true)
if (Array.isArray(result[0].content)) {
expect(result[0].content).toHaveLength(1)
const imagePart = result[0].content[0]
if (imagePart.type === 'image') {
expect(imagePart.image).toBe('')
}
}
})
it('should convert URL image', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'user',
content: [
{
type: 'image',
source: {
type: 'url',
url: 'https://example.com/image.png'
}
}
]
}
]
}
const result = convertAnthropicToAiMessages(params)
if (Array.isArray(result[0].content)) {
const imagePart = result[0].content[0]
if (imagePart.type === 'image') {
expect(imagePart.image).toBe('https://example.com/image.png')
}
}
})
it('should convert mixed text and image content', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'user',
content: [
{ type: 'text', text: 'Look at this:' },
{
type: 'image',
source: {
type: 'url',
url: 'https://example.com/pic.jpg'
}
}
]
}
]
}
const result = convertAnthropicToAiMessages(params)
if (Array.isArray(result[0].content)) {
expect(result[0].content).toHaveLength(2)
expect(result[0].content[0].type).toBe('text')
expect(result[0].content[1].type).toBe('image')
}
})
})
describe('Tool Messages', () => {
it('should convert tool_use block', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'user',
content: 'What is the weather?'
},
{
role: 'assistant',
content: [
{
type: 'tool_use',
id: 'call_123',
name: 'get_weather',
input: { location: 'San Francisco' }
}
]
}
]
}
const result = convertAnthropicToAiMessages(params)
expect(result).toHaveLength(2)
const assistantMsg = result[1]
expect(assistantMsg.role).toBe('assistant')
if (Array.isArray(assistantMsg.content)) {
expect(assistantMsg.content).toHaveLength(1)
const toolCall = assistantMsg.content[0]
if (toolCall.type === 'tool-call') {
expect(toolCall.toolName).toBe('get_weather')
expect(toolCall.toolCallId).toBe('call_123')
expect(toolCall.input).toEqual({ location: 'San Francisco' })
}
}
})
it('should convert tool_result with string content', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'assistant',
content: [
{
type: 'tool_use',
id: 'call_123',
name: 'get_weather',
input: {}
}
]
},
{
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: 'call_123',
content: 'Temperature is 72°F'
}
]
}
]
}
const result = convertAnthropicToAiMessages(params)
const toolMsg = result[1]
expect(toolMsg.role).toBe('tool')
if (Array.isArray(toolMsg.content)) {
expect(toolMsg.content).toHaveLength(1)
const toolResult = toolMsg.content[0]
if (toolResult.type === 'tool-result') {
expect(toolResult.toolCallId).toBe('call_123')
expect(toolResult.toolName).toBe('get_weather')
if (toolResult.output.type === 'text') {
expect(toolResult.output.value).toBe('Temperature is 72°F')
}
}
}
})
it('should convert tool_result with array content', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'assistant',
content: [
{
type: 'tool_use',
id: 'call_456',
name: 'analyze',
input: {}
}
]
},
{
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: 'call_456',
content: [
{ type: 'text', text: 'Result part 1' },
{ type: 'text', text: 'Result part 2' }
]
}
]
}
]
}
const result = convertAnthropicToAiMessages(params)
const toolMsg = result[1]
if (Array.isArray(toolMsg.content)) {
const toolResult = toolMsg.content[0]
if (toolResult.type === 'tool-result' && toolResult.output.type === 'content') {
expect(toolResult.output.value).toHaveLength(2)
expect(toolResult.output.value[0]).toEqual({ type: 'text', text: 'Result part 1' })
expect(toolResult.output.value[1]).toEqual({ type: 'text', text: 'Result part 2' })
}
}
})
it('should convert tool_result with image content', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'assistant',
content: [
{
type: 'tool_use',
id: 'call_789',
name: 'screenshot',
input: {}
}
]
},
{
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: 'call_789',
content: [
{
type: 'image',
source: {
type: 'base64',
media_type: 'image/png',
data: 'abc123'
}
}
]
}
]
}
]
}
const result = convertAnthropicToAiMessages(params)
const toolMsg = result[1]
if (Array.isArray(toolMsg.content)) {
const toolResult = toolMsg.content[0]
if (toolResult.type === 'tool-result' && toolResult.output.type === 'content') {
expect(toolResult.output.value).toHaveLength(1)
const media = toolResult.output.value[0]
if (media.type === 'media') {
expect(media.data).toBe('abc123')
expect(media.mediaType).toBe('image/png')
}
}
}
})
it('should handle multiple tool calls', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'assistant',
content: [
{
type: 'tool_use',
id: 'call_1',
name: 'tool1',
input: {}
},
{
type: 'tool_use',
id: 'call_2',
name: 'tool2',
input: {}
}
]
}
]
}
const result = convertAnthropicToAiMessages(params)
if (Array.isArray(result[0].content)) {
expect(result[0].content).toHaveLength(2)
expect(result[0].content[0].type).toBe('tool-call')
expect(result[0].content[1].type).toBe('tool-call')
}
})
})
describe('Thinking Content', () => {
it('should convert thinking block to reasoning', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'assistant',
content: [
{
type: 'thinking',
thinking: 'Let me analyze this...',
signature: 'sig123'
},
{
type: 'text',
text: 'Here is my answer'
}
]
}
]
}
const result = convertAnthropicToAiMessages(params)
if (Array.isArray(result[0].content)) {
expect(result[0].content).toHaveLength(2)
const reasoning = result[0].content[0]
if (reasoning.type === 'reasoning') {
expect(reasoning.text).toBe('Let me analyze this...')
}
const text = result[0].content[1]
if (text.type === 'text') {
expect(text.text).toBe('Here is my answer')
}
}
})
it('should convert redacted_thinking to reasoning', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'assistant',
content: [
{
type: 'redacted_thinking',
data: '[Redacted]'
}
]
}
]
}
const result = convertAnthropicToAiMessages(params)
if (Array.isArray(result[0].content)) {
expect(result[0].content).toHaveLength(1)
const reasoning = result[0].content[0]
if (reasoning.type === 'reasoning') {
expect(reasoning.text).toBe('[Redacted]')
}
}
})
})
describe('Multi-turn Conversations', () => {
it('should handle complete conversation flow', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
system: 'You are a helpful assistant.',
messages: [
{
role: 'user',
content: 'What is the weather in SF?'
},
{
role: 'assistant',
content: [
{
type: 'tool_use',
id: 'weather_call',
name: 'get_weather',
input: { location: 'SF' }
}
]
},
{
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: 'weather_call',
content: '72°F and sunny'
}
]
},
{
role: 'assistant',
content: 'The weather in San Francisco is 72°F and sunny.'
}
]
}
const result = convertAnthropicToAiMessages(params)
expect(result).toHaveLength(5)
expect(result[0].role).toBe('system')
expect(result[1].role).toBe('user')
expect(result[2].role).toBe('assistant')
expect(result[3].role).toBe('tool')
expect(result[4].role).toBe('assistant')
})
})
describe('Edge Cases', () => {
it('should handle empty content array for user', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'user',
content: []
}
]
}
const result = convertAnthropicToAiMessages(params)
expect(result).toHaveLength(0)
})
it('should handle empty content array for assistant', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'assistant',
content: []
}
]
}
const result = convertAnthropicToAiMessages(params)
expect(result).toHaveLength(0)
})
it('should handle tool_result without matching tool_use', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: 'unknown_call',
content: 'Some result'
}
]
}
]
}
const result = convertAnthropicToAiMessages(params)
expect(result).toHaveLength(1)
if (Array.isArray(result[0].content)) {
const toolResult = result[0].content[0]
if (toolResult.type === 'tool-result') {
expect(toolResult.toolName).toBe('unknown')
}
}
})
it('should handle tool_result with empty content', () => {
const params: MessageCreateParams = {
model: 'claude-3-5-sonnet-20241022',
max_tokens: 1024,
messages: [
{
role: 'assistant',
content: [
{
type: 'tool_use',
id: 'call_empty',
name: 'empty_tool',
input: {}
}
]
},
{
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: 'call_empty'
}
]
}
]
}
const result = convertAnthropicToAiMessages(params)
const toolMsg = result[1]
if (Array.isArray(toolMsg.content)) {
const toolResult = toolMsg.content[0]
if (toolResult.type === 'tool-result' && toolResult.output.type === 'text') {
expect(toolResult.output.value).toBe('')
}
}
})
})
})
})

View File

@ -2,8 +2,10 @@ import type Anthropic from '@anthropic-ai/sdk'
import type { MessageCreateParams, MessageStreamEvent } from '@anthropic-ai/sdk/resources' import type { MessageCreateParams, MessageStreamEvent } from '@anthropic-ai/sdk/resources'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import anthropicService from '@main/services/AnthropicService' import anthropicService from '@main/services/AnthropicService'
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic' import { buildClaudeCodeSystemMessage, getSdkClient, sanitizeToolsForAnthropic } from '@shared/anthropic'
import type { Provider } from '@types' import type { Provider } from '@types'
import { APICallError, RetryError } from 'ai'
import { net } from 'electron'
import type { Response } from 'express' import type { Response } from 'express'
const logger = loggerService.withContext('MessagesService') const logger = loggerService.withContext('MessagesService')
@ -98,11 +100,30 @@ export class MessagesService {
async getClient(provider: Provider, extraHeaders?: Record<string, string | string[]>): Promise<Anthropic> { async getClient(provider: Provider, extraHeaders?: Record<string, string | string[]>): Promise<Anthropic> {
// Create Anthropic client for the provider // Create Anthropic client for the provider
// Wrap net.fetch to handle compatibility issues:
// 1. net.fetch expects string URLs, not Request objects
// 2. net.fetch doesn't support 'agent' option from Node.js http module
const electronFetch: typeof globalThis.fetch = async (input: URL | RequestInfo, init?: RequestInit) => {
const url = typeof input === 'string' ? input : input instanceof URL ? input.toString() : input.url
// Remove unsupported options for Electron's net.fetch
if (init) {
const initWithAgent = init as RequestInit & { agent?: unknown }
delete initWithAgent.agent
const headers = new Headers(initWithAgent.headers)
if (headers.has('content-length')) {
headers.delete('content-length')
}
initWithAgent.headers = headers
return net.fetch(url, initWithAgent)
}
return net.fetch(url)
}
const context = { fetch: electronFetch }
if (provider.authType === 'oauth') { if (provider.authType === 'oauth') {
const oauthToken = await anthropicService.getValidAccessToken() const oauthToken = await anthropicService.getValidAccessToken()
return getSdkClient(provider, oauthToken, extraHeaders) return getSdkClient(provider, oauthToken, extraHeaders, context)
} }
return getSdkClient(provider, null, extraHeaders) return getSdkClient(provider, null, extraHeaders, context)
} }
prepareHeaders(headers: Record<string, string | string[] | undefined>): Record<string, string | string[]> { prepareHeaders(headers: Record<string, string | string[] | undefined>): Record<string, string | string[]> {
@ -127,7 +148,8 @@ export class MessagesService {
createAnthropicRequest(request: MessageCreateParams, provider: Provider, modelId?: string): MessageCreateParams { createAnthropicRequest(request: MessageCreateParams, provider: Provider, modelId?: string): MessageCreateParams {
const anthropicRequest: MessageCreateParams = { const anthropicRequest: MessageCreateParams = {
...request, ...request,
stream: !!request.stream stream: !!request.stream,
tools: sanitizeToolsForAnthropic(request.tools)
} }
// Override model if provided // Override model if provided
@ -233,9 +255,71 @@ export class MessagesService {
} }
transformError(error: any): { statusCode: number; errorResponse: ErrorResponse } { transformError(error: any): { statusCode: number; errorResponse: ErrorResponse } {
let statusCode = 500 let statusCode: number | undefined = undefined
let errorType = 'api_error' let errorType: string | undefined = undefined
let errorMessage = 'Internal server error' let errorMessage: string | undefined = undefined
const errorMap: Record<number, string> = {
400: 'invalid_request_error',
401: 'authentication_error',
403: 'forbidden_error',
404: 'not_found_error',
429: 'rate_limit_error',
500: 'internal_server_error'
}
// Handle AI SDK RetryError - extract the last error for better error messages
if (RetryError.isInstance(error)) {
const lastError = error.lastError
// If the last error is an APICallError, extract its details
if (APICallError.isInstance(lastError)) {
statusCode = lastError.statusCode || 502
errorMessage = lastError.message
return {
statusCode,
errorResponse: {
type: 'error',
error: {
type: errorMap[statusCode] || 'api_error',
message: `${error.reason}: ${errorMessage}`,
requestId: lastError.name
}
}
}
}
// Fallback for other retry errors
errorMessage = error.message
statusCode = 502
return {
statusCode,
errorResponse: {
type: 'error',
error: {
type: 'api_error',
message: errorMessage,
requestId: error.name
}
}
}
}
if (APICallError.isInstance(error)) {
statusCode = error.statusCode
errorMessage = error.message
if (statusCode) {
return {
statusCode,
errorResponse: {
type: 'error',
error: {
type: errorMap[statusCode] || 'api_error',
message: errorMessage,
requestId: error.name
}
}
}
}
}
const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined
const anthropicError = error?.error const anthropicError = error?.error
@ -277,11 +361,11 @@ export class MessagesService {
typeof errorMessage === 'string' && errorMessage.length > 0 ? errorMessage : 'Internal server error' typeof errorMessage === 'string' && errorMessage.length > 0 ? errorMessage : 'Internal server error'
return { return {
statusCode, statusCode: statusCode ?? 500,
errorResponse: { errorResponse: {
type: 'error', type: 'error',
error: { error: {
type: errorType, type: errorType || 'api_error',
message: safeErrorMessage, message: safeErrorMessage,
requestId: error?.request_id requestId: error?.request_id
} }

View File

@ -1,13 +1,6 @@
import { isEmpty } from 'lodash'
import type { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels' import type { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels'
import { loggerService } from '../../services/LoggerService' import { loggerService } from '../../services/LoggerService'
import { import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils'
getAvailableProviders,
getProviderAnthropicModelChecker,
listAllAvailableModels,
transformModelToOpenAI
} from '../utils'
const logger = loggerService.withContext('ModelsService') const logger = loggerService.withContext('ModelsService')
@ -20,11 +13,12 @@ export class ModelsService {
try { try {
logger.debug('Getting available models from providers', { filter }) logger.debug('Getting available models from providers', { filter })
let providers = await getAvailableProviders() const providers = await getAvailableProviders()
if (filter.providerType === 'anthropic') { // Note: When providerType === 'anthropic', we now return ALL available models
providers = providers.filter((p) => p.type === 'anthropic' || !isEmpty(p.anthropicApiHost?.trim())) // because the API Server's unified adapter (AiSdkToAnthropicSSE) can convert
} // any provider's response to Anthropic SSE format. This enables Claude Code Agent
// to work with OpenAI, Gemini, and other providers transparently.
const models = await listAllAvailableModels(providers) const models = await listAllAvailableModels(providers)
// Use Map to deduplicate models by their full ID (provider:model_id) // Use Map to deduplicate models by their full ID (provider:model_id)
@ -32,20 +26,11 @@ export class ModelsService {
for (const model of models) { for (const model of models) {
const provider = providers.find((p) => p.id === model.provider) const provider = providers.find((p) => p.id === model.provider)
// logger.debug(`Processing model ${model.id}`)
if (!provider) { if (!provider) {
logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`) logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`)
continue continue
} }
if (filter.providerType === 'anthropic') {
const checker = getProviderAnthropicModelChecker(provider.id)
if (!checker(model)) {
logger.debug(`Skipping model ${model.id} from ${model.provider}. Reason: Not an Anthropic model.`)
continue
}
}
const openAIModel = transformModelToOpenAI(model, provider) const openAIModel = transformModelToOpenAI(model, provider)
const fullModelId = openAIModel.id // This is already in format "provider:model_id" const fullModelId = openAIModel.id // This is already in format "provider:model_id"

View File

@ -0,0 +1,45 @@
/**
* Reasoning Cache Service
*
* Manages reasoning-related caching for AI providers that support thinking/reasoning modes.
* This includes Google Gemini's thought signatures and OpenRouter's reasoning details.
*/
import type { ReasoningDetailUnion } from '@main/apiServer/adapters/openrouter'
import { CacheService } from '@main/services/CacheService'
/**
* Interface for reasoning cache
*/
export interface IReasoningCache<T> {
set(key: string, value: T): void
get(key: string): T | undefined
}
/**
* Cache duration: 30 minutes
* Reasoning data is typically only needed within a short conversation context
*/
const REASONING_CACHE_DURATION = 30 * 60 * 1000
/**
* Google Gemini reasoning cache
*
* Stores thought signatures for Gemini 3 models to handle multi-turn conversations
* where the model needs to maintain thinking context across tool calls.
*/
export const googleReasoningCache: IReasoningCache<string> = {
set: (key, value) => CacheService.set(`google-reasoning:${key}`, value, REASONING_CACHE_DURATION),
get: (key) => CacheService.get(`google-reasoning:${key}`) || undefined
}
/**
* OpenRouter reasoning cache
*
* Stores reasoning details from OpenRouter responses to preserve thinking tokens
* and reasoning metadata across the conversation flow.
*/
export const openRouterReasoningCache: IReasoningCache<ReasoningDetailUnion[]> = {
set: (key, value) => CacheService.set(`openrouter-reasoning:${key}`, value, REASONING_CACHE_DURATION),
get: (key) => CacheService.get(`openrouter-reasoning:${key}`) || undefined
}

View File

@ -0,0 +1,764 @@
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import type { JSONSchema7, LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
import type {
ImageBlockParam,
MessageCreateParams,
TextBlockParam,
Tool as AnthropicTool
} from '@anthropic-ai/sdk/resources/messages'
import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core'
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@main/apiServer/adapters'
import { generateSignature as cherryaiGenerateSignature } from '@main/integration/cherryai'
import anthropicService from '@main/services/AnthropicService'
import copilotService from '@main/services/CopilotService'
import { reduxService } from '@main/services/ReduxService'
import type { OpenRouterProviderOptions } from '@openrouter/ai-sdk-provider'
import { isGemini3ModelId } from '@shared/ai-sdk-middlewares'
import {
type AiSdkConfig,
type AiSdkConfigContext,
formatProviderApiHost,
initializeSharedProviders,
isAnthropicProvider,
isGeminiProvider,
isOpenAIProvider,
type MinimalProvider,
type ProviderFormatContext,
providerToAiSdkConfig as sharedProviderToAiSdkConfig,
resolveActualProvider,
SystemProviderIds
} from '@shared/provider'
import { COPILOT_DEFAULT_HEADERS } from '@shared/provider/constant'
import { defaultAppHeaders } from '@shared/utils'
import type { Provider } from '@types'
import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool as AiSdkTool } from 'ai'
import { simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel, zodSchema } from 'ai'
import { net } from 'electron'
import type { Response } from 'express'
import * as z from 'zod'
import { googleReasoningCache, openRouterReasoningCache } from './reasoning-cache'
const logger = loggerService.withContext('UnifiedMessagesService')
const MAGIC_STRING = 'skip_thought_signature_validator'
function sanitizeJson(value: unknown): JSONValue {
return JSON.parse(JSON.stringify(value))
}
initializeSharedProviders({
warn: (message) => logger.warn(message),
error: (message, error) => logger.error(message, error)
})
/**
* Configuration for unified message streaming
*/
export interface UnifiedStreamConfig {
response: Response
provider: Provider
modelId: string
params: MessageCreateParams
onError?: (error: unknown) => void
onComplete?: () => void
/**
* Optional AI SDK middlewares to apply
*/
middlewares?: LanguageModelV2Middleware[]
/**
* Optional AI Core plugins to use with the executor
*/
plugins?: AiPlugin[]
}
/**
* Configuration for non-streaming message generation
*/
export interface GenerateUnifiedMessageConfig {
provider: Provider
modelId: string
params: MessageCreateParams
middlewares?: LanguageModelV2Middleware[]
plugins?: AiPlugin[]
}
function getMainProcessFormatContext(): ProviderFormatContext {
const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai')
return {
vertex: {
project: vertexSettings?.projectId || 'default-project',
location: vertexSettings?.location || 'us-central1'
}
}
}
function isSupportStreamOptionsProvider(provider: MinimalProvider): boolean {
const NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS = ['mistral'] as const
return !NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS.some((pid) => pid === provider.id)
}
const mainProcessSdkContext: AiSdkConfigContext = {
isSupportStreamOptionsProvider,
getIncludeUsageSetting: () =>
reduxService.selectSync<boolean | undefined>('state.settings.openAI?.streamOptions?.includeUsage'),
fetch: net.fetch as typeof globalThis.fetch
}
function getActualProvider(provider: Provider, modelId: string): Provider {
const model = provider.models?.find((m) => m.id === modelId)
if (!model) return provider
return resolveActualProvider(provider, model)
}
function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig {
const actualProvider = getActualProvider(provider, modelId)
const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext())
return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext)
}
function convertAnthropicToolResultToAiSdk(
content: string | Array<TextBlockParam | ImageBlockParam>
): LanguageModelV2ToolResultOutput {
if (typeof content === 'string') {
return { type: 'text', value: content }
}
const values: Array<{ type: 'text'; text: string } | { type: 'media'; data: string; mediaType: string }> = []
for (const block of content) {
if (block.type === 'text') {
values.push({ type: 'text', text: block.text })
} else if (block.type === 'image') {
values.push({
type: 'media',
data: block.source.type === 'base64' ? block.source.data : block.source.url,
mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png'
})
}
}
return { type: 'content', value: values }
}
/**
* JSON Schema type for tool input schemas
*/
export type JsonSchemaLike = JSONSchema7
/**
* Convert JSON Schema to Zod schema
* This avoids non-standard fields like input_examples that Anthropic doesn't support
* TODO: Anthropic/beta support input_examples
*/
export function jsonSchemaToZod(schema: JsonSchemaLike): z.ZodTypeAny {
const schemaType = schema.type
const enumValues = schema.enum
const description = schema.description
// Handle enum first
if (enumValues && Array.isArray(enumValues) && enumValues.length > 0) {
if (enumValues.every((v) => typeof v === 'string')) {
const zodEnum = z.enum(enumValues as [string, ...string[]])
return description ? zodEnum.describe(description) : zodEnum
}
// For non-string enums, use union of literals
const literals = enumValues.map((v) => z.literal(v as string | number | boolean))
if (literals.length === 1) {
return description ? literals[0].describe(description) : literals[0]
}
const zodUnion = z.union(literals as unknown as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
return description ? zodUnion.describe(description) : zodUnion
}
// Handle union types (type: ["string", "null"])
if (Array.isArray(schemaType)) {
const schemas = schemaType.map((t) =>
jsonSchemaToZod({
...schema,
type: t,
enum: undefined
})
)
if (schemas.length === 1) {
return schemas[0]
}
return z.union(schemas as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
}
// Handle by type
switch (schemaType) {
case 'string': {
let zodString = z.string()
if (typeof schema.minLength === 'number') zodString = zodString.min(schema.minLength)
if (typeof schema.maxLength === 'number') zodString = zodString.max(schema.maxLength)
if (typeof schema.pattern === 'string') zodString = zodString.regex(new RegExp(schema.pattern))
return description ? zodString.describe(description) : zodString
}
case 'number':
case 'integer': {
let zodNumber = schemaType === 'integer' ? z.number().int() : z.number()
if (typeof schema.minimum === 'number') zodNumber = zodNumber.min(schema.minimum)
if (typeof schema.maximum === 'number') zodNumber = zodNumber.max(schema.maximum)
return description ? zodNumber.describe(description) : zodNumber
}
case 'boolean': {
const zodBoolean = z.boolean()
return description ? zodBoolean.describe(description) : zodBoolean
}
case 'null':
return z.null()
case 'array': {
const items = schema.items
let zodArray: z.ZodArray<z.ZodTypeAny>
if (items && typeof items === 'object' && !Array.isArray(items)) {
zodArray = z.array(jsonSchemaToZod(items as JsonSchemaLike))
} else {
zodArray = z.array(z.unknown())
}
if (typeof schema.minItems === 'number') zodArray = zodArray.min(schema.minItems)
if (typeof schema.maxItems === 'number') zodArray = zodArray.max(schema.maxItems)
return description ? zodArray.describe(description) : zodArray
}
case 'object': {
const properties = schema.properties
const required = schema.required || []
// Always use z.object() to ensure "properties" field is present in output schema
// OpenAI requires explicit properties field even for empty objects
const shape: Record<string, z.ZodTypeAny> = {}
if (properties && typeof properties === 'object') {
for (const [key, propSchema] of Object.entries(properties)) {
if (typeof propSchema === 'boolean') {
shape[key] = propSchema ? z.unknown() : z.never()
} else {
const zodProp = jsonSchemaToZod(propSchema as JsonSchemaLike)
shape[key] = required.includes(key) ? zodProp : zodProp.optional()
}
}
}
const zodObject = z.object(shape)
return description ? zodObject.describe(description) : zodObject
}
default:
// Unknown type, use z.unknown()
return z.unknown()
}
}
export function convertAnthropicToolsToAiSdk(
tools: MessageCreateParams['tools']
): Record<string, AiSdkTool> | undefined {
if (!tools || tools.length === 0) return undefined
const aiSdkTools: Record<string, AiSdkTool> = {}
for (const anthropicTool of tools) {
if (anthropicTool.type === 'bash_20250124') continue
const toolDef = anthropicTool as AnthropicTool
const rawSchema = toolDef.input_schema
// Convert Anthropic's InputSchema to JSONSchema7-compatible format
const schema = jsonSchemaToZod(rawSchema as JsonSchemaLike)
// Use tool() with inputSchema (AI SDK v5 API)
const aiTool = tool({
description: toolDef.description || '',
inputSchema: zodSchema(schema)
})
aiSdkTools[toolDef.name] = aiTool
}
return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined
}
export function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage[] {
const messages: ModelMessage[] = []
// System message
if (params.system) {
if (typeof params.system === 'string') {
messages.push({ role: 'system', content: params.system })
} else if (Array.isArray(params.system)) {
const systemText = params.system
.filter((block) => block.type === 'text')
.map((block) => block.text)
.join('\n')
if (systemText) {
messages.push({ role: 'system', content: systemText })
}
}
}
const toolCallIdToName = new Map<string, string>()
for (const msg of params.messages) {
if (Array.isArray(msg.content)) {
for (const block of msg.content) {
if (block.type === 'tool_use') {
toolCallIdToName.set(block.id, block.name)
}
}
}
}
// User/assistant messages
for (const msg of params.messages) {
if (typeof msg.content === 'string') {
messages.push({
role: msg.role === 'user' ? 'user' : 'assistant',
content: msg.content
})
} else if (Array.isArray(msg.content)) {
const textParts: TextPart[] = []
const imageParts: ImagePart[] = []
const reasoningParts: ReasoningPart[] = []
const toolCallParts: ToolCallPart[] = []
const toolResultParts: ToolResultPart[] = []
for (const block of msg.content) {
if (block.type === 'text') {
textParts.push({ type: 'text', text: block.text })
} else if (block.type === 'thinking') {
reasoningParts.push({ type: 'reasoning', text: block.thinking })
} else if (block.type === 'redacted_thinking') {
reasoningParts.push({ type: 'reasoning', text: block.data })
} else if (block.type === 'image') {
const source = block.source
if (source.type === 'base64') {
imageParts.push({ type: 'image', image: `data:${source.media_type};base64,${source.data}` })
} else if (source.type === 'url') {
imageParts.push({ type: 'image', image: source.url })
}
} else if (block.type === 'tool_use') {
const options: ProviderOptions = {}
logger.debug('Processing tool call block', { block, msgRole: msg.role, model: params.model })
if (isGemini3ModelId(params.model)) {
if (googleReasoningCache.get(`google-${block.name}`)) {
options.google = {
thoughtSignature: MAGIC_STRING
}
}
}
if (openRouterReasoningCache.get(`openrouter-${block.id}`)) {
options.openrouter = {
reasoning_details:
(sanitizeJson(openRouterReasoningCache.get(`openrouter-${block.id}`)) as JSONValue[]) || []
}
}
toolCallParts.push({
type: 'tool-call',
toolName: block.name,
toolCallId: block.id,
input: block.input,
providerOptions: options
})
} else if (block.type === 'tool_result') {
// Look up toolName from the pre-built map (covers cross-message references)
const toolName = toolCallIdToName.get(block.tool_use_id) || 'unknown'
toolResultParts.push({
type: 'tool-result',
toolCallId: block.tool_use_id,
toolName,
output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' }
})
}
}
if (toolResultParts.length > 0) {
messages.push({ role: 'tool', content: [...toolResultParts] })
}
if (msg.role === 'user') {
const userContent = [...textParts, ...imageParts]
if (userContent.length > 0) {
messages.push({ role: 'user', content: userContent })
}
} else {
const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts]
if (assistantContent.length > 0) {
let providerOptions: ProviderOptions | undefined = undefined
if (openRouterReasoningCache.get('openrouter')) {
providerOptions = {
openrouter: {
reasoning_details: (sanitizeJson(openRouterReasoningCache.get('openrouter')) as JSONValue[]) || []
}
}
} else if (isGemini3ModelId(params.model)) {
providerOptions = {
google: {
thoughtSignature: MAGIC_STRING
}
}
}
messages.push({ role: 'assistant', content: assistantContent, providerOptions })
}
}
}
}
return messages
}
interface ExecuteStreamConfig {
provider: Provider
modelId: string
params: MessageCreateParams
middlewares?: LanguageModelV2Middleware[]
plugins?: AiPlugin[]
onEvent?: (event: Parameters<typeof formatSSEEvent>[0]) => void
}
/**
* Create AI SDK provider instance from config
* Similar to renderer's createAiSdkProvider
*/
async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider> {
let providerId = config.providerId
// Handle special provider modes (same as renderer)
if (providerId === 'openai' && config.options?.mode === 'chat') {
providerId = 'openai-chat'
} else if (providerId === 'azure' && config.options?.mode === 'responses') {
providerId = 'azure-responses'
} else if (providerId === 'cherryin' && config.options?.mode === 'chat') {
providerId = 'cherryin-chat'
}
const provider = await createProviderCore(providerId, config.options)
return provider
}
/**
* Prepare special provider configuration for providers that need dynamic tokens
* Similar to renderer's prepareSpecialProviderConfig
*/
async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkConfig): Promise<AiSdkConfig> {
switch (provider.id) {
case 'copilot': {
const storedHeaders =
((await reduxService.select('state.copilot.defaultHeaders')) as Record<string, string> | null) ?? {}
const headers: Record<string, string> = {
...COPILOT_DEFAULT_HEADERS,
...storedHeaders
}
try {
const { token } = await copilotService.getToken(null as any, headers)
config.options.apiKey = token
const existingHeaders = (config.options.headers as Record<string, string> | undefined) ?? {}
config.options.headers = {
...headers,
...existingHeaders
}
} catch (error) {
logger.error('Failed to get Copilot token', error as Error)
throw new Error('Failed to get Copilot token. Please re-authorize Copilot.')
}
break
}
case 'anthropic': {
if (provider.authType === 'oauth') {
try {
const oauthToken = await anthropicService.getValidAccessToken()
if (!oauthToken) {
throw new Error('Anthropic OAuth token not available. Please re-authorize.')
}
config.options = {
...config.options,
headers: {
...(config.options.headers ? config.options.headers : {}),
'Content-Type': 'application/json',
'anthropic-version': '2023-06-01',
'anthropic-beta': 'oauth-2025-04-20',
Authorization: `Bearer ${oauthToken}`
},
baseURL: 'https://api.anthropic.com/v1',
apiKey: ''
}
} catch (error) {
logger.error('Failed to get Anthropic OAuth token', error as Error)
throw new Error('Failed to get Anthropic OAuth token. Please re-authorize.')
}
}
break
}
case 'cherryai': {
// Create a signed fetch wrapper for cherryai
const baseFetch = net.fetch as typeof globalThis.fetch
config.options.fetch = async (url: RequestInfo | URL, options?: RequestInit) => {
if (!options?.body) {
return baseFetch(url, options)
}
const signature = cherryaiGenerateSignature({
method: 'POST',
path: '/chat/completions',
query: '',
body: JSON.parse(options.body as string)
})
return baseFetch(url, {
...options,
headers: {
...(options.headers as Record<string, string>),
...signature
}
})
}
break
}
}
return config
}
function mapAnthropicThinkToAISdkProviderOptions(
provider: Provider,
config: MessageCreateParams['thinking']
): ProviderOptions | undefined {
if (!config) return undefined
if (isAnthropicProvider(provider)) {
return {
anthropic: {
...mapToAnthropicProviderOptions(config)
}
}
}
if (isGeminiProvider(provider)) {
return {
google: {
...mapToGeminiProviderOptions(config)
}
}
}
if (isOpenAIProvider(provider)) {
return {
openai: {
...mapToOpenAIProviderOptions(config)
}
}
}
if (provider.id === SystemProviderIds.openrouter) {
return {
openrouter: {
...mapToOpenRouterProviderOptions(config)
}
}
}
return undefined
}
function mapToAnthropicProviderOptions(config: NonNullable<MessageCreateParams['thinking']>): AnthropicProviderOptions {
return {
thinking: {
type: config.type,
budgetTokens: config.type === 'enabled' ? config.budget_tokens : undefined
}
}
}
function mapToGeminiProviderOptions(
config: NonNullable<MessageCreateParams['thinking']>
): GoogleGenerativeAIProviderOptions {
return {
thinkingConfig: {
thinkingBudget: config.type === 'enabled' ? config.budget_tokens : -1,
includeThoughts: config.type === 'enabled'
}
}
}
function mapToOpenAIProviderOptions(
config: NonNullable<MessageCreateParams['thinking']>
): OpenAIResponsesProviderOptions {
return {
reasoningEffort: config.type === 'enabled' ? 'high' : 'none'
}
}
function mapToOpenRouterProviderOptions(
config: NonNullable<MessageCreateParams['thinking']>
): OpenRouterProviderOptions {
return {
reasoning: {
enabled: config.type === 'enabled',
effort: 'high'
}
}
}
/**
* Core stream execution function - single source of truth for AI SDK calls
*/
async function executeStream(config: ExecuteStreamConfig): Promise<AiSdkToAnthropicSSE> {
const { provider, modelId, params, middlewares = [], plugins = [], onEvent } = config
// Convert provider config to AI SDK config
let sdkConfig = providerToAiSdkConfig(provider, modelId)
// Prepare special provider config (Copilot, Anthropic OAuth, etc.)
sdkConfig = await prepareSpecialProviderConfig(provider, sdkConfig)
// Create provider instance and get language model
const aiSdkProvider = await createAiSdkProvider(sdkConfig)
const baseModel = aiSdkProvider.languageModel(modelId)
// Apply middlewares if present
const model =
middlewares.length > 0 && typeof baseModel === 'object'
? (wrapLanguageModel({ model: baseModel, middleware: middlewares }) as typeof baseModel)
: baseModel
// Create executor with plugins
const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins)
// Convert messages and tools
const coreMessages = convertAnthropicToAiMessages(params)
const tools = convertAnthropicToolsToAiSdk(params.tools)
// Create the adapter
const adapter = new AiSdkToAnthropicSSE({
model: `${provider.id}:${modelId}`,
onEvent: onEvent || (() => {})
})
const result = await executor.streamText({
model,
messages: coreMessages,
// FIXME: Claude Code传入的maxToken会超出有些模型限制需做特殊处理可能在v2好修复一点现在维护的成本有点高
// 已知: 豆包
maxOutputTokens: params.max_tokens,
temperature: params.temperature,
topP: params.top_p,
topK: params.top_k,
stopSequences: params.stop_sequences,
stopWhen: stepCountIs(100),
headers: defaultAppHeaders(),
tools,
providerOptions: mapAnthropicThinkToAISdkProviderOptions(provider, params.thinking)
})
// Process the stream through the adapter
await adapter.processStream(result.fullStream)
return adapter
}
/**
* Stream a message request using AI SDK executor and convert to Anthropic SSE format
*/
export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise<void> {
const { response, provider, modelId, params, onError, onComplete, middlewares = [], plugins = [] } = config
logger.info('Starting unified message stream', {
providerId: provider.id,
providerType: provider.type,
modelId,
stream: params.stream,
middlewareCount: middlewares.length,
pluginCount: plugins.length
})
try {
response.setHeader('Content-Type', 'text/event-stream')
response.setHeader('Cache-Control', 'no-cache')
response.setHeader('Connection', 'keep-alive')
response.setHeader('X-Accel-Buffering', 'no')
await executeStream({
provider,
modelId,
params,
middlewares,
plugins,
onEvent: (event) => {
logger.silly('Streaming event', { eventType: event.type })
const sseData = formatSSEEvent(event)
response.write(sseData)
}
})
// Send done marker
response.write(formatSSEDone())
response.end()
logger.info('Unified message stream completed', { providerId: provider.id, modelId })
onComplete?.()
} catch (error) {
logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId })
onError?.(error)
throw error
}
}
/**
* Generate a non-streaming message response
*
* Uses simulateStreamingMiddleware to reuse the same streaming logic,
* similar to renderer's ModernAiProvider pattern.
*/
export async function generateUnifiedMessage(
providerOrConfig: Provider | GenerateUnifiedMessageConfig,
modelId?: string,
params?: MessageCreateParams
): Promise<ReturnType<typeof AiSdkToAnthropicSSE.prototype.buildNonStreamingResponse>> {
// Support both old signature and new config-based signature
let config: GenerateUnifiedMessageConfig
if ('provider' in providerOrConfig && 'modelId' in providerOrConfig && 'params' in providerOrConfig) {
config = providerOrConfig
} else {
config = {
provider: providerOrConfig as Provider,
modelId: modelId!,
params: params!
}
}
const { provider, middlewares = [], plugins = [] } = config
logger.info('Starting unified message generation', {
providerId: provider.id,
providerType: provider.type,
modelId: config.modelId,
middlewareCount: middlewares.length,
pluginCount: plugins.length
})
try {
// Add simulateStreamingMiddleware to reuse streaming logic for non-streaming
const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares]
const adapter = await executeStream({
provider,
modelId: config.modelId,
params: config.params,
middlewares: allMiddlewares,
plugins
})
const finalResponse = adapter.buildNonStreamingResponse()
logger.info('Unified message generation completed', {
providerId: provider.id,
modelId: config.modelId
})
return finalResponse
} catch (error) {
logger.error('Error in unified message generation', error as Error, {
providerId: provider.id,
modelId: config.modelId
})
throw error
}
}
export default {
streamUnifiedMessages,
generateUnifiedMessage
}

View File

@ -1,7 +1,7 @@
import { CacheService } from '@main/services/CacheService' import { CacheService } from '@main/services/CacheService'
import { loggerService } from '@main/services/LoggerService' import { loggerService } from '@main/services/LoggerService'
import { reduxService } from '@main/services/ReduxService' import { reduxService } from '@main/services/ReduxService'
import { isSiliconAnthropicCompatibleModel } from '@shared/config/providers' import { isPpioAnthropicCompatibleModel, isSiliconAnthropicCompatibleModel } from '@shared/config/providers'
import type { ApiModel, Model, Provider } from '@types' import type { ApiModel, Model, Provider } from '@types'
const logger = loggerService.withContext('ApiServerUtils') const logger = loggerService.withContext('ApiServerUtils')
@ -28,10 +28,9 @@ export async function getAvailableProviders(): Promise<Provider[]> {
return [] return []
} }
// Support OpenAI and Anthropic type providers for API server // Support all provider types that AI SDK can handle
const supportedProviders = providers.filter( // The unified-messages service uses AI SDK which supports many providers
(p: Provider) => p.enabled && (p.type === 'openai' || p.type === 'anthropic') const supportedProviders = providers.filter((p: Provider) => p.enabled)
)
// Cache the filtered results // Cache the filtered results
CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL) CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL)
@ -160,7 +159,7 @@ export async function validateModelId(model: string): Promise<{
valid: false, valid: false,
error: { error: {
type: 'provider_not_found', type: 'provider_not_found',
message: `Provider '${providerId}' not found, not enabled, or not supported. Only OpenAI providers are currently supported.`, message: `Provider '${providerId}' not found or not enabled.`,
code: 'provider_not_found' code: 'provider_not_found'
} }
} }
@ -262,14 +261,8 @@ export function validateProvider(provider: Provider): boolean {
return false return false
} }
// Support OpenAI and Anthropic type providers // AI SDK supports many provider types, no longer need to filter by type
if (provider.type !== 'openai' && provider.type !== 'anthropic') { // The unified-messages service handles all supported types
logger.debug('Provider type not supported', {
providerId: provider.id,
providerType: provider.type
})
return false
}
return true return true
} catch (error: any) { } catch (error: any) {
@ -290,8 +283,39 @@ export const getProviderAnthropicModelChecker = (providerId: string): ((m: Model
return (m: Model) => m.id.includes('claude') return (m: Model) => m.id.includes('claude')
case 'silicon': case 'silicon':
return (m: Model) => isSiliconAnthropicCompatibleModel(m.id) return (m: Model) => isSiliconAnthropicCompatibleModel(m.id)
case 'ppio':
return (m: Model) => isPpioAnthropicCompatibleModel(m.id)
default: default:
// allow all models when checker not configured // allow all models when checker not configured
return () => true return () => true
} }
} }
/**
* Check if a specific model is compatible with Anthropic API for a given provider.
*
* This is used for fine-grained routing decisions at the model level.
* For aggregated providers (like Silicon), only certain models support the Anthropic API endpoint.
*
* @param provider - The provider to check
* @param modelId - The model ID to check (without provider prefix)
* @returns true if the model supports Anthropic API endpoint
*/
export function isModelAnthropicCompatible(provider: Provider, modelId: string): boolean {
const checker = getProviderAnthropicModelChecker(provider.id)
const model = provider.models?.find((m) => m.id === modelId)
if (model) {
return checker(model)
}
const minimalModel: Model = {
id: modelId,
name: modelId,
provider: provider.id,
group: ''
}
return checker(minimalModel)
}

View File

@ -87,6 +87,7 @@ export class ClaudeStreamState {
private pendingUsage: PendingUsageState = {} private pendingUsage: PendingUsageState = {}
private pendingToolCalls = new Map<string, PendingToolCall>() private pendingToolCalls = new Map<string, PendingToolCall>()
private stepActive = false private stepActive = false
private _streamFinished = false
constructor(options: ClaudeStreamStateOptions) { constructor(options: ClaudeStreamStateOptions) {
this.logger = loggerService.withContext('ClaudeStreamState') this.logger = loggerService.withContext('ClaudeStreamState')
@ -289,6 +290,16 @@ export class ClaudeStreamState {
getNamespacedToolCallId(rawToolCallId: string): string { getNamespacedToolCallId(rawToolCallId: string): string {
return buildNamespacedToolCallId(this.agentSessionId, rawToolCallId) return buildNamespacedToolCallId(this.agentSessionId, rawToolCallId)
} }
/** Marks the stream as finished (either completed or errored). */
markFinished(): void {
this._streamFinished = true
}
/** Returns true if the stream has already emitted a terminal event. */
isFinished(): boolean {
return this._streamFinished
}
} }
export type { PendingToolCall } export type { PendingToolCall }

View File

@ -87,18 +87,14 @@ class ClaudeCodeService implements AgentServiceInterface {
}) })
return aiStream return aiStream
} }
if ( // Validate provider has required configuration
(modelInfo.provider?.type !== 'anthropic' && // Note: We no longer restrict to anthropic type only - the API Server's unified adapter
(modelInfo.provider?.anthropicApiHost === undefined || modelInfo.provider.anthropicApiHost.trim() === '')) || // handles format conversion for any provider type (OpenAI, Gemini, etc.)
modelInfo.provider.apiKey === '' if (!modelInfo.provider?.apiKey) {
) { logger.error('Provider API key is missing', { modelInfo })
logger.error('Anthropic provider configuration is missing', {
modelInfo
})
aiStream.emit('data', { aiStream.emit('data', {
type: 'error', type: 'error',
error: new Error(`Invalid provider type '${modelInfo.provider?.type}'. Expected 'anthropic' provider type.`) error: new Error(`Provider '${modelInfo.provider?.id}' is missing API key configuration.`)
}) })
return aiStream return aiStream
} }
@ -112,15 +108,14 @@ class ClaudeCodeService implements AgentServiceInterface {
// Auto-discover Git Bash path on Windows (already logs internally) // Auto-discover Git Bash path on Windows (already logs internally)
const customGitBashPath = isWin ? autoDiscoverGitBash() : null const customGitBashPath = isWin ? autoDiscoverGitBash() : null
// Route through local API Server which handles format conversion via unified adapter
// This enables Claude Code Agent to work with any provider (OpenAI, Gemini, etc.)
// The API Server converts AI SDK responses to Anthropic SSE format transparently
const env = { const env = {
...loginShellEnvWithoutProxies, ...loginShellEnvWithoutProxies,
// TODO: fix the proxy api server ANTHROPIC_API_KEY: apiConfig.apiKey,
// ANTHROPIC_API_KEY: apiConfig.apiKey, ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey,
// ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey, ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`,
// ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`,
ANTHROPIC_API_KEY: modelInfo.provider.apiKey,
ANTHROPIC_AUTH_TOKEN: modelInfo.provider.apiKey,
ANTHROPIC_BASE_URL: modelInfo.provider.anthropicApiHost?.trim() || modelInfo.provider.apiHost,
ANTHROPIC_MODEL: modelInfo.modelId, ANTHROPIC_MODEL: modelInfo.modelId,
ANTHROPIC_DEFAULT_OPUS_MODEL: modelInfo.modelId, ANTHROPIC_DEFAULT_OPUS_MODEL: modelInfo.modelId,
ANTHROPIC_DEFAULT_SONNET_MODEL: modelInfo.modelId, ANTHROPIC_DEFAULT_SONNET_MODEL: modelInfo.modelId,
@ -545,6 +540,19 @@ class ClaudeCodeService implements AgentServiceInterface {
return return
} }
// Skip emitting error if stream already finished (error was handled via result message)
if (streamState.isFinished()) {
logger.debug('SDK process exited after stream finished, skipping duplicate error event', {
duration,
error: errorObj instanceof Error ? { name: errorObj.name, message: errorObj.message } : String(errorObj)
})
// Still emit complete to signal stream end
stream.emit('data', {
type: 'complete'
})
return
}
errorChunks.push(errorObj instanceof Error ? errorObj.message : String(errorObj)) errorChunks.push(errorObj instanceof Error ? errorObj.message : String(errorObj))
const errorMessage = errorChunks.join('\n\n') const errorMessage = errorChunks.join('\n\n')
logger.error('SDK query failed', { logger.error('SDK query failed', {

View File

@ -121,7 +121,7 @@ export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state:
case 'system': case 'system':
return handleSystemMessage(sdkMessage) return handleSystemMessage(sdkMessage)
case 'result': case 'result':
return handleResultMessage(sdkMessage) return handleResultMessage(sdkMessage, state)
default: default:
logger.warn('Unknown SDKMessage type', { type: (sdkMessage as any).type }) logger.warn('Unknown SDKMessage type', { type: (sdkMessage as any).type })
return [] return []
@ -193,6 +193,30 @@ function handleAssistantMessage(
} }
break break
} }
case 'thinking':
case 'redacted_thinking': {
const thinkingText = block.type === 'thinking' ? block.thinking : block.data
if (thinkingText) {
const id = generateMessageId()
chunks.push({
type: 'reasoning-start',
id,
providerMetadata
})
chunks.push({
type: 'reasoning-delta',
id,
text: thinkingText,
providerMetadata
})
chunks.push({
type: 'reasoning-end',
id,
providerMetadata
})
}
break
}
case 'tool_use': case 'tool_use':
handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks) handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks)
break break
@ -445,7 +469,11 @@ function handleStreamEvent(
case 'content_block_stop': { case 'content_block_stop': {
const block = state.closeBlock(event.index) const block = state.closeBlock(event.index)
if (!block) { if (!block) {
logger.warn('Received content_block_stop for unknown index', { index: event.index }) // Some providers (e.g., Gemini) send content via assistant message before stream events,
// so the block may not exist in state. This is expected behavior, not an error.
logger.debug('Received content_block_stop for unknown index (may be from non-streaming content)', {
index: event.index
})
break break
} }
@ -679,7 +707,13 @@ function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>):
* Successful runs yield a `finish` frame with aggregated usage metrics, while * Successful runs yield a `finish` frame with aggregated usage metrics, while
* failures are surfaced as `error` frames. * failures are surfaced as `error` frames.
*/ */
function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>): AgentStreamPart[] { function handleResultMessage(
message: Extract<SDKMessage, { type: 'result' }>,
state: ClaudeStreamState
): AgentStreamPart[] {
// Mark stream as finished to prevent duplicate error events when SDK process exits
state.markFinished()
const chunks: AgentStreamPart[] = [] const chunks: AgentStreamPart[] = []
let usage: LanguageModelUsage | undefined let usage: LanguageModelUsage | undefined
@ -691,7 +725,6 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
} }
} }
if (message.subtype === 'success') {
chunks.push({ chunks.push({
type: 'finish', type: 'finish',
totalUsage: usage ?? emptyUsage, totalUsage: usage ?? emptyUsage,
@ -704,13 +737,21 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
raw: message raw: message
} }
} as AgentStreamPart) } as AgentStreamPart)
} else { if (message.subtype !== 'success') {
chunks.push({ chunks.push({
type: 'error', type: 'error',
error: { error: {
message: `${message.subtype}: Process failed after ${message.num_turns} turns` message: `${message.subtype}: Process failed after ${message.num_turns} turns`
} }
} as AgentStreamPart) } as AgentStreamPart)
} else {
if (message.is_error) {
const errorMatch = message.result.match(/\{.*\}/)
if (errorMatch) {
const errorDetail = JSON.parse(errorMatch[0])
chunks.push(errorDetail)
}
}
} }
return chunks return chunks
} }

View File

@ -24,7 +24,7 @@ export class VertexAPIClient extends GeminiAPIClient {
this.anthropicVertexClient = new AnthropicVertexClient(provider) this.anthropicVertexClient = new AnthropicVertexClient(provider)
// 如果传入的是普通 Provider转换为 VertexProvider // 如果传入的是普通 Provider转换为 VertexProvider
if (isVertexProvider(provider)) { if (isVertexProvider(provider)) {
this.vertexProvider = provider this.vertexProvider = provider as VertexProvider
} else { } else {
this.vertexProvider = createVertexProvider(provider) this.vertexProvider = createVertexProvider(provider)
} }

View File

@ -5,6 +5,7 @@ import type { MCPTool } from '@renderer/types'
import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types' import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk' import type { Chunk } from '@renderer/types/chunk'
import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider' import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import { openrouterReasoningMiddleware, skipGeminiThoughtSignatureMiddleware } from '@shared/ai-sdk-middlewares'
import type { LanguageModelMiddleware } from 'ai' import type { LanguageModelMiddleware } from 'ai'
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai' import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
@ -12,9 +13,7 @@ import { getAiSdkProviderId } from '../provider/factory'
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image' import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
import { noThinkMiddleware } from './noThinkMiddleware' import { noThinkMiddleware } from './noThinkMiddleware'
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware' import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware' import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware'
const logger = loggerService.withContext('AiSdkMiddlewareBuilder') const logger = loggerService.withContext('AiSdkMiddlewareBuilder')

View File

@ -1,50 +0,0 @@
import type { LanguageModelV2StreamPart } from '@ai-sdk/provider'
import type { LanguageModelMiddleware } from 'ai'
/**
* https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude
*
* @returns LanguageModelMiddleware - a middleware filter redacted block
*/
export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
const REDACTED_BLOCK = '[REDACTED]'
return {
middlewareVersion: 'v2',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate()
const modifiedContent = content.map((part) => {
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
return {
...part,
text: part.text.replace(REDACTED_BLOCK, '')
}
}
return part
})
return { content: modifiedContent, ...rest }
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream()
return {
stream: stream.pipeThrough(
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
transform(
chunk: LanguageModelV2StreamPart,
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
) {
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
controller.enqueue({
...chunk,
delta: chunk.delta.replace(REDACTED_BLOCK, '')
})
} else {
controller.enqueue(chunk)
}
}
})
),
...rest
}
}
}
}

View File

@ -1,36 +0,0 @@
import type { LanguageModelMiddleware } from 'ai'
/**
* skip Gemini Thought Signature Middleware
* Gemini3
* Due to the complexity of multi-model client requests (which can switch to other models mid-process),
* it was decided to add a skip for all Gemini3 thinking signatures via middleware.
* @param aiSdkId AI SDK Provider ID
* @returns LanguageModelMiddleware
*/
export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware {
const MAGIC_STRING = 'skip_thought_signature_validator'
return {
middlewareVersion: 'v2',
transformParams: async ({ params }) => {
const transformedParams = { ...params }
// Process messages in prompt
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
transformedParams.prompt = transformedParams.prompt.map((message) => {
if (typeof message.content !== 'string') {
for (const part of message.content) {
const googleOptions = part?.providerOptions?.[aiSdkId]
if (googleOptions?.thoughtSignature) {
googleOptions.thoughtSignature = MAGIC_STRING
}
}
}
return message
})
}
return transformedParams
}
}
}

View File

@ -37,7 +37,7 @@ vi.mock('@renderer/utils/api', () => ({
if (isSupportedAPIVersion === false) { if (isSupportedAPIVersion === false) {
return host // Return host as-is when isSupportedAPIVersion is false return host // Return host as-is when isSupportedAPIVersion is false
} }
return `${host}/v1` // Default behavior when isSupportedAPIVersion is true return host ? `${host}/v1` : '' // Default behavior when isSupportedAPIVersion is true
}), }),
routeToEndpoint: vi.fn((host) => ({ routeToEndpoint: vi.fn((host) => ({
baseURL: host, baseURL: host,
@ -46,6 +46,20 @@ vi.mock('@renderer/utils/api', () => ({
isWithTrailingSharp: vi.fn((host) => host?.endsWith('#') || false) isWithTrailingSharp: vi.fn((host) => host?.endsWith('#') || false)
})) }))
// Also mock @shared/utils/url since formatProviderApiHost uses it directly
vi.mock('@shared/utils/url', async (importOriginal) => {
const actual = (await importOriginal()) as any
return {
...actual,
formatApiHost: vi.fn((host, isSupportedAPIVersion = true) => {
if (isSupportedAPIVersion === false) {
return host || '' // Return host as-is when isSupportedAPIVersion is false
}
return host ? `${host}/v1` : '' // Default behavior when isSupportedAPIVersion is true
})
}
})
vi.mock('@renderer/utils/provider', async (importOriginal) => { vi.mock('@renderer/utils/provider', async (importOriginal) => {
const actual = (await importOriginal()) as any const actual = (await importOriginal()) as any
return { return {
@ -78,8 +92,8 @@ vi.mock('@renderer/services/AssistantService', () => ({
import { getProviderByModel } from '@renderer/services/AssistantService' import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types' import type { Model, Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api'
import { isAzureOpenAIProvider, isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider' import { isAzureOpenAIProvider, isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
import { formatApiHost } from '@shared/utils/url'
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants' import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig' import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'
@ -96,6 +110,31 @@ const createWindowKeyv = () => {
} }
} }
/**
* mock state
*/
const createDefaultMockState = (overrides?: {
includeUsage?: boolean | undefined
copilotHeaders?: Record<string, string>
}) => ({
copilot: { defaultHeaders: overrides?.copilotHeaders ?? {} },
settings: {
openAI: {
streamOptions: {
includeUsage: overrides?.includeUsage
}
}
},
llm: {
settings: {
vertexai: {
projectId: '',
location: ''
}
}
}
})
const createCopilotProvider = (): Provider => ({ const createCopilotProvider = (): Provider => ({
id: 'copilot', id: 'copilot',
type: 'openai', type: 'openai',
@ -150,16 +189,7 @@ describe('Copilot responses routing', () => {
...(globalThis as any).window, ...(globalThis as any).window,
keyv: createWindowKeyv() keyv: createWindowKeyv()
} }
mockGetState.mockReturnValue({ mockGetState.mockReturnValue(createDefaultMockState())
copilot: { defaultHeaders: {} },
settings: {
openAI: {
streamOptions: {
includeUsage: undefined
}
}
}
})
}) })
it('detects official GPT-5 Codex identifiers case-insensitively', () => { it('detects official GPT-5 Codex identifiers case-insensitively', () => {
@ -195,16 +225,7 @@ describe('CherryAI provider configuration', () => {
...(globalThis as any).window, ...(globalThis as any).window,
keyv: createWindowKeyv() keyv: createWindowKeyv()
} }
mockGetState.mockReturnValue({ mockGetState.mockReturnValue(createDefaultMockState())
copilot: { defaultHeaders: {} },
settings: {
openAI: {
streamOptions: {
includeUsage: undefined
}
}
}
})
vi.clearAllMocks() vi.clearAllMocks()
}) })
@ -276,16 +297,7 @@ describe('Perplexity provider configuration', () => {
...(globalThis as any).window, ...(globalThis as any).window,
keyv: createWindowKeyv() keyv: createWindowKeyv()
} }
mockGetState.mockReturnValue({ mockGetState.mockReturnValue(createDefaultMockState())
copilot: { defaultHeaders: {} },
settings: {
openAI: {
streamOptions: {
includeUsage: undefined
}
}
}
})
vi.clearAllMocks() vi.clearAllMocks()
}) })
@ -360,6 +372,7 @@ describe('Stream options includeUsage configuration', () => {
...(globalThis as any).window, ...(globalThis as any).window,
keyv: createWindowKeyv() keyv: createWindowKeyv()
} }
mockGetState.mockReturnValue(createDefaultMockState())
vi.clearAllMocks() vi.clearAllMocks()
}) })
@ -374,16 +387,7 @@ describe('Stream options includeUsage configuration', () => {
}) })
it('uses includeUsage from settings when undefined', () => { it('uses includeUsage from settings when undefined', () => {
mockGetState.mockReturnValue({ mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: undefined }))
copilot: { defaultHeaders: {} },
settings: {
openAI: {
streamOptions: {
includeUsage: undefined
}
}
}
})
const provider = createOpenAIProvider() const provider = createOpenAIProvider()
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai')) const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai'))
@ -392,16 +396,7 @@ describe('Stream options includeUsage configuration', () => {
}) })
it('uses includeUsage from settings when set to true', () => { it('uses includeUsage from settings when set to true', () => {
mockGetState.mockReturnValue({ mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: true }))
copilot: { defaultHeaders: {} },
settings: {
openAI: {
streamOptions: {
includeUsage: true
}
}
}
})
const provider = createOpenAIProvider() const provider = createOpenAIProvider()
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai')) const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai'))
@ -410,16 +405,7 @@ describe('Stream options includeUsage configuration', () => {
}) })
it('uses includeUsage from settings when set to false', () => { it('uses includeUsage from settings when set to false', () => {
mockGetState.mockReturnValue({ mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: false }))
copilot: { defaultHeaders: {} },
settings: {
openAI: {
streamOptions: {
includeUsage: false
}
}
}
})
const provider = createOpenAIProvider() const provider = createOpenAIProvider()
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai')) const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai'))
@ -428,16 +414,7 @@ describe('Stream options includeUsage configuration', () => {
}) })
it('respects includeUsage setting for non-supporting providers', () => { it('respects includeUsage setting for non-supporting providers', () => {
mockGetState.mockReturnValue({ mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: true }))
copilot: { defaultHeaders: {} },
settings: {
openAI: {
streamOptions: {
includeUsage: true
}
}
}
})
const testProvider: Provider = { const testProvider: Provider = {
id: 'test', id: 'test',
@ -459,16 +436,7 @@ describe('Stream options includeUsage configuration', () => {
}) })
it('uses includeUsage from settings for Copilot provider when set to false', () => { it('uses includeUsage from settings for Copilot provider when set to false', () => {
mockGetState.mockReturnValue({ mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: false }))
copilot: { defaultHeaders: {} },
settings: {
openAI: {
streamOptions: {
includeUsage: false
}
}
}
})
const provider = createCopilotProvider() const provider = createCopilotProvider()
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot')) const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot'))
@ -478,16 +446,7 @@ describe('Stream options includeUsage configuration', () => {
}) })
it('uses includeUsage from settings for Copilot provider when set to true', () => { it('uses includeUsage from settings for Copilot provider when set to true', () => {
mockGetState.mockReturnValue({ mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: true }))
copilot: { defaultHeaders: {} },
settings: {
openAI: {
streamOptions: {
includeUsage: true
}
}
}
})
const provider = createCopilotProvider() const provider = createCopilotProvider()
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot')) const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot'))
@ -497,16 +456,7 @@ describe('Stream options includeUsage configuration', () => {
}) })
it('uses includeUsage from settings for Copilot provider when undefined', () => { it('uses includeUsage from settings for Copilot provider when undefined', () => {
mockGetState.mockReturnValue({ mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: undefined }))
copilot: { defaultHeaders: {} },
settings: {
openAI: {
streamOptions: {
includeUsage: undefined
}
}
}
})
const provider = createCopilotProvider() const provider = createCopilotProvider()
const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot')) const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot'))

View File

@ -1,22 +0,0 @@
import type { Provider } from '@renderer/types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
const AZURE_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => ({
...provider,
type: 'anthropic',
apiHost: provider.apiHost + 'anthropic/v1',
id: 'azure-anthropic'
})
}
],
fallbackRule: (provider: Provider) => provider
}
export const azureAnthropicProviderCreator = provider2Provider.bind(null, AZURE_ANTHROPIC_RULES)

View File

@ -1,22 +0,0 @@
import type { Model, Provider } from '@renderer/types'
import type { RuleSet } from './types'
export const startsWith = (prefix: string) => (model: Model) => model.id.toLowerCase().startsWith(prefix.toLowerCase())
export const endpointIs = (type: string) => (model: Model) => model.endpoint_type === type
/**
* Provider
* @param ruleSet
* @param model
* @param provider provider对象
* @returns provider对象
*/
export function provider2Provider(ruleSet: RuleSet, model: Model, provider: Provider): Provider {
for (const rule of ruleSet.rules) {
if (rule.match(model)) {
return rule.provider(provider)
}
}
return ruleSet.fallbackRule(provider)
}

View File

@ -1,3 +1,7 @@
export { aihubmixProviderCreator } from './aihubmix' // Re-export from shared config
export { newApiResolverCreator } from './newApi' export {
export { vertexAnthropicProviderCreator } from './vertext-anthropic' aihubmixProviderCreator,
azureAnthropicProviderCreator,
newApiResolverCreator,
vertexAnthropicProviderCreator
} from '@shared/provider/config'

View File

@ -1,9 +0,0 @@
import type { Model, Provider } from '@renderer/types'
export interface RuleSet {
rules: Array<{
match: (model: Model) => boolean
provider: (provider: Provider) => Provider
}>
fallbackRule: (provider: Provider) => Provider
}

View File

@ -1,19 +0,0 @@
import type { Provider } from '@renderer/types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
const VERTEX_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => ({
...provider,
id: 'google-vertex-anthropic'
})
}
],
fallbackRule: (provider: Provider) => provider
}
export const vertexAnthropicProviderCreator = provider2Provider.bind(null, VERTEX_ANTHROPIC_RULES)

View File

@ -1,25 +1 @@
import type { Model } from '@renderer/types' export { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '@shared/provider/constant'
export const COPILOT_EDITOR_VERSION = 'vscode/1.104.1'
export const COPILOT_PLUGIN_VERSION = 'copilot-chat/0.26.7'
export const COPILOT_INTEGRATION_ID = 'vscode-chat'
export const COPILOT_USER_AGENT = 'GitHubCopilotChat/0.26.7'
export const COPILOT_DEFAULT_HEADERS = {
'Copilot-Integration-Id': COPILOT_INTEGRATION_ID,
'User-Agent': COPILOT_USER_AGENT,
'Editor-Version': COPILOT_EDITOR_VERSION,
'Editor-Plugin-Version': COPILOT_PLUGIN_VERSION,
'editor-version': COPILOT_EDITOR_VERSION,
'editor-plugin-version': COPILOT_PLUGIN_VERSION,
'copilot-vision-request': 'true'
} as const
// Models that require the OpenAI Responses endpoint when routed through GitHub Copilot (#10560)
const COPILOT_RESPONSES_MODEL_IDS = ['gpt-5-codex']
export function isCopilotResponsesModel(model: Model): boolean {
const normalizedId = model.id?.trim().toLowerCase()
const normalizedName = model.name?.trim().toLowerCase()
return COPILOT_RESPONSES_MODEL_IDS.some((target) => normalizedId === target || normalizedName === target)
}

View File

@ -1,8 +1,7 @@
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import type { Provider } from '@renderer/types' import type { Provider } from '@renderer/types'
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider' import { getAiSdkProviderId as sharedGetAiSdkProviderId } from '@shared/provider'
import type { Provider as AiSdkProvider } from 'ai' import type { Provider as AiSdkProvider } from 'ai'
import type { AiSdkConfig } from '../types' import type { AiSdkConfig } from '../types'
@ -22,69 +21,12 @@ const logger = loggerService.withContext('ProviderFactory')
} }
})() })()
/**
* Provider映射表
* Cherry Studio特有的provider ID到AI SDK标准ID的映射
*/
const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
gemini: 'google', // Google Gemini -> google
'azure-openai': 'azure', // Azure OpenAI -> azure
'openai-response': 'openai', // OpenAI Responses -> openai
grok: 'xai', // Grok -> xai
copilot: 'github-copilot-openai-compatible'
}
/**
* provider标识符
*/
function tryResolveProviderId(identifier: string): ProviderId | null {
// 1. 检查静态映射
const staticMapping = STATIC_PROVIDER_MAPPING[identifier]
if (staticMapping) {
return staticMapping
}
// 2. 检查AiCore是否支持包括别名支持
if (hasProviderConfigByAlias(identifier)) {
// 解析为真实的Provider ID
return resolveProviderConfigId(identifier) as ProviderId
}
return null
}
/** /**
* AI SDK Provider ID * AI SDK Provider ID
* * Uses shared implementation with renderer-specific config checker
* TODO: 整理函数逻辑
*/ */
export function getAiSdkProviderId(provider: Provider): string { export function getAiSdkProviderId(provider: Provider): string {
// 1. 尝试解析provider.id return sharedGetAiSdkProviderId(provider)
const resolvedFromId = tryResolveProviderId(provider.id)
if (isAzureOpenAIProvider(provider)) {
if (isAzureResponsesEndpoint(provider)) {
return 'azure-responses'
} else {
return 'azure'
}
}
if (resolvedFromId) {
return resolvedFromId
}
// 2. 尝试解析provider.type
// 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上
if (provider.type !== 'openai') {
const resolvedFromType = tryResolveProviderId(provider.type)
if (resolvedFromType) {
return resolvedFromType
}
}
if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat'
}
// 3. 最后的fallback使用provider本身的id
return provider.id
} }
export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> { export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {

View File

@ -1,4 +1,4 @@
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider' import { hasProviderConfig } from '@cherrystudio/ai-core/provider'
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models' import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
import { import {
getAwsBedrockAccessKeyId, getAwsBedrockAccessKeyId,
@ -9,58 +9,65 @@ import {
} from '@renderer/hooks/useAwsBedrock' } from '@renderer/hooks/useAwsBedrock'
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI' import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService' import { getProviderByModel } from '@renderer/services/AssistantService'
import { getProviderById } from '@renderer/services/ProviderService'
import store from '@renderer/store' import store from '@renderer/store'
import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types' import { isSystemProvider, type Model, type Provider } from '@renderer/types'
import type { OpenAICompletionsStreamOptions } from '@renderer/types/aiCoreTypes' import { isSupportStreamOptionsProvider } from '@renderer/utils/provider'
import { import {
formatApiHost, type AiSdkConfigContext,
formatAzureOpenAIApiHost, formatProviderApiHost as sharedFormatProviderApiHost,
formatOllamaApiHost, type ProviderFormatContext,
formatVertexApiHost, providerToAiSdkConfig as sharedProviderToAiSdkConfig,
isWithTrailingSharp, resolveActualProvider
routeToEndpoint } from '@shared/provider'
} from '@renderer/utils/api' import { cloneDeep } from 'lodash'
import {
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isOllamaProvider,
isPerplexityProvider,
isSupportStreamOptionsProvider,
isVertexProvider
} from '@renderer/utils/provider'
import { defaultAppHeaders } from '@shared/utils'
import { cloneDeep, isEmpty } from 'lodash'
import type { AiSdkConfig } from '../types' import type { AiSdkConfig } from '../types'
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { COPILOT_DEFAULT_HEADERS } from './constants' import { COPILOT_DEFAULT_HEADERS } from './constants'
import { getAiSdkProviderId } from './factory' import { getAiSdkProviderId } from './factory'
/** /**
* provider的转换逻辑 * Renderer-specific context for providerToAiSdkConfig
* Provides implementations using browser APIs, store, and hooks
*/ */
function handleSpecialProviders(model: Model, provider: Provider): Provider { function createRendererSdkContext(model: Model): AiSdkConfigContext {
if (isNewApiProvider(provider)) { return {
return newApiResolverCreator(model, provider) isOpenAIChatCompletionOnlyModel: () => isOpenAIChatCompletionOnlyModel(model),
isSupportStreamOptionsProvider: (provider) => isSupportStreamOptionsProvider(provider as Provider),
getIncludeUsageSetting: () => store.getState().settings.openAI?.streamOptions?.includeUsage,
getCopilotDefaultHeaders: () => COPILOT_DEFAULT_HEADERS,
getCopilotStoredHeaders: () => store.getState().copilot.defaultHeaders ?? {},
getAwsBedrockConfig: () => {
const authType = getAwsBedrockAuthType()
return {
authType,
region: getAwsBedrockRegion(),
apiKey: authType === 'apiKey' ? getAwsBedrockApiKey() : undefined,
accessKeyId: authType === 'iam' ? getAwsBedrockAccessKeyId() : undefined,
secretAccessKey: authType === 'iam' ? getAwsBedrockSecretAccessKey() : undefined
}
},
getVertexConfig: (provider) => {
if (!isVertexAIConfigured()) {
return undefined
}
return createVertexProvider(provider as Provider)
},
getEndpointType: () => model.endpoint_type
}
} }
if (isSystemProvider(provider)) { /**
if (provider.id === 'aihubmix') { * AISdk的BaseURL格式
return aihubmixProviderCreator(model, provider) * Uses shared implementation with renderer-specific context
} */
if (provider.id === 'vertexai') { function getRendererFormatContext(): ProviderFormatContext {
return vertexAnthropicProviderCreator(model, provider) const vertexSettings = store.getState().llm.settings.vertexai
return {
vertex: {
project: vertexSettings.projectId || 'default-project',
location: vertexSettings.location || 'us-central1'
} }
} }
if (isAzureOpenAIProvider(provider)) {
return azureAnthropicProviderCreator(model, provider)
}
return provider
} }
/** /**
@ -70,38 +77,8 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider {
* @param provider - The provider whose API host is to be formatted. * @param provider - The provider whose API host is to be formatted.
* @returns A new provider instance with the formatted API host. * @returns A new provider instance with the formatted API host.
*/ */
export function formatProviderApiHost(provider: Provider): Provider { function formatProviderApiHost(provider: Provider): Provider {
const formatted = { ...provider } return sharedFormatProviderApiHost(provider, getRendererFormatContext())
const appendApiVersion = !isWithTrailingSharp(provider.apiHost)
if (formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost, appendApiVersion)
}
if (isAnthropicProvider(provider)) {
const baseHost = formatted.anthropicApiHost || formatted.apiHost
// AI SDK needs /v1 in baseURL, Anthropic SDK will strip it in getSdkClient
formatted.apiHost = formatApiHost(baseHost, appendApiVersion)
if (!formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatted.apiHost
}
} else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isOllamaProvider(formatted)) {
formatted.apiHost = formatOllamaApiHost(formatted.apiHost)
} else if (isGeminiProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, appendApiVersion, 'v1beta')
} else if (isAzureOpenAIProvider(formatted)) {
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
} else if (isVertexProvider(formatted)) {
formatted.apiHost = formatVertexApiHost(formatted)
} else if (isCherryAIProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isPerplexityProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else {
formatted.apiHost = formatApiHost(formatted.apiHost, appendApiVersion)
}
return formatted
} }
/** /**
@ -132,7 +109,9 @@ export function adaptProvider({ provider, model }: { provider: Provider; model?:
// Apply transformations in order // Apply transformations in order
if (model) { if (model) {
adaptedProvider = handleSpecialProviders(model, adaptedProvider) adaptedProvider = resolveActualProvider(adaptedProvider, model, {
isSystemProvider
})
} }
adaptedProvider = formatProviderApiHost(adaptedProvider) adaptedProvider = formatProviderApiHost(adaptedProvider)
@ -141,148 +120,11 @@ export function adaptProvider({ provider, model }: { provider: Provider; model?:
/** /**
* Provider AI SDK * Provider AI SDK
* * Uses shared implementation with renderer-specific context
*/ */
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig { export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
const aiSdkProviderId = getAiSdkProviderId(actualProvider) const context = createRendererSdkContext(model)
return sharedProviderToAiSdkConfig(actualProvider, model.id, context) as AiSdkConfig
// 构建基础配置
const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost)
const baseConfig = {
baseURL: baseURL,
apiKey: actualProvider.apiKey
}
let includeUsage: OpenAICompletionsStreamOptions['include_usage'] = undefined
if (isSupportStreamOptionsProvider(actualProvider)) {
includeUsage = store.getState().settings.openAI?.streamOptions?.includeUsage
}
const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot
if (isCopilotProvider) {
const storedHeaders = store.getState().copilot.defaultHeaders ?? {}
const options = ProviderConfigFactory.fromProvider('github-copilot-openai-compatible', baseConfig, {
headers: {
...COPILOT_DEFAULT_HEADERS,
...storedHeaders,
...actualProvider.extra_headers
},
name: actualProvider.id,
includeUsage
})
return {
providerId: 'github-copilot-openai-compatible',
options
}
}
if (isOllamaProvider(actualProvider)) {
return {
providerId: 'ollama',
options: {
...baseConfig,
headers: {
...actualProvider.extra_headers,
Authorization: !isEmpty(baseConfig.apiKey) ? `Bearer ${baseConfig.apiKey}` : undefined
}
}
}
}
// 处理OpenAI模式
const extraOptions: any = {}
extraOptions.endpoint = endpoint
if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && actualProvider.type === 'openai')) {
extraOptions.mode = 'chat'
}
extraOptions.headers = {
...defaultAppHeaders(),
...actualProvider.extra_headers
}
if (aiSdkProviderId === 'openai') {
extraOptions.headers['X-Api-Key'] = baseConfig.apiKey
}
// azure
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api
if (aiSdkProviderId === 'azure-responses') {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'azure') {
extraOptions.mode = 'chat'
}
if (isAzureOpenAIProvider(actualProvider)) {
const apiVersion = actualProvider.apiVersion?.trim()
if (apiVersion) {
extraOptions.apiVersion = apiVersion
if (!['preview', 'v1'].includes(apiVersion)) {
extraOptions.useDeploymentBasedUrls = true
}
}
}
// bedrock
if (aiSdkProviderId === 'bedrock') {
const authType = getAwsBedrockAuthType()
extraOptions.region = getAwsBedrockRegion()
if (authType === 'apiKey') {
extraOptions.apiKey = getAwsBedrockApiKey()
} else {
extraOptions.accessKeyId = getAwsBedrockAccessKeyId()
extraOptions.secretAccessKey = getAwsBedrockSecretAccessKey()
}
}
// google-vertex
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}
const { project, location, googleCredentials } = createVertexProvider(actualProvider)
extraOptions.project = project
extraOptions.location = location
extraOptions.googleCredentials = {
...googleCredentials,
privateKey: formatPrivateKey(googleCredentials.privateKey)
}
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
}
// cherryin
if (aiSdkProviderId === 'cherryin') {
if (model.endpoint_type) {
extraOptions.endpointType = model.endpoint_type
}
// CherryIN API Host
const cherryinProvider = getProviderById(SystemProviderIds.cherryin)
if (cherryinProvider) {
extraOptions.anthropicBaseURL = cherryinProvider.anthropicApiHost + '/v1'
extraOptions.geminiBaseURL = cherryinProvider.apiHost + '/v1beta/models'
}
}
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId,
options
}
}
// 否则fallback到openai-compatible
const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: actualProvider.id,
...extraOptions,
includeUsage
}
}
} }
/** /**
@ -325,13 +167,13 @@ export async function prepareSpecialProviderConfig(
break break
} }
case 'cherryai': { case 'cherryai': {
config.options.fetch = async (url, options) => { config.options.fetch = async (url: RequestInfo | URL, options: RequestInit) => {
// 在这里对最终参数进行签名 // 在这里对最终参数进行签名
const signature = await window.api.cherryai.generateSignature({ const signature = await window.api.cherryai.generateSignature({
method: 'POST', method: 'POST',
path: '/chat/completions', path: '/chat/completions',
query: '', query: '',
body: JSON.parse(options.body) body: JSON.parse(options.body as string)
}) })
return fetch(url, { return fetch(url, {
...options, ...options,

View File

@ -1,124 +1,13 @@
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import * as z from 'zod' import { initializeSharedProviders, SHARED_PROVIDER_CONFIGS } from '@shared/provider'
const logger = loggerService.withContext('ProviderConfigs') const logger = loggerService.withContext('ProviderConfigs')
/** export const NEW_PROVIDER_CONFIGS = SHARED_PROVIDER_CONFIGS
* Provider配置定义
* AI Providers
*/
export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
{
id: 'openrouter',
name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: true,
aliases: ['openrouter']
},
{
id: 'google-vertex',
name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex/edge'),
creatorFunctionName: 'createVertex',
supportsImageGeneration: true,
aliases: ['vertexai']
},
{
id: 'google-vertex-anthropic',
name: 'Google Vertex AI Anthropic',
import: () => import('@ai-sdk/google-vertex/anthropic/edge'),
creatorFunctionName: 'createVertexAnthropic',
supportsImageGeneration: true,
aliases: ['vertexai-anthropic']
},
{
id: 'azure-anthropic',
name: 'Azure AI Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false,
aliases: ['azure-anthropic']
},
{
id: 'github-copilot-openai-compatible',
name: 'GitHub Copilot OpenAI Compatible',
import: () => import('@opeoginni/github-copilot-openai-compatible'),
creatorFunctionName: 'createGitHubCopilotOpenAICompatible',
supportsImageGeneration: false,
aliases: ['copilot', 'github-copilot']
},
{
id: 'bedrock',
name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: true,
aliases: ['aws-bedrock']
},
{
id: 'perplexity',
name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false,
aliases: ['perplexity']
},
{
id: 'mistral',
name: 'Mistral',
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral',
supportsImageGeneration: false,
aliases: ['mistral']
},
{
id: 'huggingface',
name: 'HuggingFace',
import: () => import('@ai-sdk/huggingface'),
creatorFunctionName: 'createHuggingFace',
supportsImageGeneration: true,
aliases: ['hf', 'hugging-face']
},
{
id: 'gateway',
name: 'Vercel AI Gateway',
import: () => import('@ai-sdk/gateway'),
creatorFunctionName: 'createGateway',
supportsImageGeneration: true,
aliases: ['ai-gateway']
},
{
id: 'cerebras',
name: 'Cerebras',
import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras',
supportsImageGeneration: false
},
{
id: 'ollama',
name: 'Ollama',
import: () => import('ollama-ai-provider-v2'),
creatorFunctionName: 'createOllama',
supportsImageGeneration: false
}
] as const
export const registeredNewProviderIds = NEW_PROVIDER_CONFIGS.map((config) => config.id)
export const registeredNewProviderIdSchema = z.enum(registeredNewProviderIds)
/**
* Providers
* 使aiCore的动态注册功能
*/
export async function initializeNewProviders(): Promise<void> { export async function initializeNewProviders(): Promise<void> {
try { initializeSharedProviders({
const successCount = registerMultipleProviderConfigs(NEW_PROVIDER_CONFIGS) warn: (message) => logger.warn(message),
if (successCount < NEW_PROVIDER_CONFIGS.length) { error: (message, error) => logger.error(message, error)
logger.warn('Some providers failed to register. Check previous error logs.') })
}
} catch (error) {
logger.error('Failed to initialize new providers:', error as Error)
}
} }

View File

@ -6,6 +6,29 @@ import { isDeepSeekHybridInferenceModel } from '../reasoning'
import { isFunctionCallingModel } from '../tooluse' import { isFunctionCallingModel } from '../tooluse'
import { isPureGenerateImageModel, isTextToImageModel } from '../vision' import { isPureGenerateImageModel, isTextToImageModel } from '../vision'
vi.mock('@renderer/i18n', () => ({
__esModule: true,
default: {
t: vi.fn((key: string) => key)
}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn().mockReturnValue({
id: 'openai',
type: 'openai',
name: 'OpenAI',
models: []
}),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
vi.mock('@renderer/hooks/useStore', () => ({ vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => []) getStoreProviders: vi.fn(() => [])
})) }))

View File

@ -15,6 +15,7 @@ import {
isSupportVerbosityModel isSupportVerbosityModel
} from '../openai' } from '../openai'
import { isQwenMTModel } from '../qwen' import { isQwenMTModel } from '../qwen'
import { isFunctionCallingModel } from '../tooluse'
import { import {
agentModelFilter, agentModelFilter,
getModelSupportedVerbosity, getModelSupportedVerbosity,
@ -71,6 +72,29 @@ vi.mock('@renderer/store/settings', () => {
) )
}) })
vi.mock('@renderer/i18n', () => ({
__esModule: true,
default: {
t: vi.fn((key: string) => key)
}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn().mockReturnValue({
id: 'openai',
type: 'openai',
name: 'OpenAI',
models: []
}),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
vi.mock('@renderer/hooks/useSettings', () => ({ vi.mock('@renderer/hooks/useSettings', () => ({
useSettings: vi.fn(() => ({})), useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })), useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
@ -101,6 +125,10 @@ vi.mock('../websearch', () => ({
isOpenAIWebSearchChatCompletionOnlyModel: vi.fn() isOpenAIWebSearchChatCompletionOnlyModel: vi.fn()
})) }))
vi.mock('../tooluse', () => ({
isFunctionCallingModel: vi.fn()
}))
const createModel = (overrides: Partial<Model> = {}): Model => ({ const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o', id: 'gpt-4o',
name: 'gpt-4o', name: 'gpt-4o',
@ -116,6 +144,7 @@ const textToImageMock = vi.mocked(isTextToImageModel)
const generateImageMock = vi.mocked(isGenerateImageModel) const generateImageMock = vi.mocked(isGenerateImageModel)
const reasoningMock = vi.mocked(isOpenAIReasoningModel) const reasoningMock = vi.mocked(isOpenAIReasoningModel)
const openAIWebSearchOnlyMock = vi.mocked(isOpenAIWebSearchChatCompletionOnlyModel) const openAIWebSearchOnlyMock = vi.mocked(isOpenAIWebSearchChatCompletionOnlyModel)
const isFunctionCallingModelMock = vi.mocked(isFunctionCallingModel)
describe('model utils', () => { describe('model utils', () => {
beforeEach(() => { beforeEach(() => {
@ -124,9 +153,10 @@ describe('model utils', () => {
rerankMock.mockReturnValue(false) rerankMock.mockReturnValue(false)
visionMock.mockReturnValue(true) visionMock.mockReturnValue(true)
textToImageMock.mockReturnValue(false) textToImageMock.mockReturnValue(false)
generateImageMock.mockReturnValue(true) generateImageMock.mockReturnValue(false)
reasoningMock.mockReturnValue(false) reasoningMock.mockReturnValue(false)
openAIWebSearchOnlyMock.mockReturnValue(false) openAIWebSearchOnlyMock.mockReturnValue(false)
isFunctionCallingModelMock.mockReturnValue(true)
}) })
describe('OpenAI model detection', () => { describe('OpenAI model detection', () => {
@ -598,6 +628,7 @@ describe('model utils', () => {
describe('isGenerateImageModels', () => { describe('isGenerateImageModels', () => {
it('returns true when all models support image generation', () => { it('returns true when all models support image generation', () => {
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })] const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
generateImageMock.mockReturnValue(true)
expect(isGenerateImageModels(models)).toBe(true) expect(isGenerateImageModels(models)).toBe(true)
}) })
@ -636,12 +667,22 @@ describe('model utils', () => {
expect(agentModelFilter(createModel({ id: 'rerank' }))).toBe(false) expect(agentModelFilter(createModel({ id: 'rerank' }))).toBe(false)
}) })
it('filters out non-function-call models', () => {
rerankMock.mockReturnValue(false)
isFunctionCallingModelMock.mockReturnValueOnce(false)
expect(agentModelFilter(createModel({ id: 'DeepSeek R1' }))).toBe(false)
})
it('filters out text-to-image models', () => { it('filters out text-to-image models', () => {
rerankMock.mockReturnValue(false) rerankMock.mockReturnValue(false)
textToImageMock.mockReturnValueOnce(true) textToImageMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'gpt-image-1' }))).toBe(false) expect(agentModelFilter(createModel({ id: 'gpt-image-1' }))).toBe(false)
}) })
}) })
textToImageMock.mockReturnValue(false)
generateImageMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'dall-e-3' }))).toBe(false)
}) })
describe('Temperature limits', () => { describe('Temperature limits', () => {

View File

@ -1,6 +1,8 @@
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model } from '@renderer/types' import type { Model } from '@renderer/types'
import { isSystemProviderId } from '@renderer/types' import { isSystemProviderId } from '@renderer/types'
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils' import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import { isAzureOpenAIProvider } from '@shared/provider'
import { isEmbeddingModel, isRerankModel } from './embedding' import { isEmbeddingModel, isRerankModel } from './embedding'
import { isDeepSeekHybridInferenceModel } from './reasoning' import { isDeepSeekHybridInferenceModel } from './reasoning'
@ -55,6 +57,13 @@ export const FUNCTION_CALLING_REGEX = new RegExp(
'i' 'i'
) )
const AZURE_FUNCTION_CALLING_EXCLUDED_MODELS = [
'(?:Meta-)?Llama-3(?:\\.\\d+)?-[\\w-]+',
'Phi-[34](?:\\.[\\w-]+)?(?:-[\\w-]+)?',
'DeepSeek-(?:R1|V3)',
'Codestral-2501'
]
export function isFunctionCallingModel(model?: Model): boolean { export function isFunctionCallingModel(model?: Model): boolean {
if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) { if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
return false return false
@ -70,6 +79,15 @@ export function isFunctionCallingModel(model?: Model): boolean {
return FUNCTION_CALLING_REGEX.test(modelId) || FUNCTION_CALLING_REGEX.test(model.name) return FUNCTION_CALLING_REGEX.test(modelId) || FUNCTION_CALLING_REGEX.test(model.name)
} }
const provider = getProviderByModel(model)
if (isAzureOpenAIProvider(provider)) {
const azureExcludedRegex = new RegExp(`\\b(?:${AZURE_FUNCTION_CALLING_EXCLUDED_MODELS.join('|')})\\b`, 'i')
if (azureExcludedRegex.test(modelId)) {
return false
}
}
// 2025/08/26 百炼与火山引擎均不支持 v3.1 函数调用 // 2025/08/26 百炼与火山引擎均不支持 v3.1 函数调用
// 先默认支持 // 先默认支持
if (isDeepSeekHybridInferenceModel(model)) { if (isDeepSeekHybridInferenceModel(model)) {

View File

@ -1,5 +1,6 @@
import type OpenAI from '@cherrystudio/openai' import type OpenAI from '@cherrystudio/openai'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding' import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding'
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Assistant } from '@renderer/types' import type { Assistant } from '@renderer/types'
import { type Model, SystemProviderIds } from '@renderer/types' import { type Model, SystemProviderIds } from '@renderer/types'
import type { OpenAIVerbosity, ValidOpenAIVerbosity } from '@renderer/types/aiCoreTypes' import type { OpenAIVerbosity, ValidOpenAIVerbosity } from '@renderer/types/aiCoreTypes'
@ -17,6 +18,7 @@ import {
} from './openai' } from './openai'
import { isQwenMTModel } from './qwen' import { isQwenMTModel } from './qwen'
import { isClaude45ReasoningModel } from './reasoning' import { isClaude45ReasoningModel } from './reasoning'
import { isFunctionCallingModel } from './tooluse'
import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision' import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision'
export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i
export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$', 'i') export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$', 'i')
@ -247,8 +249,21 @@ export const isGrokModel = (model: Model) => {
// zhipu 视觉推理模型用这组 special token 标记推理结果 // zhipu 视觉推理模型用这组 special token 标记推理结果
export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const
// TODO: 支持提示词模式的工具调用
export const agentModelFilter = (model: Model): boolean => { export const agentModelFilter = (model: Model): boolean => {
return !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model) const provider = getProviderByModel(model)
// 需要适配,且容易超出限额
if (provider.id === SystemProviderIds.copilot) {
return false
}
return (
!isEmbeddingModel(model) &&
!isRerankModel(model) &&
!isTextToImageModel(model) &&
!isGenerateImageModel(model) &&
isFunctionCallingModel(model)
)
} }
export const isMaxTemperatureOneModel = (model: Model): boolean => { export const isMaxTemperatureOneModel = (model: Model): boolean => {

View File

@ -17,7 +17,7 @@ import type { EndpointType, Model } from '@renderer/types'
import { getClaudeSupportedProviders } from '@renderer/utils/provider' import { getClaudeSupportedProviders } from '@renderer/utils/provider'
import type { TerminalConfig } from '@shared/config/constant' import type { TerminalConfig } from '@shared/config/constant'
import { codeTools, terminalApps } from '@shared/config/constant' import { codeTools, terminalApps } from '@shared/config/constant'
import { isSiliconAnthropicCompatibleModel } from '@shared/config/providers' import { isPpioAnthropicCompatibleModel, isSiliconAnthropicCompatibleModel } from '@shared/config/providers'
import { Alert, Avatar, Button, Checkbox, Input, Popover, Select, Space, Tooltip } from 'antd' import { Alert, Avatar, Button, Checkbox, Input, Popover, Select, Space, Tooltip } from 'antd'
import { ArrowUpRight, Download, FolderOpen, HelpCircle, Terminal, X } from 'lucide-react' import { ArrowUpRight, Download, FolderOpen, HelpCircle, Terminal, X } from 'lucide-react'
import type { FC } from 'react' import type { FC } from 'react'
@ -82,10 +82,12 @@ const CodeToolsPage: FC = () => {
if (m.supported_endpoint_types) { if (m.supported_endpoint_types) {
return m.supported_endpoint_types.includes('anthropic') return m.supported_endpoint_types.includes('anthropic')
} }
// Special handling for silicon provider: only specific models support Anthropic API
if (m.provider === 'silicon') { if (m.provider === 'silicon') {
return isSiliconAnthropicCompatibleModel(m.id) return isSiliconAnthropicCompatibleModel(m.id)
} }
if (m.provider === 'ppio') {
return isPpioAnthropicCompatibleModel(m.id)
}
return m.id.includes('claude') || CLAUDE_OFFICIAL_SUPPORTED_PROVIDERS.includes(m.provider) return m.id.includes('claude') || CLAUDE_OFFICIAL_SUPPORTED_PROVIDERS.includes(m.provider)
} }

View File

@ -23,6 +23,7 @@ import { abortCompletion } from '@renderer/utils/abortController'
import { buildAgentSessionTopicId } from '@renderer/utils/agentSession' import { buildAgentSessionTopicId } from '@renderer/utils/agentSession'
import { getSendMessageShortcutLabel } from '@renderer/utils/input' import { getSendMessageShortcutLabel } from '@renderer/utils/input'
import { createMainTextBlock, createMessage } from '@renderer/utils/messageUtils/create' import { createMainTextBlock, createMessage } from '@renderer/utils/messageUtils/create'
import { parseModelId } from '@renderer/utils/model'
import { documentExts, imageExts, textExts } from '@shared/config/constant' import { documentExts, imageExts, textExts } from '@shared/config/constant'
import type { FC } from 'react' import type { FC } from 'react'
import React, { useCallback, useEffect, useMemo, useRef } from 'react' import React, { useCallback, useEffect, useMemo, useRef } from 'react'
@ -67,8 +68,9 @@ const AgentSessionInputbar: FC<Props> = ({ agentId, sessionId }) => {
if (!session) return null if (!session) return null
// Extract model info // Extract model info
const [providerId, actualModelId] = session.model?.split(':') ?? [undefined, undefined] // Use parseModelId to handle model IDs with colons (e.g., "openrouter:anthropic/claude:free")
const actualModel = actualModelId ? getModel(actualModelId, providerId) : undefined const parsed = parseModelId(session.model)
const actualModel = parsed ? getModel(parsed.modelId, parsed.providerId) : undefined
const model: Model | undefined = actualModel const model: Model | undefined = actualModel
? { ? {

View File

@ -81,7 +81,8 @@ const ANTHROPIC_COMPATIBLE_PROVIDER_IDS = [
SystemProviderIds.silicon, SystemProviderIds.silicon,
SystemProviderIds.qiniu, SystemProviderIds.qiniu,
SystemProviderIds.dmxapi, SystemProviderIds.dmxapi,
SystemProviderIds.mimo SystemProviderIds.mimo,
SystemProviderIds.ppio
] as const ] as const
type AnthropicCompatibleProviderId = (typeof ANTHROPIC_COMPATIBLE_PROVIDER_IDS)[number] type AnthropicCompatibleProviderId = (typeof ANTHROPIC_COMPATIBLE_PROVIDER_IDS)[number]

View File

@ -2945,6 +2945,11 @@ const migrateConfig = {
includeUsage: DEFAULT_STREAM_OPTIONS_INCLUDE_USAGE includeUsage: DEFAULT_STREAM_OPTIONS_INCLUDE_USAGE
} }
} }
state.llm.providers.forEach((provider) => {
if (provider.id === SystemProviderIds.ppio) {
provider.anthropicApiHost = 'https://api.ppinfra.com/anthropic'
}
})
logger.info('migrate 182 success') logger.info('migrate 182 success')
return state return state
} catch (error) { } catch (error) {

View File

@ -7,6 +7,7 @@ import type { CSSProperties } from 'react'
export * from './file' export * from './file'
export * from './note' export * from './note'
import type { MinimalModel } from '@shared/provider/types'
import * as z from 'zod' import * as z from 'zod'
import type { StreamTextParams } from './aiCoreTypes' import type { StreamTextParams } from './aiCoreTypes'
@ -274,7 +275,7 @@ export type ModelCapability = {
isUserSelected?: boolean isUserSelected?: boolean
} }
export type Model = { export type Model = MinimalModel & {
id: string id: string
provider: string provider: string
name: string name: string

View File

@ -1,25 +1,14 @@
import type OpenAI from '@cherrystudio/openai' import type OpenAI from '@cherrystudio/openai'
import type { MinimalProvider } from '@shared/provider'
import type { ProviderType, SystemProviderId, SystemProviderIdTypeMap } from '@shared/provider/types'
import { isSystemProviderId, SystemProviderIds } from '@shared/provider/types'
import type { Model } from '@types' import type { Model } from '@types'
import * as z from 'zod'
import type { OpenAIVerbosity } from './aiCoreTypes' import type { OpenAIVerbosity } from './aiCoreTypes'
export const ProviderTypeSchema = z.enum([ export type { ProviderType } from '@shared/provider'
'openai', export type { SystemProviderId, SystemProviderIdTypeMap } from '@shared/provider/types'
'openai-response', export { isSystemProviderId, ProviderTypeSchema, SystemProviderIds } from '@shared/provider/types'
'anthropic',
'gemini',
'azure-openai',
'vertexai',
'mistral',
'aws-bedrock',
'vertex-anthropic',
'new-api',
'gateway',
'ollama'
])
export type ProviderType = z.infer<typeof ProviderTypeSchema>
// undefined is treated as supported, enabled by default // undefined is treated as supported, enabled by default
export type ProviderApiOptions = { export type ProviderApiOptions = {
@ -94,7 +83,7 @@ export function isAwsBedrockAuthType(type: string): type is AwsBedrockAuthType {
return Object.hasOwn(AwsBedrockAuthTypes, type) return Object.hasOwn(AwsBedrockAuthTypes, type)
} }
export type Provider = { export type Provider = MinimalProvider & {
id: string id: string
type: ProviderType type: ProviderType
name: string name: string
@ -129,142 +118,6 @@ export type Provider = {
extra_headers?: Record<string, string> extra_headers?: Record<string, string>
} }
export const SystemProviderIdSchema = z.enum([
'cherryin',
'silicon',
'aihubmix',
'ocoolai',
'deepseek',
'ppio',
'alayanew',
'qiniu',
'dmxapi',
'burncloud',
'tokenflux',
'302ai',
'cephalon',
'lanyun',
'ph8',
'openrouter',
'ollama',
'ovms',
'new-api',
'lmstudio',
'anthropic',
'openai',
'azure-openai',
'gemini',
'vertexai',
'github',
'copilot',
'zhipu',
'yi',
'moonshot',
'baichuan',
'dashscope',
'stepfun',
'doubao',
'infini',
'minimax',
'groq',
'together',
'fireworks',
'nvidia',
'grok',
'hyperbolic',
'mistral',
'jina',
'perplexity',
'modelscope',
'xirang',
'hunyuan',
'tencent-cloud-ti',
'baidu-cloud',
'gpustack',
'voyageai',
'aws-bedrock',
'poe',
'aionly',
'longcat',
'huggingface',
'sophnet',
'gateway',
'cerebras',
'mimo'
])
export type SystemProviderId = z.infer<typeof SystemProviderIdSchema>
export const isSystemProviderId = (id: string): id is SystemProviderId => {
return SystemProviderIdSchema.safeParse(id).success
}
export const SystemProviderIds = {
cherryin: 'cherryin',
silicon: 'silicon',
aihubmix: 'aihubmix',
ocoolai: 'ocoolai',
deepseek: 'deepseek',
ppio: 'ppio',
alayanew: 'alayanew',
qiniu: 'qiniu',
dmxapi: 'dmxapi',
burncloud: 'burncloud',
tokenflux: 'tokenflux',
'302ai': '302ai',
cephalon: 'cephalon',
lanyun: 'lanyun',
ph8: 'ph8',
sophnet: 'sophnet',
openrouter: 'openrouter',
ollama: 'ollama',
ovms: 'ovms',
'new-api': 'new-api',
lmstudio: 'lmstudio',
anthropic: 'anthropic',
openai: 'openai',
'azure-openai': 'azure-openai',
gemini: 'gemini',
vertexai: 'vertexai',
github: 'github',
copilot: 'copilot',
zhipu: 'zhipu',
yi: 'yi',
moonshot: 'moonshot',
baichuan: 'baichuan',
dashscope: 'dashscope',
stepfun: 'stepfun',
doubao: 'doubao',
infini: 'infini',
minimax: 'minimax',
groq: 'groq',
together: 'together',
fireworks: 'fireworks',
nvidia: 'nvidia',
grok: 'grok',
hyperbolic: 'hyperbolic',
mistral: 'mistral',
jina: 'jina',
perplexity: 'perplexity',
modelscope: 'modelscope',
xirang: 'xirang',
hunyuan: 'hunyuan',
'tencent-cloud-ti': 'tencent-cloud-ti',
'baidu-cloud': 'baidu-cloud',
gpustack: 'gpustack',
voyageai: 'voyageai',
'aws-bedrock': 'aws-bedrock',
poe: 'poe',
aionly: 'aionly',
longcat: 'longcat',
huggingface: 'huggingface',
gateway: 'gateway',
cerebras: 'cerebras',
mimo: 'mimo'
} as const satisfies Record<SystemProviderId, SystemProviderId>
type SystemProviderIdTypeMap = typeof SystemProviderIds
export type SystemProvider = Provider & { export type SystemProvider = Provider & {
id: SystemProviderId id: SystemProviderId
isSystem: true isSystem: true

View File

@ -326,18 +326,7 @@ describe('api', () => {
}) })
it('uses global endpoint when location equals global', () => { it('uses global endpoint when location equals global', () => {
getStateMock.mockReturnValueOnce({ expect(formatVertexApiHost(createVertexProvider(''), 'global-project', 'global')).toBe(
llm: {
settings: {
vertexai: {
projectId: 'global-project',
location: 'global'
}
}
}
})
expect(formatVertexApiHost(createVertexProvider(''))).toBe(
'https://aiplatform.googleapis.com/v1/projects/global-project/locations/global' 'https://aiplatform.googleapis.com/v1/projects/global-project/locations/global'
) )
}) })

View File

@ -1,7 +1,7 @@
import type { Model, ModelTag } from '@renderer/types' import type { Model, ModelTag } from '@renderer/types'
import { describe, expect, it, vi } from 'vitest' import { describe, expect, it, vi } from 'vitest'
import { getModelTags, isFreeModel } from '../model' import { getModelTags, isFreeModel, parseModelId } from '../model'
// Mock the model checking functions from @renderer/config/models // Mock the model checking functions from @renderer/config/models
vi.mock('@renderer/config/models', () => ({ vi.mock('@renderer/config/models', () => ({
@ -92,4 +92,85 @@ describe('model', () => {
expect(getModelTags(models_2)).toStrictEqual(expected_2) expect(getModelTags(models_2)).toStrictEqual(expected_2)
}) })
}) })
describe('parseModelId', () => {
it('should parse model identifiers with single colon', () => {
expect(parseModelId('anthropic:claude-3-sonnet')).toEqual({
providerId: 'anthropic',
modelId: 'claude-3-sonnet'
})
expect(parseModelId('openai:gpt-4')).toEqual({
providerId: 'openai',
modelId: 'gpt-4'
})
})
it('should parse model identifiers with multiple colons', () => {
expect(parseModelId('openrouter:anthropic/claude-3.5-sonnet:free')).toEqual({
providerId: 'openrouter',
modelId: 'anthropic/claude-3.5-sonnet:free'
})
expect(parseModelId('provider:model:suffix:extra')).toEqual({
providerId: 'provider',
modelId: 'model:suffix:extra'
})
})
it('should handle model identifiers without provider prefix', () => {
expect(parseModelId('claude-3-sonnet')).toEqual({
providerId: undefined,
modelId: 'claude-3-sonnet'
})
expect(parseModelId('gpt-4')).toEqual({
providerId: undefined,
modelId: 'gpt-4'
})
})
it('should return undefined for invalid inputs', () => {
expect(parseModelId(undefined)).toBeUndefined()
expect(parseModelId('')).toBeUndefined()
expect(parseModelId(' ')).toBeUndefined()
})
it('should handle edge cases with colons', () => {
// Colon at start - treat as modelId without provider
expect(parseModelId(':missing-provider')).toEqual({
providerId: undefined,
modelId: ':missing-provider'
})
// Colon at end - treat everything before as modelId
expect(parseModelId('missing-model:')).toEqual({
providerId: undefined,
modelId: 'missing-model'
})
// Only colon - treat as modelId without provider
expect(parseModelId(':')).toEqual({
providerId: undefined,
modelId: ':'
})
})
it('should handle edge cases', () => {
expect(parseModelId('a:b')).toEqual({
providerId: 'a',
modelId: 'b'
})
expect(parseModelId('provider:model-with-dashes')).toEqual({
providerId: 'provider',
modelId: 'model-with-dashes'
})
expect(parseModelId('provider:model/with/slashes')).toEqual({
providerId: 'provider',
modelId: 'model/with/slashes'
})
})
})
}) })

View File

@ -1,6 +1,20 @@
import store from '@renderer/store' export {
import type { VertexProvider } from '@renderer/types' formatApiHost,
import { trim } from 'lodash' formatAzureOpenAIApiHost,
formatOllamaApiHost,
formatVertexApiHost,
getAiSdkBaseUrl,
getTrailingApiVersion,
hasAPIVersion,
isWithTrailingSharp,
routeToEndpoint,
SUPPORTED_ENDPOINT_LIST,
SUPPORTED_IMAGE_ENDPOINT_LIST,
validateApiHost,
withoutTrailingApiVersion,
withoutTrailingSharp,
withoutTrailingSlash
} from '@shared/utils/url'
/** /**
* API key * API key
@ -12,228 +26,6 @@ export function formatApiKeys(value: string): string {
return value.replaceAll('', ',').replaceAll('\n', ',') return value.replaceAll('', ',').replaceAll('\n', ',')
} }
/**
* Matches a version segment in a path that starts with `/v<number>` and optionally
* continues with `alpha` or `beta`. The segment may be followed by `/` or the end
* of the string (useful for cases like `/v3alpha/resources`).
*/
const VERSION_REGEX_PATTERN = '\\/v\\d+(?:alpha|beta)?(?=\\/|$)'
/**
* Matches an API version at the end of a URL (with optional trailing slash).
* Used to detect and extract versions only from the trailing position.
*/
const TRAILING_VERSION_REGEX = /\/v\d+(?:alpha|beta)?\/?$/i
/**
* host path /v1/v2beta
*
* @param host - host path
* @returns path true false
*/
export function hasAPIVersion(host?: string): boolean {
if (!host) return false
const regex = new RegExp(VERSION_REGEX_PATTERN, 'i')
try {
const url = new URL(host)
return regex.test(url.pathname)
} catch {
// 若无法作为完整 URL 解析,则当作路径直接检测
return regex.test(host)
}
}
/**
* Removes the trailing slash from a URL string if it exists.
*
* @template T - The string type to preserve type safety
* @param {T} url - The URL string to process
* @returns {T} The URL string without a trailing slash
*
* @example
* ```ts
* withoutTrailingSlash('https://example.com/') // 'https://example.com'
* withoutTrailingSlash('https://example.com') // 'https://example.com'
* ```
*/
export function withoutTrailingSlash<T extends string>(url: T): T {
return url.replace(/\/$/, '') as T
}
/**
* Checks if a URL string ends with a trailing '#' character.
*
* @template T - The string type to preserve type safety
* @param {T} url - The URL string to check
* @returns {boolean} True if the URL ends with '#', false otherwise
*
* @example
* ```ts
* isWithTrailingSharp('https://example.com#') // true
* isWithTrailingSharp('https://example.com') // false
* ```
*/
export function isWithTrailingSharp<T extends string>(url: T): boolean {
return url.endsWith('#')
}
/**
* Removes the trailing '#' from a URL string if it exists.
*
* @template T - The string type to preserve type safety
* @param {T} url - The URL string to process
* @returns {T} The URL string without a trailing '#'
*
* @example
* ```ts
* withoutTrailingSharp('https://example.com#') // 'https://example.com'
* withoutTrailingSharp('https://example.com') // 'https://example.com'
* ```
*/
export function withoutTrailingSharp<T extends string>(url: T): T {
return url.replace(/#$/, '') as T
}
/**
* Formats an API host URL by normalizing it and optionally appending an API version.
*
* @param host - The API host URL to format. Leading/trailing whitespace will be trimmed and trailing slashes removed.
* @param supportApiVersion - Whether the API version is supported. Defaults to `true`.
* @param apiVersion - The API version to append if needed. Defaults to `'v1'`.
*
* @returns The formatted API host URL. If the host is empty after normalization, returns an empty string.
* If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host with trailing '#' removed.
* Otherwise, returns the host with the API version appended.
*
* @example
* formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1'
* formatApiHost('https://api.example.com#') // Returns 'https://api.example.com'
* formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2'
*/
export function formatApiHost(host?: string, supportApiVersion: boolean = true, apiVersion: string = 'v1'): string {
const normalizedHost = withoutTrailingSlash(trim(host))
if (!normalizedHost) {
return ''
}
const shouldAppendApiVersion = !(normalizedHost.endsWith('#') || !supportApiVersion || hasAPIVersion(normalizedHost))
if (shouldAppendApiVersion) {
return `${normalizedHost}/${apiVersion}`
} else {
return withoutTrailingSharp(normalizedHost)
}
}
/**
* Ollama API
*/
export function formatOllamaApiHost(host: string): string {
const normalizedHost = withoutTrailingSlash(host)
?.replace(/\/v1$/, '')
?.replace(/\/api$/, '')
?.replace(/\/chat$/, '')
return formatApiHost(normalizedHost + '/api', false)
}
/**
* Azure OpenAI API
*/
export function formatAzureOpenAIApiHost(host: string): string {
const normalizedHost = withoutTrailingSlash(host)
?.replace(/\/v1$/, '')
.replace(/\/openai$/, '')
// NOTE: AISDK会添加上`v1`
return formatApiHost(normalizedHost + '/openai', false)
}
export function formatVertexApiHost(provider: VertexProvider): string {
const { apiHost } = provider
const { projectId: project, location } = store.getState().llm.settings.vertexai
const trimmedHost = withoutTrailingSlash(trim(apiHost))
if (!trimmedHost || trimmedHost.endsWith('aiplatform.googleapis.com')) {
const host =
location == 'global' ? 'https://aiplatform.googleapis.com' : `https://${location}-aiplatform.googleapis.com`
return `${formatApiHost(host)}/projects/${project}/locations/${location}`
}
return formatApiHost(trimmedHost)
}
// 目前对话界面只支持这些端点
export const SUPPORTED_IMAGE_ENDPOINT_LIST = ['images/generations', 'images/edits', 'predict'] as const
export const SUPPORTED_ENDPOINT_LIST = [
'chat/completions',
'responses',
'messages',
'generateContent',
'streamGenerateContent',
...SUPPORTED_IMAGE_ENDPOINT_LIST
] as const
/**
* Converts an API host URL into separate base URL and endpoint components.
*
* @param apiHost - The API host string to parse. Expected to be a trimmed URL that may end with '#' followed by an endpoint identifier.
* @returns An object containing:
* - `baseURL`: The base URL without the endpoint suffix
* - `endpoint`: The matched endpoint identifier, or empty string if no match found
*
* @description
* This function extracts endpoint information from a composite API host string.
* If the host ends with '#', it attempts to match the preceding part against the supported endpoint list.
* The '#' delimiter is removed before processing.
*
* @example
* routeToEndpoint('https://api.example.com/openai/chat/completions#')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: 'chat/completions' }
*
* @example
* routeToEndpoint('https://api.example.com/v1')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: '' }
*/
export function routeToEndpoint(apiHost: string): { baseURL: string; endpoint: string } {
const trimmedHost = trim(apiHost)
// 前面已经确保apiHost合法
if (!trimmedHost.endsWith('#')) {
return { baseURL: trimmedHost, endpoint: '' }
}
// 去掉结尾的 #
const host = trimmedHost.slice(0, -1)
const endpointMatch = SUPPORTED_ENDPOINT_LIST.find((endpoint) => host.endsWith(endpoint))
if (!endpointMatch) {
const baseURL = withoutTrailingSlash(host)
return { baseURL, endpoint: '' }
}
const baseSegment = host.slice(0, host.length - endpointMatch.length)
const baseURL = withoutTrailingSlash(baseSegment).replace(/:$/, '') // 去掉结尾可能存在的冒号(gemini的特殊情况)
return { baseURL, endpoint: endpointMatch }
}
/**
* API
*
* @param {string} apiHost - API
* @returns {boolean} URL true false
*/
export function validateApiHost(apiHost: string): boolean {
// 允许apiHost为空
if (!apiHost || !trim(apiHost)) {
return true
}
try {
const url = new URL(trim(apiHost))
// 验证协议是否为 http 或 https
if (url.protocol !== 'http:' && url.protocol !== 'https:') {
return false
}
return true
} catch {
return false
}
}
/** /**
* API key * API key
* *
@ -272,50 +64,3 @@ export function splitApiKeyString(keyStr: string): string[] {
.map((k) => k.replace(/\\,/g, ',')) .map((k) => k.replace(/\\,/g, ','))
.filter((k) => k) .filter((k) => k)
} }
/**
* Extracts the trailing API version segment from a URL path.
*
* This function extracts API version patterns (e.g., `v1`, `v2beta`) from the end of a URL.
* Only versions at the end of the path are extracted, not versions in the middle.
* The returned version string does not include leading or trailing slashes.
*
* @param {string} url - The URL string to parse.
* @returns {string | undefined} The trailing API version found (e.g., 'v1', 'v2beta'), or undefined if none found.
*
* @example
* getTrailingApiVersion('https://api.example.com/v1') // 'v1'
* getTrailingApiVersion('https://api.example.com/v2beta/') // 'v2beta'
* getTrailingApiVersion('https://api.example.com/v1/chat') // undefined (version not at end)
* getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1/xxx/v1beta') // 'v1beta'
* getTrailingApiVersion('https://api.example.com') // undefined
*/
export function getTrailingApiVersion(url: string): string | undefined {
const match = url.match(TRAILING_VERSION_REGEX)
if (match) {
// Extract version without leading slash and trailing slash
return match[0].replace(/^\//, '').replace(/\/$/, '')
}
return undefined
}
/**
* Removes the trailing API version segment from a URL path.
*
* This function removes API version patterns (e.g., `/v1`, `/v2beta`) from the end of a URL.
* Only versions at the end of the path are removed, not versions in the middle.
*
* @param {string} url - The URL string to process.
* @returns {string} The URL with the trailing API version removed, or the original URL if no trailing version found.
*
* @example
* withoutTrailingApiVersion('https://api.example.com/v1') // 'https://api.example.com'
* withoutTrailingApiVersion('https://api.example.com/v2beta/') // 'https://api.example.com'
* withoutTrailingApiVersion('https://api.example.com/v1/chat') // 'https://api.example.com/v1/chat' (no change)
* withoutTrailingApiVersion('https://api.example.com') // 'https://api.example.com'
*/
export function withoutTrailingApiVersion(url: string): string {
return url.replace(TRAILING_VERSION_REGEX, '')
}

View File

@ -81,3 +81,57 @@ export const apiModelAdapter = (model: ApiModel): AdaptedApiModel => {
origin: model origin: model
} }
} }
/**
* Parse a model identifier in the format "provider:modelId"
* where modelId may contain additional colons (e.g., "openrouter:anthropic/claude-3.5-sonnet:free")
*
* @param modelIdentifier - The full model identifier string
* @returns Object with providerId and modelId. If no provider prefix found, providerId will be undefined
*
* @example
* parseModelId("openrouter:anthropic/claude-3.5-sonnet:free")
* // => { providerId: "openrouter", modelId: "anthropic/claude-3.5-sonnet:free" }
*
* @example
* parseModelId("anthropic:claude-3-sonnet")
* // => { providerId: "anthropic", modelId: "claude-3-sonnet" }
*
* @example
* parseModelId("claude-3-sonnet")
* // => { providerId: undefined, modelId: "claude-3-sonnet" }
*
* @example
* parseModelId("") // => undefined
*/
export function parseModelId(
modelIdentifier: string | undefined
): { providerId: string | undefined; modelId: string } | undefined {
if (!modelIdentifier || typeof modelIdentifier !== 'string' || modelIdentifier.trim() === '') {
return undefined
}
const colonIndex = modelIdentifier.indexOf(':')
// No colon found or colon at the start - treat entire string as modelId
if (colonIndex <= 0) {
return {
providerId: undefined,
modelId: modelIdentifier
}
}
// Colon at the end - treat everything before as modelId
if (colonIndex >= modelIdentifier.length - 1) {
return {
providerId: undefined,
modelId: modelIdentifier.substring(0, colonIndex)
}
}
// Standard format: "provider:modelId"
return {
providerId: modelIdentifier.substring(0, colonIndex),
modelId: modelIdentifier.substring(colonIndex + 1)
}
}

View File

@ -2,6 +2,8 @@ import { getProviderLabel } from '@renderer/i18n/label'
import type { Provider } from '@renderer/types' import type { Provider } from '@renderer/types'
import { isSystemProvider } from '@renderer/types' import { isSystemProvider } from '@renderer/types'
export { getBaseModelName, getLowerBaseModelName } from '@shared/utils/naming'
/** /**
* ID * ID
* *
@ -50,42 +52,6 @@ export const getDefaultGroupName = (id: string, provider?: string): string => {
return str return str
} }
/**
* ID
*
* - 'deepseek/deepseek-r1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1'
* @param {string} id ID
* @param {string} [delimiter='/'] '/'
* @returns {string}
*/
export const getBaseModelName = (id: string, delimiter: string = '/'): string => {
const parts = id.split(delimiter)
return parts[parts.length - 1]
}
/**
* ID
*
* - 'deepseek/DeepSeek-R1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1'
* @param {string} id ID
* @param {string} [delimiter='/'] '/'
* @returns {string}
*/
export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => {
const baseModelName = getBaseModelName(id, delimiter).toLowerCase()
// for openrouter
if (baseModelName.endsWith(':free')) {
return baseModelName.replace(':free', '')
}
// for cherryin
if (baseModelName.endsWith('(free)')) {
return baseModelName.replace('(free)', '')
}
return baseModelName
}
/** /**
* *
* @param provider * @param provider

View File

@ -1,10 +1,21 @@
import { CLAUDE_SUPPORTED_PROVIDERS } from '@renderer/pages/code' import { CLAUDE_SUPPORTED_PROVIDERS } from '@renderer/pages/code'
import type { AzureOpenAIProvider, ProviderType, VertexProvider } from '@renderer/types' import type { ProviderType } from '@renderer/types'
import { isSystemProvider, type Provider, type SystemProviderId, SystemProviderIds } from '@renderer/types' import { isSystemProvider, type Provider, type SystemProviderId, SystemProviderIds } from '@renderer/types'
export {
export const isAzureResponsesEndpoint = (provider: AzureOpenAIProvider) => { isAIGatewayProvider,
return provider.apiVersion === 'preview' || provider.apiVersion === 'v1' isAnthropicProvider,
} isAwsBedrockProvider,
isAzureOpenAIProvider,
isAzureResponsesEndpoint,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isOllamaProvider,
isOpenAICompatibleProvider,
isOpenAIProvider,
isPerplexityProvider,
isVertexProvider
} from '@shared/provider'
export const getClaudeSupportedProviders = (providers: Provider[]) => { export const getClaudeSupportedProviders = (providers: Provider[]) => {
return providers.filter( return providers.filter(
@ -127,59 +138,6 @@ export const isGeminiWebSearchProvider = (provider: Provider) => {
return SUPPORT_GEMINI_NATIVE_WEB_SEARCH_PROVIDERS.some((id) => id === provider.id) return SUPPORT_GEMINI_NATIVE_WEB_SEARCH_PROVIDERS.some((id) => id === provider.id)
} }
export const isNewApiProvider = (provider: Provider) => {
return ['new-api', 'cherryin'].includes(provider.id) || provider.type === 'new-api'
}
export function isCherryAIProvider(provider: Provider): boolean {
return provider.id === 'cherryai'
}
export function isPerplexityProvider(provider: Provider): boolean {
return provider.id === 'perplexity'
}
/**
* OpenAI
* @param {Provider} provider
* @returns {boolean} OpenAI
*/
export function isOpenAICompatibleProvider(provider: Provider): boolean {
return ['openai', 'new-api', 'mistral'].includes(provider.type)
}
export function isAzureOpenAIProvider(provider: Provider): provider is AzureOpenAIProvider {
return provider.type === 'azure-openai'
}
export function isOpenAIProvider(provider: Provider): boolean {
return provider.type === 'openai-response'
}
export function isVertexProvider(provider: Provider): provider is VertexProvider {
return provider.type === 'vertexai'
}
export function isAwsBedrockProvider(provider: Provider): boolean {
return provider.type === 'aws-bedrock'
}
export function isAnthropicProvider(provider: Provider): boolean {
return provider.type === 'anthropic'
}
export function isGeminiProvider(provider: Provider): boolean {
return provider.type === 'gemini'
}
export function isAIGatewayProvider(provider: Provider): boolean {
return provider.type === 'gateway'
}
export function isOllamaProvider(provider: Provider): boolean {
return provider.type === 'ollama'
}
const NOT_SUPPORT_API_VERSION_PROVIDERS = ['github', 'copilot', 'perplexity'] as const satisfies SystemProviderId[] const NOT_SUPPORT_API_VERSION_PROVIDERS = ['github', 'copilot', 'perplexity'] as const satisfies SystemProviderId[]
export const isSupportAPIVersionProvider = (provider: Provider) => { export const isSupportAPIVersionProvider = (provider: Provider) => {

View File

@ -61,7 +61,19 @@ vi.mock('electron', () => ({
getPrimaryDisplay: vi.fn(), getPrimaryDisplay: vi.fn(),
getAllDisplays: vi.fn() getAllDisplays: vi.fn()
}, },
Notification: vi.fn() Notification: vi.fn(),
net: {
fetch: vi.fn(() =>
Promise.resolve({
ok: true,
status: 200,
statusText: 'OK',
json: vi.fn(() => Promise.resolve({})),
text: vi.fn(() => Promise.resolve('')),
headers: new Headers()
})
)
}
})) }))
// Mock Winston for LoggerService dependencies // Mock Winston for LoggerService dependencies
@ -97,14 +109,39 @@ vi.mock('winston-daily-rotate-file', () => {
})) }))
}) })
// Mock main process services
vi.mock('@main/services/AnthropicService', () => ({
default: {}
}))
vi.mock('@main/services/CopilotService', () => ({
default: {}
}))
vi.mock('@main/services/ReduxService', () => ({
reduxService: {
selectSync: vi.fn()
}
}))
vi.mock('@main/integration/cherryai', () => ({
generateSignature: vi.fn()
}))
// Mock Node.js modules // Mock Node.js modules
vi.mock('node:os', () => ({ vi.mock('node:os', async () => {
const actual = await vi.importActual<typeof import('node:os')>('node:os')
return {
...actual,
default: actual,
platform: vi.fn(() => 'darwin'), platform: vi.fn(() => 'darwin'),
arch: vi.fn(() => 'x64'), arch: vi.fn(() => 'x64'),
version: vi.fn(() => '20.0.0'), version: vi.fn(() => '20.0.0'),
cpus: vi.fn(() => [{ model: 'Mock CPU' }]), cpus: vi.fn(() => [{ model: 'Mock CPU' }]),
totalmem: vi.fn(() => 8 * 1024 * 1024 * 1024) // 8GB totalmem: vi.fn(() => 8 * 1024 * 1024 * 1024), // 8GB
})) homedir: vi.fn(() => '/tmp')
}
})
vi.mock('node:path', async () => { vi.mock('node:path', async () => {
const actual = await vi.importActual('node:path') const actual = await vi.importActual('node:path')

View File

@ -8,8 +8,10 @@
"src/preload/**/*", "src/preload/**/*",
"src/renderer/src/services/traceApi.ts", "src/renderer/src/services/traceApi.ts",
"src/renderer/src/types/*", "src/renderer/src/types/*",
"packages/aiCore/src/**/*",
"packages/mcp-trace/**/*", "packages/mcp-trace/**/*",
"packages/shared/**/*", "packages/shared/**/*",
"packages/ai-sdk-provider/**/*"
], ],
"compilerOptions": { "compilerOptions": {
"composite": true, "composite": true,
@ -26,7 +28,12 @@
"@types": ["./src/renderer/src/types/index.ts"], "@types": ["./src/renderer/src/types/index.ts"],
"@shared/*": ["./packages/shared/*"], "@shared/*": ["./packages/shared/*"],
"@mcp-trace/*": ["./packages/mcp-trace/*"], "@mcp-trace/*": ["./packages/mcp-trace/*"],
"@modelcontextprotocol/sdk/*": ["./node_modules/@modelcontextprotocol/sdk/dist/esm/*"] "@modelcontextprotocol/sdk/*": ["./node_modules/@modelcontextprotocol/sdk/dist/esm/*"],
"@cherrystudio/ai-core/provider": ["./packages/aiCore/src/core/providers/index.ts"],
"@cherrystudio/ai-core/built-in/plugins": ["./packages/aiCore/src/core/plugins/built-in/index.ts"],
"@cherrystudio/ai-core/*": ["./packages/aiCore/src/*"],
"@cherrystudio/ai-core": ["./packages/aiCore/src/index.ts"],
"@cherrystudio/ai-sdk-provider": ["./packages/ai-sdk-provider/src/index.ts"]
}, },
"experimentalDecorators": true, "experimentalDecorators": true,
"emitDecoratorMetadata": true, "emitDecoratorMetadata": true,