mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-10 07:19:02 +08:00
feat: add CherryAI signed fetch wrapper and enhance tool conversion to Zod schema
This commit is contained in:
parent
95c18d192a
commit
ed769ac4f7
@ -88,6 +88,12 @@ export interface AiSdkConfigContext {
|
||||
* Renderer process: use browser fetch (default)
|
||||
*/
|
||||
fetch?: typeof globalThis.fetch
|
||||
|
||||
/**
|
||||
* Get CherryAI signed fetch wrapper
|
||||
* Returns a fetch function that adds signature headers to requests
|
||||
*/
|
||||
getCherryAISignedFetch?: () => typeof globalThis.fetch
|
||||
}
|
||||
|
||||
/**
|
||||
@ -220,8 +226,13 @@ export function providerToAiSdkConfig(
|
||||
}
|
||||
}
|
||||
|
||||
// Inject custom fetch if provided
|
||||
if (context.fetch) {
|
||||
// Handle cherryai signed fetch
|
||||
if (provider.id === 'cherryai') {
|
||||
const signedFetch = context.getCherryAISignedFetch?.()
|
||||
if (signedFetch) {
|
||||
extraOptions.fetch = signedFetch
|
||||
}
|
||||
} else if (context.fetch) {
|
||||
extraOptions.fetch = context.fetch
|
||||
}
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ import type {
|
||||
import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core'
|
||||
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import { generateSignature as cherryaiGenerateSignature } from '@main/integration/cherryai'
|
||||
import anthropicService from '@main/services/AnthropicService'
|
||||
import copilotService from '@main/services/CopilotService'
|
||||
import { reduxService } from '@main/services/ReduxService'
|
||||
@ -26,10 +27,11 @@ import {
|
||||
import { COPILOT_DEFAULT_HEADERS } from '@shared/provider/constant'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import type { Provider } from '@types'
|
||||
import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai'
|
||||
import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel } from 'ai'
|
||||
import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool as AiSdkTool } from 'ai'
|
||||
import { simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel, zodSchema } from 'ai'
|
||||
import { net } from 'electron'
|
||||
import type { Response } from 'express'
|
||||
import * as z from 'zod'
|
||||
|
||||
import { reasoningCache } from './cache'
|
||||
|
||||
@ -124,19 +126,119 @@ function convertAnthropicToolResultToAiSdk(
|
||||
return { type: 'content', value: values }
|
||||
}
|
||||
|
||||
function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Record<string, Tool> | undefined {
|
||||
// Type alias for JSON Schema (compatible with recursive calls)
|
||||
type JsonSchemaLike = AnthropicTool.InputSchema | Record<string, unknown>
|
||||
|
||||
/**
|
||||
* Convert JSON Schema to Zod schema
|
||||
* This avoids non-standard fields like input_examples that Anthropic doesn't support
|
||||
*/
|
||||
function jsonSchemaToZod(schema: JsonSchemaLike): z.ZodTypeAny {
|
||||
const s = schema as Record<string, unknown>
|
||||
const schemaType = s.type as string | string[] | undefined
|
||||
const enumValues = s.enum as unknown[] | undefined
|
||||
const description = s.description as string | undefined
|
||||
|
||||
// Handle enum first
|
||||
if (enumValues && Array.isArray(enumValues) && enumValues.length > 0) {
|
||||
if (enumValues.every((v) => typeof v === 'string')) {
|
||||
const zodEnum = z.enum(enumValues as [string, ...string[]])
|
||||
return description ? zodEnum.describe(description) : zodEnum
|
||||
}
|
||||
// For non-string enums, use union of literals
|
||||
const literals = enumValues.map((v) => z.literal(v as string | number | boolean))
|
||||
if (literals.length === 1) {
|
||||
return description ? literals[0].describe(description) : literals[0]
|
||||
}
|
||||
const zodUnion = z.union(literals as unknown as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
|
||||
return description ? zodUnion.describe(description) : zodUnion
|
||||
}
|
||||
|
||||
// Handle union types (type: ["string", "null"])
|
||||
if (Array.isArray(schemaType)) {
|
||||
const schemas = schemaType.map((t) => jsonSchemaToZod({ ...s, type: t, enum: undefined }))
|
||||
if (schemas.length === 1) {
|
||||
return schemas[0]
|
||||
}
|
||||
return z.union(schemas as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
|
||||
}
|
||||
|
||||
// Handle by type
|
||||
switch (schemaType) {
|
||||
case 'string': {
|
||||
let zodString = z.string()
|
||||
if (typeof s.minLength === 'number') zodString = zodString.min(s.minLength)
|
||||
if (typeof s.maxLength === 'number') zodString = zodString.max(s.maxLength)
|
||||
if (typeof s.pattern === 'string') zodString = zodString.regex(new RegExp(s.pattern))
|
||||
return description ? zodString.describe(description) : zodString
|
||||
}
|
||||
|
||||
case 'number':
|
||||
case 'integer': {
|
||||
let zodNumber = schemaType === 'integer' ? z.number().int() : z.number()
|
||||
if (typeof s.minimum === 'number') zodNumber = zodNumber.min(s.minimum)
|
||||
if (typeof s.maximum === 'number') zodNumber = zodNumber.max(s.maximum)
|
||||
return description ? zodNumber.describe(description) : zodNumber
|
||||
}
|
||||
|
||||
case 'boolean': {
|
||||
const zodBoolean = z.boolean()
|
||||
return description ? zodBoolean.describe(description) : zodBoolean
|
||||
}
|
||||
|
||||
case 'null':
|
||||
return z.null()
|
||||
|
||||
case 'array': {
|
||||
const items = s.items as Record<string, unknown> | undefined
|
||||
let zodArray = items ? z.array(jsonSchemaToZod(items)) : z.array(z.unknown())
|
||||
if (typeof s.minItems === 'number') zodArray = zodArray.min(s.minItems)
|
||||
if (typeof s.maxItems === 'number') zodArray = zodArray.max(s.maxItems)
|
||||
return description ? zodArray.describe(description) : zodArray
|
||||
}
|
||||
|
||||
case 'object': {
|
||||
const properties = s.properties as Record<string, Record<string, unknown>> | undefined
|
||||
const required = (s.required as string[]) || []
|
||||
|
||||
// Always use z.object() to ensure "properties" field is present in output schema
|
||||
// OpenAI requires explicit properties field even for empty objects
|
||||
const shape: Record<string, z.ZodTypeAny> = {}
|
||||
if (properties) {
|
||||
for (const [key, propSchema] of Object.entries(properties)) {
|
||||
const zodProp = jsonSchemaToZod(propSchema)
|
||||
shape[key] = required.includes(key) ? zodProp : zodProp.optional()
|
||||
}
|
||||
}
|
||||
|
||||
const zodObject = z.object(shape)
|
||||
return description ? zodObject.describe(description) : zodObject
|
||||
}
|
||||
|
||||
default:
|
||||
// Unknown type, use z.unknown()
|
||||
return z.unknown()
|
||||
}
|
||||
}
|
||||
|
||||
function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Record<string, AiSdkTool> | undefined {
|
||||
if (!tools || tools.length === 0) return undefined
|
||||
|
||||
const aiSdkTools: Record<string, Tool> = {}
|
||||
const aiSdkTools: Record<string, AiSdkTool> = {}
|
||||
for (const anthropicTool of tools) {
|
||||
if (anthropicTool.type === 'bash_20250124') continue
|
||||
const toolDef = anthropicTool as AnthropicTool
|
||||
const parameters = toolDef.input_schema as Parameters<typeof jsonSchema>[0]
|
||||
aiSdkTools[toolDef.name] = tool({
|
||||
const rawSchema = toolDef.input_schema
|
||||
const schema = jsonSchemaToZod(rawSchema)
|
||||
|
||||
// Use tool() with inputSchema (AI SDK v5 API)
|
||||
const aiTool = tool({
|
||||
description: toolDef.description || '',
|
||||
inputSchema: jsonSchema(parameters),
|
||||
execute: async (input: Record<string, unknown>) => input
|
||||
inputSchema: zodSchema(schema)
|
||||
})
|
||||
|
||||
logger.debug('Converted Anthropic tool to AI SDK tool', aiTool)
|
||||
aiSdkTools[toolDef.name] = aiTool
|
||||
}
|
||||
return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined
|
||||
}
|
||||
@ -343,8 +445,30 @@ async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkCon
|
||||
}
|
||||
break
|
||||
}
|
||||
// Note: cherryai requires request-level signing which is not easily supported here
|
||||
// It would need custom fetch implementation similar to renderer
|
||||
case 'cherryai': {
|
||||
// Create a signed fetch wrapper for cherryai
|
||||
const baseFetch = net.fetch as typeof globalThis.fetch
|
||||
config.options.fetch = async (url: RequestInfo | URL, options?: RequestInit) => {
|
||||
if (!options?.body) {
|
||||
return baseFetch(url, options)
|
||||
}
|
||||
const signature = cherryaiGenerateSignature({
|
||||
method: 'POST',
|
||||
path: '/chat/completions',
|
||||
query: '',
|
||||
body: JSON.parse(options.body as string)
|
||||
})
|
||||
return baseFetch(url, {
|
||||
...options,
|
||||
headers: {
|
||||
...(options.headers as Record<string, string>),
|
||||
...signature
|
||||
}
|
||||
})
|
||||
}
|
||||
logger.debug('CherryAI signed fetch configured')
|
||||
break
|
||||
}
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user