feat: Implement shared provider utilities and API host formatting

- Added provider API host formatting utilities to handle differences between Cherry Studio and AI SDK.
- Introduced functions for formatting provider API hosts, including support for Azure OpenAI and Vertex AI.
- Created a simple API key rotator for managing API key rotation.
- Developed shared provider initialization and mapping utilities for resolving provider IDs.
- Implemented AI SDK configuration utilities for converting Cherry Studio providers to AI SDK configurations.
- Added support for various providers including OpenRouter, Google Vertex AI, and Amazon Bedrock.
- Enhanced error handling and logging in the unified messages service for better debugging.
- Introduced functions for streaming and generating unified messages using AI SDK.
This commit is contained in:
suyao 2025-11-27 15:30:19 +08:00
parent eb4670c22c
commit a5e7aa1342
No known key found for this signature in database
38 changed files with 2681 additions and 896 deletions

View File

@ -0,0 +1,593 @@
/**
* AI SDK to Anthropic SSE Adapter
*
* Converts AI SDK's fullStream (TextStreamPart) events to Anthropic Messages API SSE format.
* This enables any AI provider supported by AI SDK to be exposed via Anthropic-compatible API.
*
* Anthropic SSE Event Flow:
* 1. message_start - Initial message with metadata
* 2. content_block_start - Begin a content block (text, tool_use, thinking)
* 3. content_block_delta - Incremental content updates
* 4. content_block_stop - End a content block
* 5. message_delta - Updates to overall message (stop_reason, usage)
* 6. message_stop - Stream complete
*
* @see https://docs.anthropic.com/en/api/messages-streaming
*/
import type {
ContentBlock,
InputJSONDelta,
Message,
MessageDeltaUsage,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
RawMessageDeltaEvent,
RawMessageStartEvent,
RawMessageStopEvent,
RawMessageStreamEvent,
StopReason,
TextBlock,
TextDelta,
ThinkingBlock,
ThinkingDelta,
ToolUseBlock,
Usage
} from '@anthropic-ai/sdk/resources/messages'
import { loggerService } from '@logger'
import type { TextStreamPart, ToolSet } from 'ai'
const logger = loggerService.withContext('AiSdkToAnthropicSSE')
interface ContentBlockState {
type: 'text' | 'tool_use' | 'thinking'
index: number
started: boolean
content: string
// For tool_use blocks
toolId?: string
toolName?: string
toolInput?: string
}
interface AdapterState {
messageId: string
model: string
inputTokens: number
outputTokens: number
currentBlockIndex: number
blocks: Map<number, ContentBlockState>
textBlockIndex: number | null
thinkingBlockIndex: number | null
toolBlocks: Map<string, number> // toolCallId -> blockIndex
stopReason: StopReason | null
hasEmittedMessageStart: boolean
}
// ============================================================================
// Adapter Class
// ============================================================================
export type SSEEventCallback = (event: RawMessageStreamEvent) => void
export interface AiSdkToAnthropicSSEOptions {
model: string
messageId?: string
inputTokens?: number
onEvent: SSEEventCallback
}
/**
* Adapter that converts AI SDK fullStream events to Anthropic SSE events
*/
export class AiSdkToAnthropicSSE {
private state: AdapterState
private onEvent: SSEEventCallback
constructor(options: AiSdkToAnthropicSSEOptions) {
this.onEvent = options.onEvent
this.state = {
messageId: options.messageId || `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`,
model: options.model,
inputTokens: options.inputTokens || 0,
outputTokens: 0,
currentBlockIndex: 0,
blocks: new Map(),
textBlockIndex: null,
thinkingBlockIndex: null,
toolBlocks: new Map(),
stopReason: null,
hasEmittedMessageStart: false
}
}
/**
* Process the AI SDK stream and emit Anthropic SSE events
*/
async processStream(fullStream: ReadableStream<TextStreamPart<ToolSet>>): Promise<void> {
const reader = fullStream.getReader()
try {
// Emit message_start at the beginning
this.emitMessageStart()
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
this.processChunk(value)
}
// Ensure all blocks are closed and emit final events
this.finalize()
} finally {
reader.releaseLock()
}
}
/**
* Process a single AI SDK chunk and emit corresponding Anthropic events
*/
private processChunk(chunk: TextStreamPart<ToolSet>): void {
logger.silly('AiSdkToAnthropicSSE - Processing chunk:', chunk)
switch (chunk.type) {
// === Text Events ===
case 'text-start':
this.startTextBlock()
break
case 'text-delta':
this.emitTextDelta(chunk.text || '')
break
case 'text-end':
this.stopTextBlock()
break
// === Reasoning/Thinking Events ===
case 'reasoning-start':
this.startThinkingBlock()
break
case 'reasoning-delta':
this.emitThinkingDelta(chunk.text || '')
break
case 'reasoning-end':
this.stopThinkingBlock()
break
// === Tool Events ===
case 'tool-call':
this.handleToolCall({
type: 'tool-call',
toolCallId: chunk.toolCallId,
toolName: chunk.toolName,
// AI SDK uses 'args' in some versions and 'input' in others
args: 'args' in chunk ? chunk.args : (chunk as any).input
})
break
case 'tool-result':
// Tool results are handled separately in Anthropic API
// They come from user messages, not assistant stream
break
// === Completion Events ===
case 'finish-step':
if (chunk.finishReason === 'tool-calls') {
this.state.stopReason = 'tool_use'
}
break
case 'finish':
this.handleFinish(chunk)
break
// === Error Events ===
case 'error':
// Anthropic doesn't have a standard error event in the stream
// Errors are typically sent as separate HTTP responses
// For now, we'll just log and continue
break
// Ignore other event types
default:
break
}
}
private emitMessageStart(): void {
if (this.state.hasEmittedMessageStart) return
this.state.hasEmittedMessageStart = true
const usage: Usage = {
input_tokens: this.state.inputTokens,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
server_tool_use: null
}
const message: Message = {
id: this.state.messageId,
type: 'message',
role: 'assistant',
content: [],
model: this.state.model,
stop_reason: null,
stop_sequence: null,
usage
}
const event: RawMessageStartEvent = {
type: 'message_start',
message
}
this.onEvent(event)
}
private startTextBlock(): void {
// If we already have a text block, don't create another
if (this.state.textBlockIndex !== null) return
const index = this.state.currentBlockIndex++
this.state.textBlockIndex = index
this.state.blocks.set(index, {
type: 'text',
index,
started: true,
content: ''
})
const contentBlock: TextBlock = {
type: 'text',
text: '',
citations: null
}
const event: RawContentBlockStartEvent = {
type: 'content_block_start',
index,
content_block: contentBlock
}
this.onEvent(event)
}
private emitTextDelta(text: string): void {
if (!text) return
// Auto-start text block if not started
if (this.state.textBlockIndex === null) {
this.startTextBlock()
}
const index = this.state.textBlockIndex!
const block = this.state.blocks.get(index)
if (block) {
block.content += text
}
const delta: TextDelta = {
type: 'text_delta',
text
}
const event: RawContentBlockDeltaEvent = {
type: 'content_block_delta',
index,
delta
}
this.onEvent(event)
}
private stopTextBlock(): void {
if (this.state.textBlockIndex === null) return
const index = this.state.textBlockIndex
const event: RawContentBlockStopEvent = {
type: 'content_block_stop',
index
}
this.onEvent(event)
this.state.textBlockIndex = null
}
private startThinkingBlock(): void {
if (this.state.thinkingBlockIndex !== null) return
const index = this.state.currentBlockIndex++
this.state.thinkingBlockIndex = index
this.state.blocks.set(index, {
type: 'thinking',
index,
started: true,
content: ''
})
const contentBlock: ThinkingBlock = {
type: 'thinking',
thinking: '',
signature: ''
}
const event: RawContentBlockStartEvent = {
type: 'content_block_start',
index,
content_block: contentBlock
}
this.onEvent(event)
}
private emitThinkingDelta(text: string): void {
if (!text) return
// Auto-start thinking block if not started
if (this.state.thinkingBlockIndex === null) {
this.startThinkingBlock()
}
const index = this.state.thinkingBlockIndex!
const block = this.state.blocks.get(index)
if (block) {
block.content += text
}
const delta: ThinkingDelta = {
type: 'thinking_delta',
thinking: text
}
const event: RawContentBlockDeltaEvent = {
type: 'content_block_delta',
index,
delta
}
this.onEvent(event)
}
private stopThinkingBlock(): void {
if (this.state.thinkingBlockIndex === null) return
const index = this.state.thinkingBlockIndex
const event: RawContentBlockStopEvent = {
type: 'content_block_stop',
index
}
this.onEvent(event)
this.state.thinkingBlockIndex = null
}
private handleToolCall(chunk: { type: 'tool-call'; toolCallId: string; toolName: string; args: unknown }): void {
const { toolCallId, toolName, args } = chunk
// Check if we already have this tool call
if (this.state.toolBlocks.has(toolCallId)) {
return
}
const index = this.state.currentBlockIndex++
this.state.toolBlocks.set(toolCallId, index)
const inputJson = JSON.stringify(args)
this.state.blocks.set(index, {
type: 'tool_use',
index,
started: true,
content: inputJson,
toolId: toolCallId,
toolName,
toolInput: inputJson
})
// Emit content_block_start for tool_use
const contentBlock: ToolUseBlock = {
type: 'tool_use',
id: toolCallId,
name: toolName,
input: {}
}
const startEvent: RawContentBlockStartEvent = {
type: 'content_block_start',
index,
content_block: contentBlock
}
this.onEvent(startEvent)
// Emit the full input as a delta (Anthropic streams JSON incrementally)
const delta: InputJSONDelta = {
type: 'input_json_delta',
partial_json: inputJson
}
const deltaEvent: RawContentBlockDeltaEvent = {
type: 'content_block_delta',
index,
delta
}
this.onEvent(deltaEvent)
// Emit content_block_stop
const stopEvent: RawContentBlockStopEvent = {
type: 'content_block_stop',
index
}
this.onEvent(stopEvent)
// Mark that we have tool use
this.state.stopReason = 'tool_use'
}
private handleFinish(chunk: {
type: 'finish'
finishReason?: string
totalUsage?: {
inputTokens?: number
outputTokens?: number
}
}): void {
// Update usage
if (chunk.totalUsage) {
this.state.inputTokens = chunk.totalUsage.inputTokens || 0
this.state.outputTokens = chunk.totalUsage.outputTokens || 0
}
// Determine finish reason
if (!this.state.stopReason) {
switch (chunk.finishReason) {
case 'stop':
case 'end_turn':
this.state.stopReason = 'end_turn'
break
case 'length':
case 'max_tokens':
this.state.stopReason = 'max_tokens'
break
case 'tool-calls':
this.state.stopReason = 'tool_use'
break
default:
this.state.stopReason = 'end_turn'
}
}
}
private finalize(): void {
// Close any open blocks
if (this.state.textBlockIndex !== null) {
this.stopTextBlock()
}
if (this.state.thinkingBlockIndex !== null) {
this.stopThinkingBlock()
}
// Emit message_delta with final stop reason and usage
const usage: MessageDeltaUsage = {
output_tokens: this.state.outputTokens,
input_tokens: null,
cache_creation_input_tokens: null,
cache_read_input_tokens: null,
server_tool_use: null
}
const messageDeltaEvent: RawMessageDeltaEvent = {
type: 'message_delta',
delta: {
stop_reason: this.state.stopReason || 'end_turn',
stop_sequence: null
},
usage
}
this.onEvent(messageDeltaEvent)
// Emit message_stop
const messageStopEvent: RawMessageStopEvent = {
type: 'message_stop'
}
this.onEvent(messageStopEvent)
}
/**
* Set input token count (typically from prompt)
*/
setInputTokens(count: number): void {
this.state.inputTokens = count
}
/**
* Get the current message ID
*/
getMessageId(): string {
return this.state.messageId
}
/**
* Build a complete Message object for non-streaming responses
*/
buildNonStreamingResponse(): Message {
const content: ContentBlock[] = []
// Collect all content blocks in order
const sortedBlocks = Array.from(this.state.blocks.values()).sort((a, b) => a.index - b.index)
for (const block of sortedBlocks) {
switch (block.type) {
case 'text':
content.push({
type: 'text',
text: block.content,
citations: null
} as TextBlock)
break
case 'thinking':
content.push({
type: 'thinking',
thinking: block.content
} as ThinkingBlock)
break
case 'tool_use':
content.push({
type: 'tool_use',
id: block.toolId!,
name: block.toolName!,
input: JSON.parse(block.toolInput || '{}')
} as ToolUseBlock)
break
}
}
return {
id: this.state.messageId,
type: 'message',
role: 'assistant',
content,
model: this.state.model,
stop_reason: this.state.stopReason || 'end_turn',
stop_sequence: null,
usage: {
input_tokens: this.state.inputTokens,
output_tokens: this.state.outputTokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
server_tool_use: null
}
}
}
}
/**
* Format an Anthropic SSE event for HTTP streaming
*/
export function formatSSEEvent(event: RawMessageStreamEvent): string {
return `event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n`
}
/**
* Create a done marker for SSE stream
*/
export function formatSSEDone(): string {
return 'data: [DONE]\n\n'
}
export default AiSdkToAnthropicSSE

View File

@ -0,0 +1,13 @@
/**
* Shared Adapters
*
* This module exports adapters for converting between different AI API formats.
*/
export {
AiSdkToAnthropicSSE,
type AiSdkToAnthropicSSEOptions,
formatSSEDone,
formatSSEEvent,
type SSEEventCallback
} from './AiSdkToAnthropicSSE'

View File

@ -0,0 +1,173 @@
/**
* Shared API Utilities
*
* Common utilities for API URL formatting and validation.
* Used by both main process (API Server) and renderer.
*/
import type { MinimalProvider } from '@shared/provider'
import { trim } from 'lodash'
// Supported endpoints for routing
export const SUPPORTED_IMAGE_ENDPOINT_LIST = ['images/generations', 'images/edits', 'predict'] as const
export const SUPPORTED_ENDPOINT_LIST = [
'chat/completions',
'responses',
'messages',
'generateContent',
'streamGenerateContent',
...SUPPORTED_IMAGE_ENDPOINT_LIST
] as const
/**
* Removes the trailing slash from a URL string if it exists.
*/
export function withoutTrailingSlash<T extends string>(url: T): T {
return url.replace(/\/$/, '') as T
}
/**
* Checks if the host path contains a version string (e.g., /v1, /v2beta).
*/
export function hasAPIVersion(host?: string): boolean {
if (!host) return false
const versionRegex = /\/v\d+(?:alpha|beta)?(?=\/|$)/i
try {
const url = new URL(host)
return versionRegex.test(url.pathname)
} catch {
return versionRegex.test(host)
}
}
/**
* Azure OpenAI API
*/
export function formatAzureOpenAIApiHost(host: string): string {
const normalizedHost = withoutTrailingSlash(host)
?.replace(/\/v1$/, '')
.replace(/\/openai$/, '')
// NOTE: AISDK会添加上`v1`
return formatApiHost(normalizedHost + '/openai', false)
}
export function formatVertexApiHost(provider: MinimalProvider, project: string, location: string): string {
const { apiHost } = provider
const trimmedHost = withoutTrailingSlash(trim(apiHost))
if (!trimmedHost || trimmedHost.endsWith('aiplatform.googleapis.com')) {
const host =
location == 'global' ? 'https://aiplatform.googleapis.com' : `https://${location}-aiplatform.googleapis.com`
return `${formatApiHost(host)}/projects/${project}/locations/${location}`
}
return formatApiHost(trimmedHost)
}
/**
* Formats an API host URL by normalizing it and optionally appending an API version.
*
* @param host - The API host URL to format
* @param isSupportedAPIVersion - Whether the API version is supported. Defaults to `true`.
* @param apiVersion - The API version to append if needed. Defaults to `'v1'`.
*
* @example
* formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1'
* formatApiHost('https://api.example.com#') // Returns 'https://api.example.com#'
* formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2'
*/
export function formatApiHost(host?: string, isSupportedAPIVersion: boolean = true, apiVersion: string = 'v1'): string {
const normalizedHost = withoutTrailingSlash((host || '').trim())
if (!normalizedHost) {
return ''
}
if (normalizedHost.endsWith('#') || !isSupportedAPIVersion || hasAPIVersion(normalizedHost)) {
return normalizedHost
}
return `${normalizedHost}/${apiVersion}`
}
/**
* Converts an API host URL into separate base URL and endpoint components.
*
* This function extracts endpoint information from a composite API host string.
* If the host ends with '#', it attempts to match the preceding part against the supported endpoint list.
*
* @param apiHost - The API host string to parse
* @returns An object containing:
* - `baseURL`: The base URL without the endpoint suffix
* - `endpoint`: The matched endpoint identifier, or empty string if no match found
*
* @example
* routeToEndpoint('https://api.example.com/openai/chat/completions#')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: 'chat/completions' }
*
* @example
* routeToEndpoint('https://api.example.com/v1')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: '' }
*/
export function routeToEndpoint(apiHost: string): { baseURL: string; endpoint: string } {
const trimmedHost = (apiHost || '').trim()
if (!trimmedHost.endsWith('#')) {
return { baseURL: trimmedHost, endpoint: '' }
}
// Remove trailing #
const host = trimmedHost.slice(0, -1)
const endpointMatch = SUPPORTED_ENDPOINT_LIST.find((endpoint) => host.endsWith(endpoint))
if (!endpointMatch) {
const baseURL = withoutTrailingSlash(host)
return { baseURL, endpoint: '' }
}
const baseSegment = host.slice(0, host.length - endpointMatch.length)
const baseURL = withoutTrailingSlash(baseSegment).replace(/:$/, '') // Remove trailing colon (gemini special case)
return { baseURL, endpoint: endpointMatch }
}
/**
* Gets the AI SDK compatible base URL from a provider's apiHost.
*
* AI SDK expects baseURL WITH version suffix (e.g., /v1).
* This function:
* 1. Handles '#' endpoint routing format
* 2. Ensures the URL has a version suffix (adds /v1 if missing)
*
* @param apiHost - The provider's apiHost value (may or may not have /v1)
* @param apiVersion - The API version to use if missing. Defaults to 'v1'.
* @returns The baseURL suitable for AI SDK (with version suffix)
*
* @example
* getAiSdkBaseUrl('https://api.openai.com') // 'https://api.openai.com/v1'
* getAiSdkBaseUrl('https://api.openai.com/v1') // 'https://api.openai.com/v1'
* getAiSdkBaseUrl('https://api.example.com/chat/completions#') // 'https://api.example.com'
*/
export function getAiSdkBaseUrl(apiHost: string, apiVersion: string = 'v1'): string {
// First handle '#' endpoint routing format
const { baseURL } = routeToEndpoint(apiHost)
// If already has version, return as-is
if (hasAPIVersion(baseURL)) {
return withoutTrailingSlash(baseURL)
}
// Add version suffix
return `${withoutTrailingSlash(baseURL)}/${apiVersion}`
}
/**
* Validates an API host address.
*
* @param apiHost - The API host address to validate
* @returns true if valid URL with http/https protocol, false otherwise
*/
export function validateApiHost(apiHost: string): boolean {
if (!apiHost || !apiHost.trim()) {
return true // Allow empty
}
try {
const url = new URL(apiHost.trim())
return url.protocol === 'http:' || url.protocol === 'https:'
} catch {
return false
}
}

View File

@ -1,13 +1,13 @@
/**
* AiHubMix规则集
*/
import { isOpenAILLMModel } from '@renderer/config/models'
import type { Provider } from '@renderer/types'
import { getLowerBaseModelName } from '@shared/utils/naming'
import type { MinimalModel, MinimalProvider } from '../types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
const extraProviderConfig = (provider: Provider) => {
const extraProviderConfig = <P extends MinimalProvider>(provider: P) => {
return {
...provider,
extra_headers: {
@ -17,11 +17,23 @@ const extraProviderConfig = (provider: Provider) => {
}
}
function isOpenAILLMModel<M extends MinimalModel>(model: M): boolean {
const modelId = getLowerBaseModelName(model.id)
const reasonings = ['o1', 'o3', 'o4', 'gpt-oss']
if (reasonings.some((r) => modelId.includes(r))) {
return true
}
if (modelId.includes('gpt')) {
return true
}
return false
}
const AIHUBMIX_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => {
provider: (provider) => {
return extraProviderConfig({
...provider,
type: 'anthropic'
@ -34,7 +46,7 @@ const AIHUBMIX_RULES: RuleSet = {
!model.id.endsWith('-nothink') &&
!model.id.endsWith('-search') &&
!model.id.includes('embedding'),
provider: (provider: Provider) => {
provider: (provider) => {
return extraProviderConfig({
...provider,
type: 'gemini',
@ -44,7 +56,7 @@ const AIHUBMIX_RULES: RuleSet = {
},
{
match: isOpenAILLMModel,
provider: (provider: Provider) => {
provider: (provider) => {
return extraProviderConfig({
...provider,
type: 'openai-response'
@ -52,7 +64,8 @@ const AIHUBMIX_RULES: RuleSet = {
}
}
],
fallbackRule: (provider: Provider) => extraProviderConfig(provider)
fallbackRule: (provider) => extraProviderConfig(provider)
}
export const aihubmixProviderCreator = provider2Provider.bind(null, AIHUBMIX_RULES)
export const aihubmixProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(AIHUBMIX_RULES, model, provider)

View File

@ -0,0 +1,22 @@
import type { MinimalModel, MinimalProvider, ProviderType } from '../types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
const AZURE_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: MinimalProvider) => ({
...provider,
type: 'anthropic' as ProviderType,
apiHost: provider.apiHost + 'anthropic/v1',
id: 'azure-anthropic'
})
}
],
fallbackRule: (provider: MinimalProvider) => provider
}
export const azureAnthropicProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(AZURE_ANTHROPIC_RULES, model, provider)

View File

@ -0,0 +1,32 @@
import type { MinimalModel, MinimalProvider } from '../types'
import type { RuleSet } from './types'
export const startsWith =
(prefix: string) =>
<M extends MinimalModel>(model: M) =>
model.id.toLowerCase().startsWith(prefix.toLowerCase())
export const endpointIs =
(type: string) =>
<M extends MinimalModel>(model: M) =>
model.endpoint_type === type
/**
* Provider
* @param ruleSet
* @param model
* @param provider provider对象
* @returns provider对象
*/
export function provider2Provider<
M extends MinimalModel,
R extends MinimalProvider,
P extends R = R
>(ruleSet: RuleSet<M, R>, model: M, provider: P): P {
for (const rule of ruleSet.rules) {
if (rule.match(model)) {
return rule.provider(provider) as P
}
}
return ruleSet.fallbackRule(provider) as P
}

View File

@ -0,0 +1,6 @@
export { aihubmixProviderCreator } from './aihubmix'
export { azureAnthropicProviderCreator } from './azure-anthropic'
export { endpointIs, provider2Provider, startsWith } from './helper'
export { newApiResolverCreator } from './newApi'
export type { RuleSet } from './types'
export { vertexAnthropicProviderCreator } from './vertex-anthropic'

View File

@ -1,8 +1,7 @@
/**
* NewAPI规则集
*/
import type { Provider } from '@renderer/types'
import type { MinimalModel, MinimalProvider, ProviderType } from '../types'
import { endpointIs, provider2Provider } from './helper'
import type { RuleSet } from './types'
@ -10,42 +9,43 @@ const NEWAPI_RULES: RuleSet = {
rules: [
{
match: endpointIs('anthropic'),
provider: (provider: Provider) => {
provider: (provider) => {
return {
...provider,
type: 'anthropic'
type: 'anthropic' as ProviderType
}
}
},
{
match: endpointIs('gemini'),
provider: (provider: Provider) => {
provider: (provider) => {
return {
...provider,
type: 'gemini'
type: 'gemini' as ProviderType
}
}
},
{
match: endpointIs('openai-response'),
provider: (provider: Provider) => {
provider: (provider) => {
return {
...provider,
type: 'openai-response'
type: 'openai-response' as ProviderType
}
}
},
{
match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model),
provider: (provider: Provider) => {
provider: (provider) => {
return {
...provider,
type: 'openai'
type: 'openai' as ProviderType
}
}
}
],
fallbackRule: (provider: Provider) => provider
fallbackRule: (provider) => provider
}
export const newApiResolverCreator = provider2Provider.bind(null, NEWAPI_RULES)
export const newApiResolverCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(NEWAPI_RULES, model, provider)

View File

@ -0,0 +1,9 @@
import type { MinimalModel, MinimalProvider } from '../types'
export interface RuleSet<M extends MinimalModel = MinimalModel, P extends MinimalProvider = MinimalProvider> {
rules: Array<{
match: (model: M) => boolean
provider: (provider: P) => P
}>
fallbackRule: (provider: P) => P
}

View File

@ -0,0 +1,19 @@
import type { MinimalModel, MinimalProvider } from '../types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
const VERTEX_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: MinimalProvider) => ({
...provider,
id: 'google-vertex-anthropic'
})
}
],
fallbackRule: (provider: MinimalProvider) => provider
}
export const vertexAnthropicProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(VERTEX_ANTHROPIC_RULES, model, provider)

View File

@ -0,0 +1,100 @@
/**
* Provider Type Detection Utilities
*
* Functions to detect provider types based on provider configuration.
* These are pure functions that only depend on provider.type and provider.id.
*
* NOTE: These functions should match the logic in @renderer/utils/provider.ts
*/
import type { MinimalProvider } from './types'
/**
* Check if provider is Anthropic type
*/
export function isAnthropicProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'anthropic'
}
/**
* Check if provider is OpenAI Response type (openai-response)
* NOTE: This matches isOpenAIProvider in renderer/utils/provider.ts
*/
export function isOpenAIProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'openai-response'
}
/**
* Check if provider is Gemini type
*/
export function isGeminiProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'gemini'
}
/**
* Check if provider is Azure OpenAI type
*/
export function isAzureOpenAIProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'azure-openai'
}
/**
* Check if provider is Vertex AI type
*/
export function isVertexProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'vertexai'
}
/**
* Check if provider is AWS Bedrock type
*/
export function isAwsBedrockProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'aws-bedrock'
}
/**
* Check if provider is AI Gateway type
*/
export function isAIGatewayProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'ai-gateway'
}
/**
* Check if Azure OpenAI provider uses responses endpoint
* Matches isAzureResponsesEndpoint in renderer/utils/provider.ts
*/
export function isAzureResponsesEndpoint<P extends MinimalProvider>(provider: P): boolean {
return provider.apiVersion === 'preview' || provider.apiVersion === 'v1'
}
/**
* Check if provider is Cherry AI type
* Matches isCherryAIProvider in renderer/utils/provider.ts
*/
export function isCherryAIProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.id === 'cherryai'
}
/**
* Check if provider is Perplexity type
* Matches isPerplexityProvider in renderer/utils/provider.ts
*/
export function isPerplexityProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.id === 'perplexity'
}
/**
* Check if provider is new-api type (supports multiple backends)
* Matches isNewApiProvider in renderer/utils/provider.ts
*/
export function isNewApiProvider<P extends MinimalProvider>(provider: P): boolean {
return ['new-api', 'cherryin'].includes(provider.id) || provider.type === ('new-api' as string)
}
/**
* Check if provider is OpenAI compatible
* Matches isOpenAICompatibleProvider in renderer/utils/provider.ts
*/
export function isOpenAICompatibleProvider<P extends MinimalProvider>(provider: P): boolean {
return ['openai', 'new-api', 'mistral'].includes(provider.type)
}

View File

@ -0,0 +1,136 @@
/**
* Provider API Host Formatting
*
* Utilities for formatting provider API hosts to work with AI SDK.
* These handle the differences between how Cherry Studio stores API hosts
* and how AI SDK expects them.
*/
import {
formatApiHost,
formatAzureOpenAIApiHost,
formatVertexApiHost,
routeToEndpoint,
withoutTrailingSlash
} from '../api'
import {
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isPerplexityProvider,
isVertexProvider
} from './detection'
import type { MinimalProvider } from './types'
import { SystemProviderIds } from './types'
/**
* Interface for environment-specific implementations
* Renderer and Main process can provide their own implementations
*/
export interface ProviderFormatContext {
vertex: {
project: string
location: string
}
}
/**
* Default Azure OpenAI API host formatter
*/
export function defaultFormatAzureOpenAIApiHost(host: string): string {
const normalizedHost = withoutTrailingSlash(host)
?.replace(/\/v1$/, '')
.replace(/\/openai$/, '')
// AI SDK will add /v1
return formatApiHost(normalizedHost + '/openai', false)
}
/**
* Format provider API host for AI SDK
*
* This function normalizes the apiHost to work with AI SDK.
* Different providers have different requirements:
* - Most providers: add /v1 suffix
* - Gemini: add /v1beta suffix
* - Some providers: no suffix needed
*
* @param provider - The provider to format
* @param context - Optional context with environment-specific implementations
* @returns Provider with formatted apiHost (and anthropicApiHost if applicable)
*/
export function formatProviderApiHost<T extends MinimalProvider>(provider: T, context: ProviderFormatContext): T {
const formatted = { ...provider }
// Format anthropicApiHost if present
if (formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost)
}
// Format based on provider type
if (isAnthropicProvider(provider)) {
const baseHost = formatted.anthropicApiHost || formatted.apiHost
// AI SDK needs /v1 in baseURL
formatted.apiHost = formatApiHost(baseHost)
if (!formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatted.apiHost
}
} else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isGeminiProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta')
} else if (isAzureOpenAIProvider(formatted)) {
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
} else if (isVertexProvider(formatted)) {
formatted.apiHost = formatVertexApiHost(formatted, context.vertex.project, context.vertex.location)
} else if (isCherryAIProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isPerplexityProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else {
formatted.apiHost = formatApiHost(formatted.apiHost)
}
return formatted
}
/**
* Get the base URL for AI SDK from a formatted provider
*
* This extracts the baseURL that AI SDK expects, handling
* the '#' endpoint routing format if present.
*
* @param formattedApiHost - The formatted apiHost (after formatProviderApiHost)
* @returns The baseURL for AI SDK
*/
export function getBaseUrlForAiSdk(formattedApiHost: string): string {
const { baseURL } = routeToEndpoint(formattedApiHost)
return baseURL
}
/**
* Get rotated API key from comma-separated keys
*
* This is the interface for API key rotation. The actual implementation
* depends on the environment (renderer uses window.keyv, main uses its own storage).
*/
export interface ApiKeyRotator {
/**
* Get the next API key in rotation
* @param providerId - The provider ID for tracking rotation
* @param keys - Comma-separated API keys
* @returns The next API key to use
*/
getRotatedKey(providerId: string, keys: string): string
}
/**
* Simple API key rotator that always returns the first key
* Use this when rotation is not needed
*/
export const simpleKeyRotator: ApiKeyRotator = {
getRotatedKey(_providerId: string, keys: string): string {
const keyList = keys.split(',').map((k) => k.trim())
return keyList[0] || keys
}
}

View File

@ -0,0 +1,48 @@
/**
* Shared Provider Utilities
*
* This module exports utilities for working with AI providers
* that can be shared between main process and renderer process.
*/
// Type definitions
export type { MinimalProvider, ProviderType, SystemProviderId } from './types'
export { SystemProviderIds } from './types'
// Provider type detection
export {
isAIGatewayProvider,
isAnthropicProvider,
isAwsBedrockProvider,
isAzureOpenAIProvider,
isAzureResponsesEndpoint,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isOpenAICompatibleProvider,
isOpenAIProvider,
isPerplexityProvider,
isVertexProvider
} from './detection'
// API host formatting
export type { ApiKeyRotator, ProviderFormatContext } from './format'
export {
defaultFormatAzureOpenAIApiHost,
formatProviderApiHost,
getBaseUrlForAiSdk,
simpleKeyRotator
} from './format'
// Provider ID mapping
export { getAiSdkProviderId, STATIC_PROVIDER_MAPPING, tryResolveProviderId } from './mapping'
// AI SDK configuration
export type { AiSdkConfig, AiSdkConfigContext } from './sdk-config'
export { providerToAiSdkConfig } from './sdk-config'
// Provider resolution
export { resolveActualProvider } from './resolve'
// Provider initialization
export { initializeSharedProviders, SHARED_PROVIDER_CONFIGS } from './initialization'

View File

@ -0,0 +1,107 @@
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
type ProviderInitializationLogger = {
warn?: (message: string) => void
error?: (message: string, error: Error) => void
}
export const SHARED_PROVIDER_CONFIGS: ProviderConfig[] = [
{
id: 'openrouter',
name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: true,
aliases: ['openrouter']
},
{
id: 'google-vertex',
name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex/edge'),
creatorFunctionName: 'createVertex',
supportsImageGeneration: true,
aliases: ['vertexai']
},
{
id: 'google-vertex-anthropic',
name: 'Google Vertex AI Anthropic',
import: () => import('@ai-sdk/google-vertex/anthropic/edge'),
creatorFunctionName: 'createVertexAnthropic',
supportsImageGeneration: true,
aliases: ['vertexai-anthropic']
},
{
id: 'azure-anthropic',
name: 'Azure AI Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false,
aliases: ['azure-anthropic']
},
{
id: 'github-copilot-openai-compatible',
name: 'GitHub Copilot OpenAI Compatible',
import: () => import('@opeoginni/github-copilot-openai-compatible'),
creatorFunctionName: 'createGitHubCopilotOpenAICompatible',
supportsImageGeneration: false,
aliases: ['copilot', 'github-copilot']
},
{
id: 'bedrock',
name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: true,
aliases: ['aws-bedrock']
},
{
id: 'perplexity',
name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false,
aliases: ['perplexity']
},
{
id: 'mistral',
name: 'Mistral',
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral',
supportsImageGeneration: false,
aliases: ['mistral']
},
{
id: 'huggingface',
name: 'HuggingFace',
import: () => import('@ai-sdk/huggingface'),
creatorFunctionName: 'createHuggingFace',
supportsImageGeneration: true,
aliases: ['hf', 'hugging-face']
},
{
id: 'ai-gateway',
name: 'AI Gateway',
import: () => import('@ai-sdk/gateway'),
creatorFunctionName: 'createGateway',
supportsImageGeneration: true,
aliases: ['gateway']
},
{
id: 'cerebras',
name: 'Cerebras',
import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras',
supportsImageGeneration: false
}
] as const
export function initializeSharedProviders(logger?: ProviderInitializationLogger): void {
try {
const successCount = registerMultipleProviderConfigs(SHARED_PROVIDER_CONFIGS)
if (successCount < SHARED_PROVIDER_CONFIGS.length) {
logger?.warn?.('Some providers failed to register. Check previous error logs.')
}
} catch (error) {
logger?.error?.('Failed to initialize shared providers', error as Error)
}
}

View File

@ -0,0 +1,95 @@
/**
* Provider ID Mapping
*
* Maps Cherry Studio provider IDs/types to AI SDK provider IDs.
* This logic should match @renderer/aiCore/provider/factory.ts
*/
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from './detection'
import type { MinimalProvider } from './types'
/**
* Static mapping from Cherry Studio provider ID/type to AI SDK provider ID
* Matches STATIC_PROVIDER_MAPPING in @renderer/aiCore/provider/factory.ts
*/
export const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
gemini: 'google', // Google Gemini -> google
'azure-openai': 'azure', // Azure OpenAI -> azure
'openai-response': 'openai', // OpenAI Responses -> openai
grok: 'xai', // Grok -> xai
copilot: 'github-copilot-openai-compatible'
}
/**
* Try to resolve a provider identifier to an AI SDK provider ID
* Matches tryResolveProviderId in @renderer/aiCore/provider/factory.ts
*
* @param identifier - The provider ID or type to resolve
* @param checker - Provider config checker (defaults to static mapping only)
* @returns The resolved AI SDK provider ID, or null if not found
*/
export function tryResolveProviderId(identifier: string): ProviderId | null {
// 1. 检查静态映射
const staticMapping = STATIC_PROVIDER_MAPPING[identifier]
if (staticMapping) {
return staticMapping
}
// 2. 检查AiCore是否支持包括别名支持
if (hasProviderConfigByAlias(identifier)) {
// 解析为真实的Provider ID
return resolveProviderConfigId(identifier) as ProviderId
}
return null
}
/**
* Get the AI SDK Provider ID for a Cherry Studio provider
* Matches getAiSdkProviderId in @renderer/aiCore/provider/factory.ts
*
* Logic:
* 1. Handle Azure OpenAI specially (check responses endpoint)
* 2. Try to resolve from provider.id
* 3. Try to resolve from provider.type (but not for generic 'openai' type)
* 4. Check for OpenAI API host pattern
* 5. Fallback to provider's own ID
*
* @param provider - The Cherry Studio provider
* @param checker - Provider config checker (defaults to static mapping only)
* @returns The AI SDK provider ID to use
*/
export function getAiSdkProviderId(provider: MinimalProvider): ProviderId {
// 1. Handle Azure OpenAI specially - check this FIRST before other resolution
if (isAzureOpenAIProvider(provider)) {
if (isAzureResponsesEndpoint(provider)) {
return 'azure-responses'
}
return 'azure'
}
// 2. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id)
if (resolvedFromId) {
return resolvedFromId
}
// 3. 尝试解析provider.type
// 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上
if (provider.type !== 'openai') {
const resolvedFromType = tryResolveProviderId(provider.type)
if (resolvedFromType) {
return resolvedFromType
}
}
// 4. Check for OpenAI API host pattern
if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat'
}
// 5. 最后的fallback使用provider本身的id
return provider.id
}

View File

@ -0,0 +1,44 @@
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { isAzureOpenAIProvider, isNewApiProvider } from './detection'
import type { MinimalModel, MinimalProvider } from './types'
export interface ResolveActualProviderOptions<P extends MinimalProvider> {
isSystemProvider?: (provider: P) => boolean
}
const defaultIsSystemProvider = <P extends MinimalProvider>(provider: P): boolean => {
if ('isSystem' in provider) {
return Boolean((provider as unknown as { isSystem?: boolean }).isSystem)
}
return false
}
export function resolveActualProvider<M extends MinimalModel, P extends MinimalProvider>(
provider: P,
model: M,
options: ResolveActualProviderOptions<P> = {}
): P {
let resolvedProvider = provider
if (isNewApiProvider(resolvedProvider)) {
resolvedProvider = newApiResolverCreator(model, resolvedProvider)
}
const isSystemProvider =
options.isSystemProvider?.(resolvedProvider) ?? defaultIsSystemProvider(resolvedProvider)
if (isSystemProvider && resolvedProvider.id === 'aihubmix') {
resolvedProvider = aihubmixProviderCreator(model, resolvedProvider)
}
if (isSystemProvider && resolvedProvider.id === 'vertexai') {
resolvedProvider = vertexAnthropicProviderCreator(model, resolvedProvider)
}
if (isAzureOpenAIProvider(resolvedProvider)) {
resolvedProvider = azureAnthropicProviderCreator(model, resolvedProvider)
}
return resolvedProvider
}

View File

@ -0,0 +1,240 @@
/**
* AI SDK Configuration
*
* Shared utilities for converting Cherry Studio Provider to AI SDK configuration.
* Environment-specific logic (renderer/main) is injected via context interfaces.
*/
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
import { routeToEndpoint } from '../api'
import { getAiSdkProviderId } from './mapping'
import type { MinimalProvider } from './types'
import { SystemProviderIds } from './types'
/**
* AI SDK configuration result
*/
export interface AiSdkConfig {
providerId: string
options: Record<string, unknown>
}
/**
* Context for environment-specific implementations
*/
export interface AiSdkConfigContext {
/**
* Get the rotated API key (for multi-key support)
* Default: returns first key
*/
getRotatedApiKey?: (provider: MinimalProvider) => string
/**
* Check if a model uses chat completion only (for OpenAI response mode)
* Default: returns false
*/
isOpenAIChatCompletionOnlyModel?: (modelId: string) => boolean
/**
* Get Copilot default headers (constants)
* Default: returns empty object
*/
getCopilotDefaultHeaders?: () => Record<string, string>
/**
* Get Copilot stored headers from state
* Default: returns empty object
*/
getCopilotStoredHeaders?: () => Record<string, string>
/**
* Get AWS Bedrock configuration
* Default: returns undefined (not configured)
*/
getAwsBedrockConfig?: () =>
| {
authType: 'apiKey' | 'iam'
region: string
apiKey?: string
accessKeyId?: string
secretAccessKey?: string
}
| undefined
/**
* Get Vertex AI configuration
* Default: returns undefined (not configured)
*/
getVertexConfig?: (provider: MinimalProvider) =>
| {
project: string
location: string
googleCredentials: {
privateKey: string
clientEmail: string
}
}
| undefined
/**
* Get endpoint type for cherryin provider
*/
getEndpointType?: (modelId: string) => string | undefined
/**
* Custom fetch implementation
* Main process: use Electron net.fetch
* Renderer process: use browser fetch (default)
*/
fetch?: typeof globalThis.fetch
}
/**
* Default simple key rotator - returns first key
*/
function defaultGetRotatedApiKey(provider: MinimalProvider): string {
const keys = provider.apiKey.split(',').map((k) => k.trim())
return keys[0] || provider.apiKey
}
/**
* Convert Cherry Studio Provider to AI SDK configuration
*
* @param provider - The formatted provider (after formatProviderApiHost)
* @param modelId - The model ID to use
* @param context - Environment-specific implementations
* @returns AI SDK configuration
*/
export function providerToAiSdkConfig(
provider: MinimalProvider,
modelId: string,
context: AiSdkConfigContext = {}
): AiSdkConfig {
const getRotatedApiKey = context.getRotatedApiKey || defaultGetRotatedApiKey
const isOpenAIChatCompletionOnlyModel = context.isOpenAIChatCompletionOnlyModel || (() => false)
const aiSdkProviderId = getAiSdkProviderId(provider)
// Build base config
const { baseURL, endpoint } = routeToEndpoint(provider.apiHost)
const baseConfig = {
baseURL,
apiKey: getRotatedApiKey(provider)
}
// Handle Copilot specially
if (provider.id === SystemProviderIds.copilot) {
const defaultHeaders = context.getCopilotDefaultHeaders?.() ?? {}
const storedHeaders = context.getCopilotStoredHeaders?.() ?? {}
const options = ProviderConfigFactory.fromProvider('github-copilot-openai-compatible', baseConfig, {
headers: {
...defaultHeaders,
...storedHeaders,
...provider.extra_headers
},
name: provider.id,
includeUsage: true
})
return {
providerId: 'github-copilot-openai-compatible',
options
}
}
// Build extra options
const extraOptions: Record<string, unknown> = {}
if (endpoint) {
extraOptions.endpoint = endpoint
}
// Handle OpenAI mode
if (provider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(modelId)) {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && provider.type === 'openai')) {
extraOptions.mode = 'chat'
}
// Add extra headers
if (provider.extra_headers) {
extraOptions.headers = provider.extra_headers
if (aiSdkProviderId === 'openai') {
extraOptions.headers = {
...(extraOptions.headers as Record<string, string>),
'HTTP-Referer': 'https://cherry-ai.com',
'X-Title': 'Cherry Studio',
'X-Api-Key': baseConfig.apiKey
}
}
}
// Handle Azure modes
if (aiSdkProviderId === 'azure-responses') {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'azure') {
extraOptions.mode = 'chat'
}
// Handle AWS Bedrock
if (aiSdkProviderId === 'bedrock') {
const bedrockConfig = context.getAwsBedrockConfig?.()
if (bedrockConfig) {
extraOptions.region = bedrockConfig.region
if (bedrockConfig.authType === 'apiKey') {
extraOptions.apiKey = bedrockConfig.apiKey
} else {
extraOptions.accessKeyId = bedrockConfig.accessKeyId
extraOptions.secretAccessKey = bedrockConfig.secretAccessKey
}
}
}
// Handle Vertex AI
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
const vertexConfig = context.getVertexConfig?.(provider)
if (vertexConfig) {
extraOptions.project = vertexConfig.project
extraOptions.location = vertexConfig.location
extraOptions.googleCredentials = {
...vertexConfig.googleCredentials,
privateKey: formatPrivateKey(vertexConfig.googleCredentials.privateKey)
}
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
}
}
// Handle cherryin endpoint type
if (aiSdkProviderId === 'cherryin') {
const endpointType = context.getEndpointType?.(modelId)
if (endpointType) {
extraOptions.endpointType = endpointType
}
}
// Inject custom fetch if provided
if (context.fetch) {
extraOptions.fetch = context.fetch
}
// Check if AI SDK supports this provider natively
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId,
options
}
}
// Fallback to openai-compatible
const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: provider.id,
...extraOptions,
includeUsage: true
}
}
}

View File

@ -0,0 +1,174 @@
import * as z from 'zod'
export const ProviderTypeSchema = z.enum([
'openai',
'openai-response',
'anthropic',
'gemini',
'azure-openai',
'vertexai',
'mistral',
'aws-bedrock',
'vertex-anthropic',
'new-api',
'ai-gateway'
])
export type ProviderType = z.infer<typeof ProviderTypeSchema>
/**
* Minimal provider interface for shared utilities
* This is the subset of Provider that shared code needs
*/
export type MinimalProvider = {
id: string
type: ProviderType
apiKey: string
apiHost: string
anthropicApiHost?: string
apiVersion?: string
extra_headers?: Record<string, string>
}
/**
* Minimal model interface for shared utilities
* This is the subset of Model that shared code needs
*/
export type MinimalModel = {
id: string
endpoint_type?: string
}
export const SystemProviderIdSchema = z.enum([
'cherryin',
'silicon',
'aihubmix',
'ocoolai',
'deepseek',
'ppio',
'alayanew',
'qiniu',
'dmxapi',
'burncloud',
'tokenflux',
'302ai',
'cephalon',
'lanyun',
'ph8',
'openrouter',
'ollama',
'ovms',
'new-api',
'lmstudio',
'anthropic',
'openai',
'azure-openai',
'gemini',
'vertexai',
'github',
'copilot',
'zhipu',
'yi',
'moonshot',
'baichuan',
'dashscope',
'stepfun',
'doubao',
'infini',
'minimax',
'groq',
'together',
'fireworks',
'nvidia',
'grok',
'hyperbolic',
'mistral',
'jina',
'perplexity',
'modelscope',
'xirang',
'hunyuan',
'tencent-cloud-ti',
'baidu-cloud',
'gpustack',
'voyageai',
'aws-bedrock',
'poe',
'aionly',
'longcat',
'huggingface',
'sophnet',
'ai-gateway',
'cerebras'
])
export type SystemProviderId = z.infer<typeof SystemProviderIdSchema>
export const isSystemProviderId = (id: string): id is SystemProviderId => {
return SystemProviderIdSchema.safeParse(id).success
}
export const SystemProviderIds = {
cherryin: 'cherryin',
silicon: 'silicon',
aihubmix: 'aihubmix',
ocoolai: 'ocoolai',
deepseek: 'deepseek',
ppio: 'ppio',
alayanew: 'alayanew',
qiniu: 'qiniu',
dmxapi: 'dmxapi',
burncloud: 'burncloud',
tokenflux: 'tokenflux',
'302ai': '302ai',
cephalon: 'cephalon',
lanyun: 'lanyun',
ph8: 'ph8',
sophnet: 'sophnet',
openrouter: 'openrouter',
ollama: 'ollama',
ovms: 'ovms',
'new-api': 'new-api',
lmstudio: 'lmstudio',
anthropic: 'anthropic',
openai: 'openai',
'azure-openai': 'azure-openai',
gemini: 'gemini',
vertexai: 'vertexai',
github: 'github',
copilot: 'copilot',
zhipu: 'zhipu',
yi: 'yi',
moonshot: 'moonshot',
baichuan: 'baichuan',
dashscope: 'dashscope',
stepfun: 'stepfun',
doubao: 'doubao',
infini: 'infini',
minimax: 'minimax',
groq: 'groq',
together: 'together',
fireworks: 'fireworks',
nvidia: 'nvidia',
grok: 'grok',
hyperbolic: 'hyperbolic',
mistral: 'mistral',
jina: 'jina',
perplexity: 'perplexity',
modelscope: 'modelscope',
xirang: 'xirang',
hunyuan: 'hunyuan',
'tencent-cloud-ti': 'tencent-cloud-ti',
'baidu-cloud': 'baidu-cloud',
gpustack: 'gpustack',
voyageai: 'voyageai',
'aws-bedrock': 'aws-bedrock',
poe: 'poe',
aionly: 'aionly',
longcat: 'longcat',
huggingface: 'huggingface',
'ai-gateway': 'ai-gateway',
cerebras: 'cerebras'
} as const satisfies Record<SystemProviderId, SystemProviderId>
export type SystemProviderIdTypeMap = typeof SystemProviderIds

View File

@ -0,0 +1 @@
export { getBaseModelName, getLowerBaseModelName } from './naming'

View File

@ -0,0 +1,31 @@
/**
* ID
*
* - 'deepseek/deepseek-r1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1'
* @param {string} id ID
* @param {string} [delimiter='/'] '/'
* @returns {string}
*/
export const getBaseModelName = (id: string, delimiter: string = '/'): string => {
const parts = id.split(delimiter)
return parts[parts.length - 1]
}
/**
* ID
*
* - 'deepseek/DeepSeek-R1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1'
* @param {string} id ID
* @param {string} [delimiter='/'] '/'
* @returns {string}
*/
export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => {
const baseModelName = getBaseModelName(id, delimiter).toLowerCase()
// for openrouter
if (baseModelName.endsWith(':free')) {
return baseModelName.replace(':free', '')
}
return baseModelName
}

View File

@ -5,6 +5,7 @@ import type { Request, Response } from 'express'
import express from 'express'
import { messagesService } from '../services/messages'
import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages'
import { getProviderById, validateModelId } from '../utils'
const logger = loggerService.withContext('ApiServerMessagesRoutes')
@ -33,21 +34,35 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro
}
interface HandleMessageProcessingOptions {
req: Request
res: Response
provider: Provider
request: MessageCreateParams
modelId?: string
}
/**
* Handle message processing using unified AI SDK
* All providers (including Anthropic) are handled through AI SDK:
* - Anthropic providers use @ai-sdk/anthropic which outputs native Anthropic SSE
* - Other providers use their respective AI SDK adapters, with output converted to Anthropic SSE
*/
async function handleMessageProcessing({
req,
res,
provider,
request,
modelId
}: HandleMessageProcessingOptions): Promise<void> {
const actualModelId = modelId || request.model
logger.info('Processing message via unified AI SDK', {
providerId: provider.id,
providerType: provider.type,
modelId: actualModelId,
stream: !!request.stream
})
try {
// Validate request
const validation = messagesService.validateRequest(request)
if (!validation.isValid) {
res.status(400).json({
@ -60,21 +75,23 @@ async function handleMessageProcessing({
return
}
const extraHeaders = messagesService.prepareHeaders(req.headers)
const { client, anthropicRequest } = await messagesService.processMessage({
provider,
request,
extraHeaders,
modelId
})
if (request.stream) {
await messagesService.handleStreaming(client, anthropicRequest, { response: res }, provider)
return
await streamUnifiedMessages({
response: res,
provider,
modelId: actualModelId,
params: request,
onError: (error) => {
logger.error('Stream error', error as Error)
},
onComplete: () => {
logger.debug('Stream completed')
}
})
} else {
const response = await generateUnifiedMessage(provider, actualModelId, request)
res.json(response)
}
const response = await client.messages.create(anthropicRequest)
res.json(response)
} catch (error: any) {
logger.error('Message processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error)
@ -235,7 +252,7 @@ router.post('/', async (req: Request, res: Response) => {
const provider = modelValidation.provider!
const modelId = modelValidation.modelId!
return handleMessageProcessing({ req, res, provider, request, modelId })
return handleMessageProcessing({ res, provider, request, modelId })
} catch (error: any) {
logger.error('Message processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error)
@ -393,7 +410,7 @@ providerRouter.post('/', async (req: Request, res: Response) => {
const request: MessageCreateParams = req.body
return handleMessageProcessing({ req, res, provider, request })
return handleMessageProcessing({ res, provider, request })
} catch (error: any) {
logger.error('Message processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error)
@ -401,4 +418,194 @@ providerRouter.post('/', async (req: Request, res: Response) => {
}
})
/**
* @swagger
* /v1/messages/count_tokens:
* post:
* summary: Count tokens for messages
* description: Count tokens for Anthropic Messages API format (required by Claude Code SDK)
* tags: [Messages]
* requestBody:
* required: true
* content:
* application/json:
* schema:
* type: object
* required:
* - model
* - messages
* properties:
* model:
* type: string
* description: Model ID
* messages:
* type: array
* items:
* type: object
* system:
* type: string
* description: System message
* responses:
* 200:
* description: Token count response
* content:
* application/json:
* schema:
* type: object
* properties:
* input_tokens:
* type: integer
* 400:
* description: Bad request
*/
router.post('/count_tokens', async (req: Request, res: Response) => {
try {
const { model, messages, system } = req.body
if (!model) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'model parameter is required'
}
})
}
if (!messages || !Array.isArray(messages)) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'messages parameter is required'
}
})
}
// Simple token estimation based on character count
// This is a rough approximation: ~4 characters per token for English text
let totalChars = 0
// Count system message tokens
if (system) {
if (typeof system === 'string') {
totalChars += system.length
} else if (Array.isArray(system)) {
for (const block of system) {
if (block.type === 'text' && block.text) {
totalChars += block.text.length
}
}
}
}
// Count message tokens
for (const msg of messages) {
if (typeof msg.content === 'string') {
totalChars += msg.content.length
} else if (Array.isArray(msg.content)) {
for (const block of msg.content) {
if (block.type === 'text' && block.text) {
totalChars += block.text.length
}
}
}
// Add overhead for role
totalChars += 10
}
// Estimate tokens (~4 chars per token, with some overhead)
const estimatedTokens = Math.ceil(totalChars / 4) + messages.length * 3
logger.debug('Token count estimated', {
model,
messageCount: messages.length,
totalChars,
estimatedTokens
})
return res.json({
input_tokens: estimatedTokens
})
} catch (error: any) {
logger.error('Token counting error', { error })
return res.status(500).json({
type: 'error',
error: {
type: 'api_error',
message: error.message || 'Internal server error'
}
})
}
})
/**
* Provider-specific count_tokens endpoint
*/
providerRouter.post('/count_tokens', async (req: Request, res: Response) => {
try {
const { model, messages, system } = req.body
if (!messages || !Array.isArray(messages)) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'messages parameter is required'
}
})
}
// Simple token estimation
let totalChars = 0
if (system) {
if (typeof system === 'string') {
totalChars += system.length
} else if (Array.isArray(system)) {
for (const block of system) {
if (block.type === 'text' && block.text) {
totalChars += block.text.length
}
}
}
}
for (const msg of messages) {
if (typeof msg.content === 'string') {
totalChars += msg.content.length
} else if (Array.isArray(msg.content)) {
for (const block of msg.content) {
if (block.type === 'text' && block.text) {
totalChars += block.text.length
}
}
}
totalChars += 10
}
const estimatedTokens = Math.ceil(totalChars / 4) + messages.length * 3
logger.debug('Token count estimated (provider route)', {
providerId: req.params.provider,
model,
messageCount: messages.length,
estimatedTokens
})
return res.json({
input_tokens: estimatedTokens
})
} catch (error: any) {
logger.error('Token counting error', { error })
return res.status(500).json({
type: 'error',
error: {
type: 'api_error',
message: error.message || 'Internal server error'
}
})
}
})
export { providerRouter as messagesProviderRoutes, router as messagesRoutes }

View File

@ -1,13 +1,6 @@
import { isEmpty } from 'lodash'
import type { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels'
import { loggerService } from '../../services/LoggerService'
import {
getAvailableProviders,
getProviderAnthropicModelChecker,
listAllAvailableModels,
transformModelToOpenAI
} from '../utils'
import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils'
const logger = loggerService.withContext('ModelsService')
@ -20,11 +13,12 @@ export class ModelsService {
try {
logger.debug('Getting available models from providers', { filter })
let providers = await getAvailableProviders()
const providers = await getAvailableProviders()
if (filter.providerType === 'anthropic') {
providers = providers.filter((p) => p.type === 'anthropic' || !isEmpty(p.anthropicApiHost?.trim()))
}
// Note: When providerType === 'anthropic', we now return ALL available models
// because the API Server's unified adapter (AiSdkToAnthropicSSE) can convert
// any provider's response to Anthropic SSE format. This enables Claude Code Agent
// to work with OpenAI, Gemini, and other providers transparently.
const models = await listAllAvailableModels(providers)
// Use Map to deduplicate models by their full ID (provider:model_id)
@ -32,20 +26,11 @@ export class ModelsService {
for (const model of models) {
const provider = providers.find((p) => p.id === model.provider)
// logger.debug(`Processing model ${model.id}`)
if (!provider) {
logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`)
continue
}
if (filter.providerType === 'anthropic') {
const checker = getProviderAnthropicModelChecker(provider.id)
if (!checker(model)) {
logger.debug(`Skipping model ${model.id} from ${model.provider}. Reason: Not an Anthropic model.`)
continue
}
}
const openAIModel = transformModelToOpenAI(model, provider)
const fullModelId = openAIModel.id // This is already in format "provider:model_id"

View File

@ -0,0 +1,455 @@
import type { LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
import type { ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
import type { ImageBlockParam, MessageCreateParams, TextBlockParam } from '@anthropic-ai/sdk/resources/messages'
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { reduxService } from '@main/services/ReduxService'
import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@shared/adapters'
import {
type AiSdkConfig,
type AiSdkConfigContext,
formatProviderApiHost,
initializeSharedProviders,
type ProviderFormatContext,
providerToAiSdkConfig as sharedProviderToAiSdkConfig,
resolveActualProvider
} from '@shared/provider'
import { defaultAppHeaders } from '@shared/utils'
import type { Provider } from '@types'
import type { ImagePart, LanguageModel, ModelMessage, Provider as AiSdkProvider, TextPart } from 'ai'
import { stepCountIs, streamText } from 'ai'
import { net } from 'electron'
import type { Response } from 'express'
const logger = loggerService.withContext('UnifiedMessagesService')
initializeSharedProviders({
warn: (message) => logger.warn(message),
error: (message, error) => logger.error(message, error)
})
export interface UnifiedStreamConfig {
response: Response
provider: Provider
modelId: string
params: MessageCreateParams
onError?: (error: unknown) => void
onComplete?: () => void
}
// ============================================================================
// Provider Factory
// ============================================================================
/**
* Main process format context for formatProviderApiHost
* Unlike renderer, main process doesn't have direct access to store getters, so use reduxService cache
*/
function getMainProcessFormatContext(): ProviderFormatContext {
const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai')
return {
vertex: {
project: vertexSettings?.projectId || 'default-project',
location: vertexSettings?.location || 'us-central1'
}
}
}
/**
* Main process context for providerToAiSdkConfig
* Main process doesn't have access to browser APIs like window.keyv
*/
const mainProcessSdkContext: AiSdkConfigContext = {
// Simple key rotation - just return first key (no persistent rotation in main process)
getRotatedApiKey: (provider) => {
const keys = provider.apiKey.split(',').map((k) => k.trim())
return keys[0] || provider.apiKey
},
fetch: net.fetch as typeof globalThis.fetch
}
/**
* Get actual provider configuration for a model
*
* For aggregated providers (new-api, aihubmix, vertexai, azure-openai),
* this resolves the actual provider type based on the model's characteristics.
*/
function getActualProvider(provider: Provider, modelId: string): Provider {
// Find the model in provider's models list
const model = provider.models?.find((m) => m.id === modelId)
if (!model) {
// If model not found, return provider as-is
return provider
}
// Resolve actual provider based on model
return resolveActualProvider(provider, model)
}
/**
* Convert Cherry Studio Provider to AI SDK config
* Uses shared implementation with main process context
*/
function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig {
// First resolve actual provider for aggregated providers
const actualProvider = getActualProvider(provider, modelId)
// Format the provider's apiHost for AI SDK
const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext())
// Use shared implementation
return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext)
}
/**
* Create an AI SDK provider from Cherry Studio provider configuration
*/
async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {
try {
const provider = await createProviderCore(config.providerId, config.options)
logger.debug('AI SDK provider created', {
providerId: config.providerId,
hasOptions: !!config.options
})
return provider
} catch (error) {
logger.error('Failed to create AI SDK provider', error as Error, {
providerId: config.providerId
})
throw error
}
}
/**
* Create an AI SDK language model from a Cherry Studio provider configuration
* Uses shared provider utilities for consistent behavior with renderer
*/
async function createLanguageModel(provider: Provider, modelId: string): Promise<LanguageModel> {
logger.debug('Creating language model', {
providerId: provider.id,
providerType: provider.type,
modelId,
apiHost: provider.apiHost
})
// Convert provider config to AI SDK config
const config = providerToAiSdkConfig(provider, modelId)
// Create the AI SDK provider
const aiSdkProvider = await createAiSdkProvider(config)
if (!aiSdkProvider) {
throw new Error(`Failed to create AI SDK provider for ${provider.id}`)
}
// Get the language model
return aiSdkProvider.languageModel(modelId)
}
function convertAnthropicToolResultToAiSdk(
content: string | Array<TextBlockParam | ImageBlockParam>
): LanguageModelV2ToolResultOutput {
if (typeof content === 'string') {
return {
type: 'text',
value: content
}
} else {
const values: Array<
| { type: 'text'; text: string }
| {
type: 'media'
/**
Base-64 encoded media data.
*/
data: string
/**
IANA media type.
@see https://www.iana.org/assignments/media-types/media-types.xhtml
*/
mediaType: string
}
> = []
for (const block of content) {
if (block.type === 'text') {
values.push({
type: 'text',
text: block.text
})
} else if (block.type === 'image') {
values.push({
type: 'media',
data: block.source.type === 'base64' ? block.source.data : block.source.url,
mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png'
})
}
}
return {
type: 'content',
value: []
}
}
}
/**
* Convert Anthropic MessageCreateParams to AI SDK message format
*/
function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage[] {
const messages: ModelMessage[] = []
// Add system message if present
if (params.system) {
if (typeof params.system === 'string') {
messages.push({
role: 'system',
content: params.system
})
} else if (Array.isArray(params.system)) {
// Handle TextBlockParam array
const systemText = params.system
.filter((block) => block.type === 'text')
.map((block) => block.text)
.join('\n')
if (systemText) {
messages.push({
role: 'system',
content: systemText
})
}
}
}
// Convert user/assistant messages
for (const msg of params.messages) {
if (typeof msg.content === 'string') {
if (msg.role === 'user') {
messages.push({ role: 'user', content: msg.content })
} else {
messages.push({ role: 'assistant', content: msg.content })
}
} else if (Array.isArray(msg.content)) {
// Handle content blocks
const textParts: TextPart[] = []
const imageParts: ImagePart[] = []
const reasoningParts: ReasoningPart[] = []
const toolCallParts: ToolCallPart[] = []
const toolResultParts: ToolResultPart[] = []
for (const block of msg.content) {
if (block.type === 'text') {
textParts.push({ type: 'text', text: block.text })
} else if (block.type === 'thinking') {
reasoningParts.push({ type: 'reasoning', text: block.thinking })
} else if (block.type === 'redacted_thinking') {
reasoningParts.push({ type: 'reasoning', text: block.data })
} else if (block.type === 'image') {
const source = block.source
if (source.type === 'base64') {
imageParts.push({
type: 'image',
image: `data:${source.media_type};base64,${source.data}`
})
} else if (source.type === 'url') {
imageParts.push({
type: 'image',
image: source.url
})
}
} else if (block.type === 'tool_use') {
toolCallParts.push({
type: 'tool-call',
toolName: block.name,
toolCallId: block.id,
input: block.input
})
} else if (block.type === 'tool_result') {
toolResultParts.push({
type: 'tool-result',
toolCallId: block.tool_use_id,
toolName: toolCallParts.find((t) => t.toolCallId === block.tool_use_id)?.toolName || 'unknown',
output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' }
})
}
}
// Build the message based on role
if (msg.role === 'user') {
messages.push({
role: 'user',
content: [...textParts, ...imageParts]
})
} else {
// Assistant messages can only have text
if (textParts.length > 0) {
messages.push({
role: 'assistant',
content: [...reasoningParts, ...textParts, ...toolCallParts, ...toolResultParts]
})
}
}
}
}
return messages
}
/**
* Stream a message request using AI SDK and convert to Anthropic SSE format
*/
export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise<void> {
const { response, provider, modelId, params, onError, onComplete } = config
logger.info('Starting unified message stream', {
providerId: provider.id,
providerType: provider.type,
modelId,
stream: params.stream
})
try {
response.setHeader('Content-Type', 'text/event-stream')
response.setHeader('Cache-Control', 'no-cache')
response.setHeader('Connection', 'keep-alive')
response.setHeader('X-Accel-Buffering', 'no')
const model = await createLanguageModel(provider, modelId)
const coreMessages = convertAnthropicToAiMessages(params)
logger.debug('Converted messages', {
originalCount: params.messages.length,
convertedCount: coreMessages.length,
hasSystem: !!params.system
})
// Create the adapter
const adapter = new AiSdkToAnthropicSSE({
model: `${provider.id}:${modelId}`,
onEvent: (event) => {
const sseData = formatSSEEvent(event)
response.write(sseData)
}
})
// Start streaming
const result = streamText({
model,
messages: coreMessages,
maxOutputTokens: params.max_tokens,
temperature: params.temperature,
topP: params.top_p,
stopSequences: params.stop_sequences,
stopWhen: stepCountIs(100),
headers: defaultAppHeaders(),
providerOptions: {}
})
// Process the stream through the adapter
await adapter.processStream(result.fullStream)
// Send done marker
response.write(formatSSEDone())
response.end()
logger.info('Unified message stream completed', {
providerId: provider.id,
modelId
})
onComplete?.()
} catch (error) {
logger.error('Error in unified message stream', error as Error, {
providerId: provider.id,
modelId
})
// Try to send error event if response is still writable
if (!response.writableEnded) {
try {
const errorMessage = error instanceof Error ? error.message : 'Unknown error'
response.write(
`event: error\ndata: ${JSON.stringify({
type: 'error',
error: {
type: 'api_error',
message: errorMessage
}
})}\n\n`
)
response.end()
} catch {
// Response already ended
}
}
onError?.(error)
throw error
}
}
/**
* Generate a non-streaming message response
*/
export async function generateUnifiedMessage(
provider: Provider,
modelId: string,
params: MessageCreateParams
): Promise<ReturnType<typeof AiSdkToAnthropicSSE.prototype.buildNonStreamingResponse>> {
logger.info('Starting unified message generation', {
providerId: provider.id,
providerType: provider.type,
modelId
})
try {
// Create language model (async - uses @cherrystudio/ai-core)
const model = await createLanguageModel(provider, modelId)
// Convert messages
const coreMessages = convertAnthropicToAiMessages(params)
// Create adapter to collect the response
let finalResponse: ReturnType<typeof AiSdkToAnthropicSSE.prototype.buildNonStreamingResponse> | null = null
const adapter = new AiSdkToAnthropicSSE({
model: `${provider.id}:${modelId}`,
onEvent: () => {
// We don't need to emit events for non-streaming
}
})
// Generate text
const result = streamText({
model,
messages: coreMessages,
maxOutputTokens: params.max_tokens,
temperature: params.temperature,
topP: params.top_p,
stopSequences: params.stop_sequences,
headers: defaultAppHeaders(),
stopWhen: stepCountIs(100)
})
// Process the stream to build the response
await adapter.processStream(result.fullStream)
// Get the final response
finalResponse = adapter.buildNonStreamingResponse()
logger.info('Unified message generation completed', {
providerId: provider.id,
modelId
})
return finalResponse
} catch (error) {
logger.error('Error in unified message generation', error as Error, {
providerId: provider.id,
modelId
})
throw error
}
}
export default {
streamUnifiedMessages,
generateUnifiedMessage
}

View File

@ -84,18 +84,14 @@ class ClaudeCodeService implements AgentServiceInterface {
})
return aiStream
}
if (
(modelInfo.provider?.type !== 'anthropic' &&
(modelInfo.provider?.anthropicApiHost === undefined || modelInfo.provider.anthropicApiHost.trim() === '')) ||
modelInfo.provider.apiKey === ''
) {
logger.error('Anthropic provider configuration is missing', {
modelInfo
})
// Validate provider has required configuration
// Note: We no longer restrict to anthropic type only - the API Server's unified adapter
// handles format conversion for any provider type (OpenAI, Gemini, etc.)
if (!modelInfo.provider?.apiKey) {
logger.error('Provider API key is missing', { modelInfo })
aiStream.emit('data', {
type: 'error',
error: new Error(`Invalid provider type '${modelInfo.provider?.type}'. Expected 'anthropic' provider type.`)
error: new Error(`Provider '${modelInfo.provider?.id}' is missing API key configuration.`)
})
return aiStream
}
@ -106,15 +102,14 @@ class ClaudeCodeService implements AgentServiceInterface {
Object.entries(loginShellEnv).filter(([key]) => !key.toLowerCase().endsWith('_proxy'))
) as Record<string, string>
// Route through local API Server which handles format conversion via unified adapter
// This enables Claude Code Agent to work with any provider (OpenAI, Gemini, etc.)
// The API Server converts AI SDK responses to Anthropic SSE format transparently
const env = {
...loginShellEnvWithoutProxies,
// TODO: fix the proxy api server
// ANTHROPIC_API_KEY: apiConfig.apiKey,
// ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey,
// ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`,
ANTHROPIC_API_KEY: modelInfo.provider.apiKey,
ANTHROPIC_AUTH_TOKEN: modelInfo.provider.apiKey,
ANTHROPIC_BASE_URL: modelInfo.provider.anthropicApiHost?.trim() || modelInfo.provider.apiHost,
ANTHROPIC_API_KEY: apiConfig.apiKey,
ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey,
ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`,
ANTHROPIC_MODEL: modelInfo.modelId,
ANTHROPIC_DEFAULT_OPUS_MODEL: modelInfo.modelId,
ANTHROPIC_DEFAULT_SONNET_MODEL: modelInfo.modelId,

View File

@ -24,7 +24,7 @@ export class VertexAPIClient extends GeminiAPIClient {
this.anthropicVertexClient = new AnthropicVertexClient(provider)
// 如果传入的是普通 Provider转换为 VertexProvider
if (isVertexProvider(provider)) {
this.vertexProvider = provider
this.vertexProvider = provider as VertexProvider
} else {
this.vertexProvider = createVertexProvider(provider)
}

View File

@ -1,22 +0,0 @@
import type { Provider } from '@renderer/types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
const AZURE_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => ({
...provider,
type: 'anthropic',
apiHost: provider.apiHost + 'anthropic/v1',
id: 'azure-anthropic'
})
}
],
fallbackRule: (provider: Provider) => provider
}
export const azureAnthropicProviderCreator = provider2Provider.bind(null, AZURE_ANTHROPIC_RULES)

View File

@ -1,22 +0,0 @@
import type { Model, Provider } from '@renderer/types'
import type { RuleSet } from './types'
export const startsWith = (prefix: string) => (model: Model) => model.id.toLowerCase().startsWith(prefix.toLowerCase())
export const endpointIs = (type: string) => (model: Model) => model.endpoint_type === type
/**
* Provider
* @param ruleSet
* @param model
* @param provider provider对象
* @returns provider对象
*/
export function provider2Provider(ruleSet: RuleSet, model: Model, provider: Provider): Provider {
for (const rule of ruleSet.rules) {
if (rule.match(model)) {
return rule.provider(provider)
}
}
return ruleSet.fallbackRule(provider)
}

View File

@ -1,3 +1,7 @@
export { aihubmixProviderCreator } from './aihubmix'
export { newApiResolverCreator } from './newApi'
export { vertexAnthropicProviderCreator } from './vertext-anthropic'
// Re-export from shared config
export {
aihubmixProviderCreator,
azureAnthropicProviderCreator,
newApiResolverCreator,
vertexAnthropicProviderCreator
} from '@shared/provider/config'

View File

@ -1,9 +0,0 @@
import type { Model, Provider } from '@renderer/types'
export interface RuleSet {
rules: Array<{
match: (model: Model) => boolean
provider: (provider: Provider) => Provider
}>
fallbackRule: (provider: Provider) => Provider
}

View File

@ -1,19 +0,0 @@
import type { Provider } from '@renderer/types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
const VERTEX_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => ({
...provider,
id: 'google-vertex-anthropic'
})
}
],
fallbackRule: (provider: Provider) => provider
}
export const vertexAnthropicProviderCreator = provider2Provider.bind(null, VERTEX_ANTHROPIC_RULES)

View File

@ -1,8 +1,7 @@
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import type { Provider } from '@renderer/types'
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider'
import { getAiSdkProviderId as sharedGetAiSdkProviderId } from '@shared/provider'
import type { Provider as AiSdkProvider } from 'ai'
import type { AiSdkConfig } from '../types'
@ -22,68 +21,12 @@ const logger = loggerService.withContext('ProviderFactory')
}
})()
/**
* Provider映射表
* Cherry Studio特有的provider ID到AI SDK标准ID的映射
*/
const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
gemini: 'google', // Google Gemini -> google
'azure-openai': 'azure', // Azure OpenAI -> azure
'openai-response': 'openai', // OpenAI Responses -> openai
grok: 'xai', // Grok -> xai
copilot: 'github-copilot-openai-compatible'
}
/**
* provider标识符
*/
function tryResolveProviderId(identifier: string): ProviderId | null {
// 1. 检查静态映射
const staticMapping = STATIC_PROVIDER_MAPPING[identifier]
if (staticMapping) {
return staticMapping
}
// 2. 检查AiCore是否支持包括别名支持
if (hasProviderConfigByAlias(identifier)) {
// 解析为真实的Provider ID
return resolveProviderConfigId(identifier) as ProviderId
}
return null
}
/**
* AI SDK Provider ID
*
* Uses shared implementation with renderer-specific config checker
*/
export function getAiSdkProviderId(provider: Provider): string {
// 1. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id)
if (isAzureOpenAIProvider(provider)) {
if (isAzureResponsesEndpoint(provider)) {
return 'azure-responses'
} else {
return 'azure'
}
}
if (resolvedFromId) {
return resolvedFromId
}
// 2. 尝试解析provider.type
// 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上
if (provider.type !== 'openai') {
const resolvedFromType = tryResolveProviderId(provider.type)
if (resolvedFromType) {
return resolvedFromType
}
}
if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat'
}
// 3. 最后的fallback使用provider本身的id
return provider.id
return sharedGetAiSdkProviderId(provider)
}
export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {

View File

@ -1,4 +1,4 @@
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
import { hasProviderConfig } from '@cherrystudio/ai-core/provider'
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
import {
getAwsBedrockAccessKeyId,
@ -10,22 +10,17 @@ import {
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api'
import { isSystemProvider, type Model, type Provider } from '@renderer/types'
import {
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isPerplexityProvider,
isVertexProvider
} from '@renderer/utils/provider'
type AiSdkConfigContext,
formatProviderApiHost as sharedFormatProviderApiHost,
type ProviderFormatContext,
providerToAiSdkConfig as sharedProviderToAiSdkConfig,
resolveActualProvider
} from '@shared/provider'
import { cloneDeep } from 'lodash'
import type { AiSdkConfig } from '../types'
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { COPILOT_DEFAULT_HEADERS } from './constants'
import { getAiSdkProviderId } from './factory'
@ -56,61 +51,51 @@ function getRotatedApiKey(provider: Provider): string {
}
/**
* provider的转换逻辑
* Renderer-specific context for providerToAiSdkConfig
* Provides implementations using browser APIs, store, and hooks
*/
function handleSpecialProviders(model: Model, provider: Provider): Provider {
if (isNewApiProvider(provider)) {
return newApiResolverCreator(model, provider)
function createRendererSdkContext(model: Model): AiSdkConfigContext {
return {
getRotatedApiKey: (provider) => getRotatedApiKey(provider as Provider),
isOpenAIChatCompletionOnlyModel: () => isOpenAIChatCompletionOnlyModel(model),
getCopilotDefaultHeaders: () => COPILOT_DEFAULT_HEADERS,
getCopilotStoredHeaders: () => store.getState().copilot.defaultHeaders ?? {},
getAwsBedrockConfig: () => {
const authType = getAwsBedrockAuthType()
return {
authType,
region: getAwsBedrockRegion(),
apiKey: authType === 'apiKey' ? getAwsBedrockApiKey() : undefined,
accessKeyId: authType === 'iam' ? getAwsBedrockAccessKeyId() : undefined,
secretAccessKey: authType === 'iam' ? getAwsBedrockSecretAccessKey() : undefined
}
},
getVertexConfig: (provider) => {
if (!isVertexAIConfigured()) {
return undefined
}
return createVertexProvider(provider as Provider)
},
getEndpointType: () => model.endpoint_type
}
if (isSystemProvider(provider)) {
if (provider.id === 'aihubmix') {
return aihubmixProviderCreator(model, provider)
}
if (provider.id === 'vertexai') {
return vertexAnthropicProviderCreator(model, provider)
}
}
if (isAzureOpenAIProvider(provider)) {
return azureAnthropicProviderCreator(model, provider)
}
return provider
}
/**
* AISdk的BaseURL格式
* @param provider
* @returns
* Uses shared implementation with renderer-specific context
*/
function formatProviderApiHost(provider: Provider): Provider {
const formatted = { ...provider }
if (formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost)
}
if (isAnthropicProvider(provider)) {
const baseHost = formatted.anthropicApiHost || formatted.apiHost
// AI SDK needs /v1 in baseURL, Anthropic SDK will strip it in getSdkClient
formatted.apiHost = formatApiHost(baseHost)
if (!formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatted.apiHost
function getRendererFormatContext(): ProviderFormatContext {
const vertexSettings = store.getState().llm.settings.vertexai
return {
vertex: {
project: vertexSettings.projectId || 'default-project',
location: vertexSettings.location || 'us-central1'
}
} else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isGeminiProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta')
} else if (isAzureOpenAIProvider(formatted)) {
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
} else if (isVertexProvider(formatted)) {
formatted.apiHost = formatVertexApiHost(formatted)
} else if (isCherryAIProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isPerplexityProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else {
formatted.apiHost = formatApiHost(formatted.apiHost)
}
return formatted
}
function formatProviderApiHost(provider: Provider): Provider {
return sharedFormatProviderApiHost(provider, getRendererFormatContext())
}
/**
@ -122,7 +107,9 @@ export function getActualProvider(model: Model): Provider {
// 按顺序处理各种转换
let actualProvider = cloneDeep(baseProvider)
actualProvider = handleSpecialProviders(model, actualProvider)
actualProvider = resolveActualProvider(actualProvider, model, {
isSystemProvider
}) as Provider
actualProvider = formatProviderApiHost(actualProvider)
return actualProvider
@ -130,121 +117,11 @@ export function getActualProvider(model: Model): Provider {
/**
* Provider AI SDK
*
* Uses shared implementation with renderer-specific context
*/
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// 构建基础配置
const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost)
const baseConfig = {
baseURL: baseURL,
apiKey: getRotatedApiKey(actualProvider)
}
const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot
if (isCopilotProvider) {
const storedHeaders = store.getState().copilot.defaultHeaders ?? {}
const options = ProviderConfigFactory.fromProvider('github-copilot-openai-compatible', baseConfig, {
headers: {
...COPILOT_DEFAULT_HEADERS,
...storedHeaders,
...actualProvider.extra_headers
},
name: actualProvider.id,
includeUsage: true
})
return {
providerId: 'github-copilot-openai-compatible',
options
}
}
// 处理OpenAI模式
const extraOptions: any = {}
extraOptions.endpoint = endpoint
if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && actualProvider.type === 'openai')) {
extraOptions.mode = 'chat'
}
// 添加额外headers
if (actualProvider.extra_headers) {
extraOptions.headers = actualProvider.extra_headers
// copy from openaiBaseClient/openaiResponseApiClient
if (aiSdkProviderId === 'openai') {
extraOptions.headers = {
...extraOptions.headers,
'HTTP-Referer': 'https://cherry-ai.com',
'X-Title': 'Cherry Studio',
'X-Api-Key': baseConfig.apiKey
}
}
}
// azure
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api
if (aiSdkProviderId === 'azure-responses') {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'azure') {
extraOptions.mode = 'chat'
}
// bedrock
if (aiSdkProviderId === 'bedrock') {
const authType = getAwsBedrockAuthType()
extraOptions.region = getAwsBedrockRegion()
if (authType === 'apiKey') {
extraOptions.apiKey = getAwsBedrockApiKey()
} else {
extraOptions.accessKeyId = getAwsBedrockAccessKeyId()
extraOptions.secretAccessKey = getAwsBedrockSecretAccessKey()
}
}
// google-vertex
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}
const { project, location, googleCredentials } = createVertexProvider(actualProvider)
extraOptions.project = project
extraOptions.location = location
extraOptions.googleCredentials = {
...googleCredentials,
privateKey: formatPrivateKey(googleCredentials.privateKey)
}
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
}
// cherryin
if (aiSdkProviderId === 'cherryin') {
if (model.endpoint_type) {
extraOptions.endpointType = model.endpoint_type
}
}
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId,
options
}
}
// 否则fallback到openai-compatible
const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: actualProvider.id,
...extraOptions,
includeUsage: true
}
}
const context = createRendererSdkContext(model)
return sharedProviderToAiSdkConfig(actualProvider, model.id, context) as AiSdkConfig
}
/**
@ -287,13 +164,13 @@ export async function prepareSpecialProviderConfig(
break
}
case 'cherryai': {
config.options.fetch = async (url, options) => {
config.options.fetch = async (url: RequestInfo | URL, options: RequestInit) => {
// 在这里对最终参数进行签名
const signature = await window.api.cherryai.generateSignature({
method: 'POST',
path: '/chat/completions',
query: '',
body: JSON.parse(options.body)
body: JSON.parse(options.body as string)
})
return fetch(url, {
...options,

View File

@ -1,113 +1,13 @@
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { initializeSharedProviders, SHARED_PROVIDER_CONFIGS } from '@shared/provider'
const logger = loggerService.withContext('ProviderConfigs')
/**
* Provider配置定义
* AI Providers
*/
export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
{
id: 'openrouter',
name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: true,
aliases: ['openrouter']
},
{
id: 'google-vertex',
name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex/edge'),
creatorFunctionName: 'createVertex',
supportsImageGeneration: true,
aliases: ['vertexai']
},
{
id: 'google-vertex-anthropic',
name: 'Google Vertex AI Anthropic',
import: () => import('@ai-sdk/google-vertex/anthropic/edge'),
creatorFunctionName: 'createVertexAnthropic',
supportsImageGeneration: true,
aliases: ['vertexai-anthropic']
},
{
id: 'azure-anthropic',
name: 'Azure AI Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false,
aliases: ['azure-anthropic']
},
{
id: 'github-copilot-openai-compatible',
name: 'GitHub Copilot OpenAI Compatible',
import: () => import('@opeoginni/github-copilot-openai-compatible'),
creatorFunctionName: 'createGitHubCopilotOpenAICompatible',
supportsImageGeneration: false,
aliases: ['copilot', 'github-copilot']
},
{
id: 'bedrock',
name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: true,
aliases: ['aws-bedrock']
},
{
id: 'perplexity',
name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false,
aliases: ['perplexity']
},
{
id: 'mistral',
name: 'Mistral',
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral',
supportsImageGeneration: false,
aliases: ['mistral']
},
{
id: 'huggingface',
name: 'HuggingFace',
import: () => import('@ai-sdk/huggingface'),
creatorFunctionName: 'createHuggingFace',
supportsImageGeneration: true,
aliases: ['hf', 'hugging-face']
},
{
id: 'ai-gateway',
name: 'AI Gateway',
import: () => import('@ai-sdk/gateway'),
creatorFunctionName: 'createGateway',
supportsImageGeneration: true,
aliases: ['gateway']
},
{
id: 'cerebras',
name: 'Cerebras',
import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras',
supportsImageGeneration: false
}
] as const
export const NEW_PROVIDER_CONFIGS = SHARED_PROVIDER_CONFIGS
/**
* Providers
* 使aiCore的动态注册功能
*/
export async function initializeNewProviders(): Promise<void> {
try {
const successCount = registerMultipleProviderConfigs(NEW_PROVIDER_CONFIGS)
if (successCount < NEW_PROVIDER_CONFIGS.length) {
logger.warn('Some providers failed to register. Check previous error logs.')
}
} catch (error) {
logger.error('Failed to initialize new providers:', error as Error)
}
initializeSharedProviders({
warn: (message) => logger.warn(message),
error: (message, error) => logger.error(message, error)
})
}

View File

@ -7,6 +7,8 @@ import type { CSSProperties } from 'react'
export * from './file'
export * from './note'
import type { MinimalModel } from '@shared/provider/types'
import type { StreamTextParams } from './aiCoreTypes'
import type { Chunk } from './chunk'
import type { FileMetadata } from './file'
@ -256,7 +258,7 @@ export type ModelCapability = {
isUserSelected?: boolean
}
export type Model = {
export type Model = MinimalModel & {
id: string
provider: string
name: string

View File

@ -1,24 +1,14 @@
import type OpenAI from '@cherrystudio/openai'
import type { MinimalProvider } from '@shared/provider'
import type { ProviderType, SystemProviderId, SystemProviderIdTypeMap } from '@shared/provider/types'
import { isSystemProviderId, SystemProviderIds } from '@shared/provider/types'
import type { Model } from '@types'
import * as z from 'zod'
import type { OpenAIVerbosity } from './aiCoreTypes'
export const ProviderTypeSchema = z.enum([
'openai',
'openai-response',
'anthropic',
'gemini',
'azure-openai',
'vertexai',
'mistral',
'aws-bedrock',
'vertex-anthropic',
'new-api',
'ai-gateway'
])
export type ProviderType = z.infer<typeof ProviderTypeSchema>
export type { ProviderType } from '@shared/provider'
export type { SystemProviderId, SystemProviderIdTypeMap } from '@shared/provider/types'
export { isSystemProviderId, ProviderTypeSchema, SystemProviderIds } from '@shared/provider/types'
// undefined is treated as supported, enabled by default
export type ProviderApiOptions = {
@ -93,7 +83,7 @@ export function isAwsBedrockAuthType(type: string): type is AwsBedrockAuthType {
return Object.hasOwn(AwsBedrockAuthTypes, type)
}
export type Provider = {
export type Provider = MinimalProvider & {
id: string
type: ProviderType
name: string
@ -128,140 +118,6 @@ export type Provider = {
extra_headers?: Record<string, string>
}
export const SystemProviderIdSchema = z.enum([
'cherryin',
'silicon',
'aihubmix',
'ocoolai',
'deepseek',
'ppio',
'alayanew',
'qiniu',
'dmxapi',
'burncloud',
'tokenflux',
'302ai',
'cephalon',
'lanyun',
'ph8',
'openrouter',
'ollama',
'ovms',
'new-api',
'lmstudio',
'anthropic',
'openai',
'azure-openai',
'gemini',
'vertexai',
'github',
'copilot',
'zhipu',
'yi',
'moonshot',
'baichuan',
'dashscope',
'stepfun',
'doubao',
'infini',
'minimax',
'groq',
'together',
'fireworks',
'nvidia',
'grok',
'hyperbolic',
'mistral',
'jina',
'perplexity',
'modelscope',
'xirang',
'hunyuan',
'tencent-cloud-ti',
'baidu-cloud',
'gpustack',
'voyageai',
'aws-bedrock',
'poe',
'aionly',
'longcat',
'huggingface',
'sophnet',
'ai-gateway',
'cerebras'
])
export type SystemProviderId = z.infer<typeof SystemProviderIdSchema>
export const isSystemProviderId = (id: string): id is SystemProviderId => {
return SystemProviderIdSchema.safeParse(id).success
}
export const SystemProviderIds = {
cherryin: 'cherryin',
silicon: 'silicon',
aihubmix: 'aihubmix',
ocoolai: 'ocoolai',
deepseek: 'deepseek',
ppio: 'ppio',
alayanew: 'alayanew',
qiniu: 'qiniu',
dmxapi: 'dmxapi',
burncloud: 'burncloud',
tokenflux: 'tokenflux',
'302ai': '302ai',
cephalon: 'cephalon',
lanyun: 'lanyun',
ph8: 'ph8',
sophnet: 'sophnet',
openrouter: 'openrouter',
ollama: 'ollama',
ovms: 'ovms',
'new-api': 'new-api',
lmstudio: 'lmstudio',
anthropic: 'anthropic',
openai: 'openai',
'azure-openai': 'azure-openai',
gemini: 'gemini',
vertexai: 'vertexai',
github: 'github',
copilot: 'copilot',
zhipu: 'zhipu',
yi: 'yi',
moonshot: 'moonshot',
baichuan: 'baichuan',
dashscope: 'dashscope',
stepfun: 'stepfun',
doubao: 'doubao',
infini: 'infini',
minimax: 'minimax',
groq: 'groq',
together: 'together',
fireworks: 'fireworks',
nvidia: 'nvidia',
grok: 'grok',
hyperbolic: 'hyperbolic',
mistral: 'mistral',
jina: 'jina',
perplexity: 'perplexity',
modelscope: 'modelscope',
xirang: 'xirang',
hunyuan: 'hunyuan',
'tencent-cloud-ti': 'tencent-cloud-ti',
'baidu-cloud': 'baidu-cloud',
gpustack: 'gpustack',
voyageai: 'voyageai',
'aws-bedrock': 'aws-bedrock',
poe: 'poe',
aionly: 'aionly',
longcat: 'longcat',
huggingface: 'huggingface',
'ai-gateway': 'ai-gateway',
cerebras: 'cerebras'
} as const satisfies Record<SystemProviderId, SystemProviderId>
type SystemProviderIdTypeMap = typeof SystemProviderIds
export type SystemProvider = Provider & {
id: SystemProviderId
isSystem: true

View File

@ -1,6 +1,15 @@
import store from '@renderer/store'
import type { VertexProvider } from '@renderer/types'
import { trim } from 'lodash'
export {
formatApiHost,
formatAzureOpenAIApiHost,
formatVertexApiHost,
getAiSdkBaseUrl,
hasAPIVersion,
routeToEndpoint,
SUPPORTED_ENDPOINT_LIST,
SUPPORTED_IMAGE_ENDPOINT_LIST,
validateApiHost,
withoutTrailingSlash
} from '@shared/api'
/**
* API key
@ -12,169 +21,6 @@ export function formatApiKeys(value: string): string {
return value.replaceAll('', ',').replaceAll('\n', ',')
}
/**
* host path /v1/v2beta
*
* @param host - host path
* @returns path true false
*/
export function hasAPIVersion(host?: string): boolean {
if (!host) return false
// 匹配路径中以 `/v<number>` 开头并可选跟随 `alpha` 或 `beta` 的版本段,
// 该段后面可以跟 `/` 或字符串结束(用于匹配诸如 `/v3alpha/resources` 的情况)。
const versionRegex = /\/v\d+(?:alpha|beta)?(?=\/|$)/i
try {
const url = new URL(host)
return versionRegex.test(url.pathname)
} catch {
// 若无法作为完整 URL 解析,则当作路径直接检测
return versionRegex.test(host)
}
}
/**
* Removes the trailing slash from a URL string if it exists.
*
* @template T - The string type to preserve type safety
* @param {T} url - The URL string to process
* @returns {T} The URL string without a trailing slash
*
* @example
* ```ts
* withoutTrailingSlash('https://example.com/') // 'https://example.com'
* withoutTrailingSlash('https://example.com') // 'https://example.com'
* ```
*/
export function withoutTrailingSlash<T extends string>(url: T): T {
return url.replace(/\/$/, '') as T
}
/**
* Formats an API host URL by normalizing it and optionally appending an API version.
*
* @param host - The API host URL to format. Leading/trailing whitespace will be trimmed and trailing slashes removed.
* @param isSupportedAPIVerion - Whether the API version is supported. Defaults to `true`.
* @param apiVersion - The API version to append if needed. Defaults to `'v1'`.
*
* @returns The formatted API host URL. If the host is empty after normalization, returns an empty string.
* If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host as-is.
* Otherwise, returns the host with the API version appended.
*
* @example
* formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1'
* formatApiHost('https://api.example.com#') // Returns 'https://api.example.com#'
* formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2'
*/
export function formatApiHost(host?: string, isSupportedAPIVerion: boolean = true, apiVersion: string = 'v1'): string {
const normalizedHost = withoutTrailingSlash(trim(host))
if (!normalizedHost) {
return ''
}
if (normalizedHost.endsWith('#') || !isSupportedAPIVerion || hasAPIVersion(normalizedHost)) {
return normalizedHost
}
return `${normalizedHost}/${apiVersion}`
}
/**
* Azure OpenAI API
*/
export function formatAzureOpenAIApiHost(host: string): string {
const normalizedHost = withoutTrailingSlash(host)
?.replace(/\/v1$/, '')
.replace(/\/openai$/, '')
// NOTE: AISDK会添加上`v1`
return formatApiHost(normalizedHost + '/openai', false)
}
export function formatVertexApiHost(provider: VertexProvider): string {
const { apiHost } = provider
const { projectId: project, location } = store.getState().llm.settings.vertexai
const trimmedHost = withoutTrailingSlash(trim(apiHost))
if (!trimmedHost || trimmedHost.endsWith('aiplatform.googleapis.com')) {
const host =
location == 'global' ? 'https://aiplatform.googleapis.com' : `https://${location}-aiplatform.googleapis.com`
return `${formatApiHost(host)}/projects/${project}/locations/${location}`
}
return formatApiHost(trimmedHost)
}
// 目前对话界面只支持这些端点
export const SUPPORTED_IMAGE_ENDPOINT_LIST = ['images/generations', 'images/edits', 'predict'] as const
export const SUPPORTED_ENDPOINT_LIST = [
'chat/completions',
'responses',
'messages',
'generateContent',
'streamGenerateContent',
...SUPPORTED_IMAGE_ENDPOINT_LIST
] as const
/**
* Converts an API host URL into separate base URL and endpoint components.
*
* @param apiHost - The API host string to parse. Expected to be a trimmed URL that may end with '#' followed by an endpoint identifier.
* @returns An object containing:
* - `baseURL`: The base URL without the endpoint suffix
* - `endpoint`: The matched endpoint identifier, or empty string if no match found
*
* @description
* This function extracts endpoint information from a composite API host string.
* If the host ends with '#', it attempts to match the preceding part against the supported endpoint list.
* The '#' delimiter is removed before processing.
*
* @example
* routeToEndpoint('https://api.example.com/openai/chat/completions#')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: 'chat/completions' }
*
* @example
* routeToEndpoint('https://api.example.com/v1')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: '' }
*/
export function routeToEndpoint(apiHost: string): { baseURL: string; endpoint: string } {
const trimmedHost = trim(apiHost)
// 前面已经确保apiHost合法
if (!trimmedHost.endsWith('#')) {
return { baseURL: trimmedHost, endpoint: '' }
}
// 去掉结尾的 #
const host = trimmedHost.slice(0, -1)
const endpointMatch = SUPPORTED_ENDPOINT_LIST.find((endpoint) => host.endsWith(endpoint))
if (!endpointMatch) {
const baseURL = withoutTrailingSlash(host)
return { baseURL, endpoint: '' }
}
const baseSegment = host.slice(0, host.length - endpointMatch.length)
const baseURL = withoutTrailingSlash(baseSegment).replace(/:$/, '') // 去掉结尾可能存在的冒号(gemini的特殊情况)
return { baseURL, endpoint: endpointMatch }
}
/**
* API
*
* @param {string} apiHost - API
* @returns {boolean} URL true false
*/
export function validateApiHost(apiHost: string): boolean {
// 允许apiHost为空
if (!apiHost || !trim(apiHost)) {
return true
}
try {
const url = new URL(trim(apiHost))
// 验证协议是否为 http 或 https
if (url.protocol !== 'http:' && url.protocol !== 'https:') {
return false
}
return true
} catch {
return false
}
}
/**
* API key
*

View File

@ -2,6 +2,8 @@ import { getProviderLabel } from '@renderer/i18n/label'
import type { Provider } from '@renderer/types'
import { isSystemProvider } from '@renderer/types'
export { getBaseModelName, getLowerBaseModelName } from '@shared/utils/naming'
/**
* ID
*
@ -50,38 +52,6 @@ export const getDefaultGroupName = (id: string, provider?: string): string => {
return str
}
/**
* ID
*
* - 'deepseek/deepseek-r1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1'
* @param {string} id ID
* @param {string} [delimiter='/'] '/'
* @returns {string}
*/
export const getBaseModelName = (id: string, delimiter: string = '/'): string => {
const parts = id.split(delimiter)
return parts[parts.length - 1]
}
/**
* ID
*
* - 'deepseek/DeepSeek-R1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1'
* @param {string} id ID
* @param {string} [delimiter='/'] '/'
* @returns {string}
*/
export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => {
const baseModelName = getBaseModelName(id, delimiter).toLowerCase()
// for openrouter
if (baseModelName.endsWith(':free')) {
return baseModelName.replace(':free', '')
}
return baseModelName
}
/**
*
* @param provider

View File

@ -1,10 +1,20 @@
import { CLAUDE_SUPPORTED_PROVIDERS } from '@renderer/pages/code'
import type { AzureOpenAIProvider, ProviderType, VertexProvider } from '@renderer/types'
import type { ProviderType } from '@renderer/types'
import { isSystemProvider, type Provider, type SystemProviderId, SystemProviderIds } from '@renderer/types'
export const isAzureResponsesEndpoint = (provider: AzureOpenAIProvider) => {
return provider.apiVersion === 'preview' || provider.apiVersion === 'v1'
}
export {
isAIGatewayProvider,
isAnthropicProvider,
isAwsBedrockProvider,
isAzureOpenAIProvider,
isAzureResponsesEndpoint,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isOpenAICompatibleProvider,
isOpenAIProvider,
isPerplexityProvider,
isVertexProvider
} from '@shared/provider'
export const getClaudeSupportedProviders = (providers: Provider[]) => {
return providers.filter(
@ -119,55 +129,6 @@ export const isGeminiWebSearchProvider = (provider: Provider) => {
return SUPPORT_GEMINI_NATIVE_WEB_SEARCH_PROVIDERS.some((id) => id === provider.id)
}
export const isNewApiProvider = (provider: Provider) => {
return ['new-api', 'cherryin'].includes(provider.id) || provider.type === 'new-api'
}
export function isCherryAIProvider(provider: Provider): boolean {
return provider.id === 'cherryai'
}
export function isPerplexityProvider(provider: Provider): boolean {
return provider.id === 'perplexity'
}
/**
* OpenAI
* @param {Provider} provider
* @returns {boolean} OpenAI
*/
export function isOpenAICompatibleProvider(provider: Provider): boolean {
return ['openai', 'new-api', 'mistral'].includes(provider.type)
}
export function isAzureOpenAIProvider(provider: Provider): provider is AzureOpenAIProvider {
return provider.type === 'azure-openai'
}
export function isOpenAIProvider(provider: Provider): boolean {
return provider.type === 'openai-response'
}
export function isVertexProvider(provider: Provider): provider is VertexProvider {
return provider.type === 'vertexai'
}
export function isAwsBedrockProvider(provider: Provider): boolean {
return provider.type === 'aws-bedrock'
}
export function isAnthropicProvider(provider: Provider): boolean {
return provider.type === 'anthropic'
}
export function isGeminiProvider(provider: Provider): boolean {
return provider.type === 'gemini'
}
export function isAIGatewayProvider(provider: Provider): boolean {
return provider.type === 'ai-gateway'
}
const NOT_SUPPORT_API_VERSION_PROVIDERS = ['github', 'copilot', 'perplexity'] as const satisfies SystemProviderId[]
export const isSupportAPIVersionProvider = (provider: Provider) => {