chore(dependencies): update ai package version and enhance aiCore functionality

- Updated the 'ai' package version from 5.0.26 to 5.0.29 in package.json and yarn.lock.
- Refactored aiCore's provider schemas to introduce new provider types and improve type safety.
- Enhanced the RuntimeExecutor class to streamline model handling and plugin execution for various AI tasks.
- Updated tests to reflect changes in parameter handling and ensure compatibility with new provider configurations.
This commit is contained in:
MyPrototypeWhat 2025-09-02 18:54:07 +08:00
parent 4cceddc179
commit 782953cca1
21 changed files with 570 additions and 505 deletions

View File

@ -174,7 +174,7 @@
"@viz-js/lang-dot": "^1.0.5",
"@viz-js/viz": "^3.14.0",
"@xyflow/react": "^12.4.4",
"ai": "^5.0.26",
"ai": "^5.0.29",
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
"archiver": "^7.0.1",
"async-mutex": "^0.5.0",

View File

@ -4,11 +4,13 @@
import { createAnthropic } from '@ai-sdk/anthropic'
import { createAzure } from '@ai-sdk/azure'
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
import { createDeepSeek } from '@ai-sdk/deepseek'
import { createGoogleGenerativeAI } from '@ai-sdk/google'
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
import { createXai } from '@ai-sdk/xai'
import { customProvider, type Provider } from 'ai'
import * as z from 'zod'
/**
@ -16,12 +18,13 @@ import * as z from 'zod'
*/
export const baseProviderIds = [
'openai',
'openai-responses',
'openai-chat',
'openai-compatible',
'anthropic',
'google',
'xai',
'azure',
'azure-responses',
'deepseek'
] as const
@ -38,7 +41,7 @@ export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
export const baseProviderSchema = z.object({
id: baseProviderIdSchema,
name: z.string(),
creator: z.function().args(z.any()).returns(z.any()),
creator: z.function().args(z.any()).returns(z.any()) as z.ZodType<(options: any) => Provider>,
supportsImageGeneration: z.boolean()
})
@ -56,9 +59,17 @@ export const baseProviders = [
supportsImageGeneration: true
},
{
id: 'openai-responses',
name: 'OpenAI Responses',
creator: (options: OpenAIProviderSettings) => createOpenAI(options).responses,
id: 'openai-chat',
name: 'OpenAI Chat',
creator: (options: OpenAIProviderSettings) => {
const provider = createOpenAI(options)
return customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.chat(modelId)
}
})
},
supportsImageGeneration: true
},
{
@ -91,6 +102,20 @@ export const baseProviders = [
creator: createAzure,
supportsImageGeneration: true
},
{
id: 'azure-responses',
name: 'Azure OpenAI Responses',
creator: (options: AzureOpenAIProviderSettings) => {
const provider = createAzure(options)
return customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.responses(modelId)
}
})
},
supportsImageGeneration: true
},
{
id: 'deepseek',
name: 'DeepSeek',

View File

@ -4,7 +4,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
import { type AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { ImageGenerationError } from '../errors'
import { ImageGenerationError, ImageModelResolutionError } from '../errors'
import { RuntimeExecutor } from '../executor'
// Mock dependencies
@ -76,9 +76,7 @@ describe('RuntimeExecutor.generateImage', () => {
describe('Basic functionality', () => {
it('should generate a single image with minimal parameters', async () => {
const result = await executor.generateImage('dall-e-3', {
prompt: 'A futuristic cityscape at sunset'
})
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape at sunset' })
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai|dall-e-3')
@ -91,7 +89,8 @@ describe('RuntimeExecutor.generateImage', () => {
})
it('should generate image with pre-created model', async () => {
const result = await executor.generateImage(mockImageModel, {
const result = await executor.generateImage({
model: mockImageModel,
prompt: 'A beautiful landscape'
})
@ -105,10 +104,7 @@ describe('RuntimeExecutor.generateImage', () => {
})
it('should support multiple images generation', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A futuristic cityscape',
n: 3
})
await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape', n: 3 })
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
@ -118,10 +114,7 @@ describe('RuntimeExecutor.generateImage', () => {
})
it('should support size specification', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A beautiful sunset',
size: '1024x1024'
})
await executor.generateImage({ model: 'dall-e-3', prompt: 'A beautiful sunset', size: '1024x1024' })
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
@ -131,10 +124,7 @@ describe('RuntimeExecutor.generateImage', () => {
})
it('should support aspect ratio specification', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A mountain landscape',
aspectRatio: '16:9'
})
await executor.generateImage({ model: 'dall-e-3', prompt: 'A mountain landscape', aspectRatio: '16:9' })
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
@ -144,10 +134,7 @@ describe('RuntimeExecutor.generateImage', () => {
})
it('should support seed for consistent output', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A cat in space',
seed: 1234567890
})
await executor.generateImage({ model: 'dall-e-3', prompt: 'A cat in space', seed: 1234567890 })
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
@ -159,10 +146,7 @@ describe('RuntimeExecutor.generateImage', () => {
it('should support abort signal', async () => {
const abortController = new AbortController()
await executor.generateImage('dall-e-3', {
prompt: 'A cityscape',
abortSignal: abortController.signal
})
await executor.generateImage({ model: 'dall-e-3', prompt: 'A cityscape', abortSignal: abortController.signal })
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
@ -172,7 +156,8 @@ describe('RuntimeExecutor.generateImage', () => {
})
it('should support provider-specific options', async () => {
await executor.generateImage('dall-e-3', {
await executor.generateImage({
model: 'dall-e-3',
prompt: 'A space station',
providerOptions: {
openai: {
@ -195,7 +180,8 @@ describe('RuntimeExecutor.generateImage', () => {
})
it('should support custom headers', async () => {
await executor.generateImage('dall-e-3', {
await executor.generateImage({
model: 'dall-e-3',
prompt: 'A robot',
headers: {
'X-Custom-Header': 'test-value'
@ -242,9 +228,7 @@ describe('RuntimeExecutor.generateImage', () => {
[testPlugin]
)
const result = await executorWithPlugin.generateImage('dall-e-3', {
prompt: 'A test image'
})
const result = await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
@ -287,9 +271,7 @@ describe('RuntimeExecutor.generateImage', () => {
[modelResolutionPlugin]
)
await executorWithPlugin.generateImage('dall-e-3', {
prompt: 'A test image'
})
await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith(
'dall-e-3',
@ -312,6 +294,7 @@ describe('RuntimeExecutor.generateImage', () => {
if (!context.isRecursiveCall && params.prompt === 'original') {
// Make a recursive call with modified prompt
await context.recursiveCall({
model: 'dall-e-3',
prompt: 'modified'
})
}
@ -327,9 +310,7 @@ describe('RuntimeExecutor.generateImage', () => {
[recursivePlugin]
)
await executorWithPlugin.generateImage('dall-e-3', {
prompt: 'original'
})
await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'original' })
expect(recursivePlugin.transformParams).toHaveBeenCalledTimes(2)
expect(aiGenerateImage).toHaveBeenCalledTimes(2)
@ -343,22 +324,47 @@ describe('RuntimeExecutor.generateImage', () => {
throw modelError
})
await expect(
executor.generateImage('invalid-model', {
prompt: 'A test image'
})
).rejects.toThrow(ImageGenerationError)
await expect(executor.generateImage({ model: 'invalid-model', prompt: 'A test image' })).rejects.toThrow(
ImageGenerationError
)
})
it('should handle ImageModelResolutionError correctly', async () => {
const resolutionError = new ImageModelResolutionError('invalid-model', 'openai', new Error('Model not found'))
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
throw resolutionError
})
const thrownError = await executor
.generateImage({ model: 'invalid-model', prompt: 'A test image' })
.catch((error) => error)
expect(thrownError).toBeInstanceOf(ImageGenerationError)
expect(thrownError.message).toContain('Failed to generate image:')
expect(thrownError.providerId).toBe('openai')
expect(thrownError.modelId).toBe('invalid-model')
expect(thrownError.cause).toBeInstanceOf(ImageModelResolutionError)
expect(thrownError.cause.message).toContain('Failed to resolve image model: invalid-model')
})
it('should handle ImageModelResolutionError without provider', async () => {
const resolutionError = new ImageModelResolutionError('unknown-model')
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
throw resolutionError
})
await expect(executor.generateImage({ model: 'unknown-model', prompt: 'A test image' })).rejects.toThrow(
ImageGenerationError
)
})
it('should handle image generation API errors', async () => {
const apiError = new Error('API request failed')
vi.mocked(aiGenerateImage).mockRejectedValue(apiError)
await expect(
executor.generateImage('dall-e-3', {
prompt: 'A test image'
})
).rejects.toThrow('Failed to generate image:')
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
'Failed to generate image:'
)
})
it('should handle NoImageGeneratedError', async () => {
@ -370,11 +376,9 @@ describe('RuntimeExecutor.generateImage', () => {
vi.mocked(aiGenerateImage).mockRejectedValue(noImageError)
vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true)
await expect(
executor.generateImage('dall-e-3', {
prompt: 'A test image'
})
).rejects.toThrow('Failed to generate image:')
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
'Failed to generate image:'
)
})
it('should execute onError plugin hook on failure', async () => {
@ -394,11 +398,9 @@ describe('RuntimeExecutor.generateImage', () => {
[errorPlugin]
)
await expect(
executorWithPlugin.generateImage('dall-e-3', {
prompt: 'A test image'
})
).rejects.toThrow('Failed to generate image:')
await expect(executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
'Failed to generate image:'
)
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
@ -418,10 +420,7 @@ describe('RuntimeExecutor.generateImage', () => {
setTimeout(() => abortController.abort(), 10)
await expect(
executor.generateImage('dall-e-3', {
prompt: 'A test image',
abortSignal: abortController.signal
})
executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', abortSignal: abortController.signal })
).rejects.toThrow('Failed to generate image:')
})
})
@ -432,9 +431,7 @@ describe('RuntimeExecutor.generateImage', () => {
apiKey: 'google-key'
})
await googleExecutor.generateImage('imagen-3.0-generate-002', {
prompt: 'A landscape'
})
await googleExecutor.generateImage({ model: 'imagen-3.0-generate-002', prompt: 'A landscape' })
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('google|imagen-3.0-generate-002')
})
@ -444,9 +441,7 @@ describe('RuntimeExecutor.generateImage', () => {
apiKey: 'xai-key'
})
await xaiExecutor.generateImage('grok-2-image', {
prompt: 'A futuristic robot'
})
await xaiExecutor.generateImage({ model: 'grok-2-image', prompt: 'A futuristic robot' })
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('xai|grok-2-image')
})
@ -454,11 +449,7 @@ describe('RuntimeExecutor.generateImage', () => {
describe('Advanced features', () => {
it('should support batch image generation with maxImagesPerCall', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A test image',
n: 10,
maxImagesPerCall: 5
})
await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', n: 10, maxImagesPerCall: 5 })
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
@ -469,10 +460,7 @@ describe('RuntimeExecutor.generateImage', () => {
})
it('should support retries with maxRetries', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A test image',
maxRetries: 3
})
await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', maxRetries: 3 })
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
@ -494,7 +482,8 @@ describe('RuntimeExecutor.generateImage', () => {
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithWarnings)
const result = await executor.generateImage('dall-e-3', {
const result = await executor.generateImage({
model: 'dall-e-3',
prompt: 'A test image',
size: '2048x2048' // Unsupported size
})
@ -504,9 +493,7 @@ describe('RuntimeExecutor.generateImage', () => {
})
it('should provide access to provider metadata', async () => {
const result = await executor.generateImage('dall-e-3', {
prompt: 'A test image'
})
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
expect(result.providerMetadata).toBeDefined()
expect(result.providerMetadata.openai).toBeDefined()
@ -526,9 +513,7 @@ describe('RuntimeExecutor.generateImage', () => {
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithMetadata)
const result = await executor.generateImage('dall-e-3', {
prompt: 'A test image'
})
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
expect(result.responses).toHaveLength(1)
expect(result.responses[0].modelId).toBe('dall-e-3')

View File

@ -72,42 +72,36 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
// === 高阶重载:直接使用模型 ===
/**
* - 使
*
*/
async streamText(
model: LanguageModel,
params: Omit<Parameters<typeof streamText>[0], 'model'>
): Promise<ReturnType<typeof streamText>>
async streamText(
modelId: string,
params: Omit<Parameters<typeof streamText>[0], 'model'>,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof streamText>>
async streamText(
modelOrId: LanguageModel,
params: Omit<Parameters<typeof streamText>[0], 'model'>,
params: Parameters<typeof streamText>[0],
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof streamText>> {
this.pluginEngine.usePlugins([
this.createResolveModelPlugin(options?.middlewares),
this.createConfigureContextPlugin()
])
const { model, ...restParams } = params
// 根据 model 类型决定插件配置
if (typeof model === 'string') {
this.pluginEngine.usePlugins([
this.createResolveModelPlugin(options?.middlewares),
this.createConfigureContextPlugin()
])
} else {
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
}
// 2. 执行插件处理
return this.pluginEngine.executeStreamWithPlugins(
'streamText',
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params,
async (model, transformedParams, streamTransforms) => {
model,
restParams,
async (resolvedModel, transformedParams, streamTransforms) => {
const experimental_transform =
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
const finalParams = {
model,
model: resolvedModel,
...transformedParams,
experimental_transform
} as Parameters<typeof streamText>[0]
@ -120,145 +114,126 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
// === 其他方法的重载 ===
/**
* - 使
*
*/
async generateText(
model: LanguageModel,
params: Omit<Parameters<typeof generateText>[0], 'model'>
): Promise<ReturnType<typeof generateText>>
async generateText(
modelId: string,
params: Omit<Parameters<typeof generateText>[0], 'model'>,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateText>>
async generateText(
modelOrId: LanguageModel | string,
params: Omit<Parameters<typeof generateText>[0], 'model'>,
params: Parameters<typeof generateText>[0],
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateText>> {
this.pluginEngine.usePlugins([
this.createResolveModelPlugin(options?.middlewares),
this.createConfigureContextPlugin()
])
const { model, ...restParams } = params
// 根据 model 类型决定插件配置
if (typeof model === 'string') {
this.pluginEngine.usePlugins([
this.createResolveModelPlugin(options?.middlewares),
this.createConfigureContextPlugin()
])
} else {
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
}
return this.pluginEngine.executeWithPlugins(
'generateText',
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params,
async (model, transformedParams) =>
generateText({ model, ...transformedParams } as Parameters<typeof generateText>[0])
model,
restParams,
async (resolvedModel, transformedParams) =>
generateText({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateText>[0])
)
}
/**
* - 使
*
*/
async generateObject(
model: LanguageModel,
params: Omit<Parameters<typeof generateObject>[0], 'model'>
): Promise<ReturnType<typeof generateObject>>
async generateObject(
modelOrId: string,
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateObject>>
async generateObject(
modelOrId: LanguageModel | string,
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
params: Parameters<typeof generateObject>[0],
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateObject>> {
this.pluginEngine.usePlugins([
this.createResolveModelPlugin(options?.middlewares),
this.createConfigureContextPlugin()
])
const { model, ...restParams } = params
// 根据 model 类型决定插件配置
if (typeof model === 'string') {
this.pluginEngine.usePlugins([
this.createResolveModelPlugin(options?.middlewares),
this.createConfigureContextPlugin()
])
} else {
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
}
return this.pluginEngine.executeWithPlugins(
'generateObject',
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params,
async (model, transformedParams) =>
generateObject({ model, ...transformedParams } as Parameters<typeof generateObject>[0])
model,
restParams,
async (resolvedModel, transformedParams) =>
generateObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateObject>[0])
)
}
/**
* - 使
*
*/
async streamObject(
model: LanguageModel,
params: Omit<Parameters<typeof streamObject>[0], 'model'>
): Promise<ReturnType<typeof streamObject>>
async streamObject(
modelId: string,
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof streamObject>>
async streamObject(
modelOrId: LanguageModel | string,
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
params: Parameters<typeof streamObject>[0],
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof streamObject>> {
this.pluginEngine.usePlugins([
this.createResolveModelPlugin(options?.middlewares),
this.createConfigureContextPlugin()
])
const { model, ...restParams } = params
// 根据 model 类型决定插件配置
if (typeof model === 'string') {
this.pluginEngine.usePlugins([
this.createResolveModelPlugin(options?.middlewares),
this.createConfigureContextPlugin()
])
} else {
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
}
return this.pluginEngine.executeWithPlugins(
'streamObject',
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params,
async (model, transformedParams) =>
streamObject({ model, ...transformedParams } as Parameters<typeof streamObject>[0])
model,
restParams,
async (resolvedModel, transformedParams) =>
streamObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof streamObject>[0])
)
}
/**
* - 使
*
*/
async generateImage(
model: ImageModelV2,
params: Omit<Parameters<typeof generateImage>[0], 'model'>
): Promise<ReturnType<typeof generateImage>>
async generateImage(
modelId: string,
params: Omit<Parameters<typeof generateImage>[0], 'model'>,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateImage>>
async generateImage(
modelOrId: ImageModelV2 | string,
params: Omit<Parameters<typeof generateImage>[0], 'model'>
params: Omit<Parameters<typeof generateImage>[0], 'model'> & { model: string | ImageModelV2 }
): Promise<ReturnType<typeof generateImage>> {
try {
this.pluginEngine.usePlugins([this.createResolveImageModelPlugin(), this.createConfigureContextPlugin()])
const { model, ...restParams } = params
// 根据 model 类型决定插件配置
if (typeof model === 'string') {
this.pluginEngine.usePlugins([this.createResolveImageModelPlugin(), this.createConfigureContextPlugin()])
} else {
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
}
return await this.pluginEngine.executeImageWithPlugins(
'generateImage',
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params,
async (model, transformedParams) => {
return await generateImage({ model, ...transformedParams })
model,
restParams,
async (resolvedModel, transformedParams) => {
return await generateImage({ model: resolvedModel, ...transformedParams })
}
)
} catch (error) {
if (error instanceof Error) {
const modelId = typeof params.model === 'string' ? params.model : params.model.modelId
throw new ImageGenerationError(
`Failed to generate image: ${error.message}`,
this.config.providerId,
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
modelId,
error
)
}

View File

@ -46,13 +46,12 @@ export function createOpenAICompatibleExecutor(
export async function streamText<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
modelId: string,
params: Parameters<RuntimeExecutor<T>['streamText']>[1],
params: Parameters<RuntimeExecutor<T>['streamText']>[0],
plugins?: AiPlugin[],
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.streamText(modelId, params, { middlewares })
return executor.streamText(params, { middlewares })
}
/**
@ -61,13 +60,12 @@ export async function streamText<T extends ProviderId>(
export async function generateText<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
modelId: string,
params: Parameters<RuntimeExecutor<T>['generateText']>[1],
params: Parameters<RuntimeExecutor<T>['generateText']>[0],
plugins?: AiPlugin[],
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.generateText(modelId, params, { middlewares })
return executor.generateText(params, { middlewares })
}
/**
@ -76,13 +74,12 @@ export async function generateText<T extends ProviderId>(
export async function generateObject<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
modelId: string,
params: Parameters<RuntimeExecutor<T>['generateObject']>[1],
params: Parameters<RuntimeExecutor<T>['generateObject']>[0],
plugins?: AiPlugin[],
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['generateObject']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.generateObject(modelId, params, { middlewares })
return executor.generateObject(params, { middlewares })
}
/**
@ -91,13 +88,12 @@ export async function generateObject<T extends ProviderId>(
export async function streamObject<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
modelId: string,
params: Parameters<RuntimeExecutor<T>['streamObject']>[1],
params: Parameters<RuntimeExecutor<T>['streamObject']>[0],
plugins?: AiPlugin[],
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['streamObject']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.streamObject(modelId, params, { middlewares })
return executor.streamObject(params, { middlewares })
}
/**
@ -106,13 +102,11 @@ export async function streamObject<T extends ProviderId>(
export async function generateImage<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
modelId: string,
params: Parameters<RuntimeExecutor<T>['generateImage']>[1],
plugins?: AiPlugin[],
middlewares?: LanguageModelV2Middleware[]
params: Parameters<RuntimeExecutor<T>['generateImage']>[0],
plugins?: AiPlugin[]
): Promise<ReturnType<RuntimeExecutor<T>['generateImage']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.generateImage(modelId, params, { middlewares })
return executor.generateImage(params)
}
// === Agent 功能预留 ===

View File

@ -64,11 +64,24 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
*/
async executeWithPlugins<TParams, TResult>(
methodName: string,
modelId: string,
model: LanguageModel,
params: TParams,
executor: (model: LanguageModel, transformedParams: TParams) => Promise<TResult>,
_context?: ReturnType<typeof createContext>
): Promise<TResult> {
// 统一处理模型解析
let resolvedModel: LanguageModel | undefined
let modelId: string
if (typeof model === 'string') {
// 字符串:需要通过插件解析
modelId = model
} else {
// 模型对象:直接使用
resolvedModel = model
modelId = model.modelId
}
// 使用正确的createContext创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params)
@ -76,7 +89,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
context.recursiveCall = async (newParams: any): Promise<TResult> => {
// 递归调用自身,重新走完整的插件流程
context.isRecursiveCall = true
const result = await this.executeWithPlugins(methodName, modelId, newParams, executor, context)
const result = await this.executeWithPlugins(methodName, model, newParams, executor, context)
context.isRecursiveCall = false
return result
}
@ -88,17 +101,24 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
// 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型
const model = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
if (!model) {
throw new Error(`Failed to resolve model: ${modelId}`)
// 2. 解析模型(如果是字符串)
if (typeof model === 'string') {
const resolved = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
if (!resolved) {
throw new Error(`Failed to resolve model: ${modelId}`)
}
resolvedModel = resolved
}
if (!resolvedModel) {
throw new Error(`Model resolution failed: no model available`)
}
// 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
// 4. 执行具体的 API 调用
const result = await executor(model, transformedParams)
const result = await executor(resolvedModel, transformedParams)
// 5. 转换结果(对于非流式调用)
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
@ -120,11 +140,24 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
*/
async executeImageWithPlugins<TParams, TResult>(
methodName: string,
modelId: string,
model: ImageModelV2 | string,
params: TParams,
executor: (model: ImageModelV2, transformedParams: TParams) => Promise<TResult>,
_context?: ReturnType<typeof createContext>
): Promise<TResult> {
// 统一处理模型解析
let resolvedModel: ImageModelV2 | undefined
let modelId: string
if (typeof model === 'string') {
// 字符串:需要通过插件解析
modelId = model
} else {
// 模型对象:直接使用
resolvedModel = model
modelId = model.modelId
}
// 使用正确的createContext创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params)
@ -132,7 +165,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
context.recursiveCall = async (newParams: any): Promise<TResult> => {
// 递归调用自身,重新走完整的插件流程
context.isRecursiveCall = true
const result = await this.executeImageWithPlugins(methodName, modelId, newParams, executor, context)
const result = await this.executeImageWithPlugins(methodName, model, newParams, executor, context)
context.isRecursiveCall = false
return result
}
@ -144,17 +177,24 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
// 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型
const model = await this.pluginManager.executeFirst<ImageModelV2>('resolveModel', modelId, context)
if (!model) {
throw new Error(`Failed to resolve image model: ${modelId}`)
// 2. 解析模型(如果是字符串)
if (typeof model === 'string') {
const resolved = await this.pluginManager.executeFirst<ImageModelV2>('resolveModel', modelId, context)
if (!resolved) {
throw new Error(`Failed to resolve image model: ${modelId}`)
}
resolvedModel = resolved
}
if (!resolvedModel) {
throw new Error(`Image model resolution failed: no model available`)
}
// 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
// 4. 执行具体的 API 调用
const result = await executor(model, transformedParams)
const result = await executor(resolvedModel, transformedParams)
// 5. 转换结果
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
@ -176,11 +216,24 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
*/
async executeStreamWithPlugins<TParams, TResult>(
methodName: string,
modelId: string,
model: LanguageModel,
params: TParams,
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
_context?: ReturnType<typeof createContext>
): Promise<TResult> {
// 统一处理模型解析
let resolvedModel: LanguageModel | undefined
let modelId: string
if (typeof model === 'string') {
// 字符串:需要通过插件解析
modelId = model
} else {
// 模型对象:直接使用
resolvedModel = model
modelId = model.modelId
}
// 创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params)
@ -188,7 +241,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
context.recursiveCall = async (newParams: any): Promise<TResult> => {
// 递归调用自身,重新走完整的插件流程
context.isRecursiveCall = true
const result = await this.executeStreamWithPlugins(methodName, modelId, newParams, executor, context)
const result = await this.executeStreamWithPlugins(methodName, model, newParams, executor, context)
context.isRecursiveCall = false
return result
}
@ -200,11 +253,17 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
// 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型
const model = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
// 2. 解析模型(如果是字符串)
if (typeof model === 'string') {
const resolved = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
if (!resolved) {
throw new Error(`Failed to resolve model: ${modelId}`)
}
resolvedModel = resolved
}
if (!model) {
throw new Error(`Failed to resolve model: ${modelId}`)
if (!resolvedModel) {
throw new Error(`Model resolution failed: no model available`)
}
// 3. 转换请求参数
@ -214,7 +273,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context)
// 5. 执行流式 API 调用
const result = await executor(model, transformedParams, streamTransforms)
const result = await executor(resolvedModel, transformedParams, streamTransforms)
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)

View File

@ -7,22 +7,23 @@
* 2.
*/
import { createExecutor, generateImage } from '@cherrystudio/ai-core'
import { createAndRegisterProvider } from '@cherrystudio/ai-core/provider'
import { createExecutor } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import type { Assistant, Model, Provider } from '@renderer/types'
import type { AiSdkModel, StreamTextParams } from '@renderer/types/aiCoreTypes'
import { ChunkType } from '@renderer/types/chunk'
import { type ImageModel, type LanguageModel, type Provider as AiSdkProvider, wrapLanguageModel } from 'ai'
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
import LegacyAiProvider from './legacy/index'
import { CompletionsResult } from './legacy/middleware/schemas'
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
import { buildPlugins } from './plugins/PluginBuilder'
import { createAiSdkProvider } from './provider/factory'
import {
getActualProvider,
isModernSdkSupported,
@ -44,6 +45,7 @@ export default class ModernAiProvider {
private config: ReturnType<typeof providerToAiSdkConfig>
private actualProvider: Provider
private model: Model
private localProvider: Awaited<AiSdkProvider> | null = null
constructor(model: Model, provider?: Provider) {
this.actualProvider = provider || getActualProvider(model)
@ -62,44 +64,58 @@ export default class ModernAiProvider {
// 准备特殊配置
await prepareSpecialProviderConfig(this.actualProvider, this.config)
logger.debug('this.config', this.config)
// 提前创建本地 provider 实例
if (!this.localProvider) {
this.localProvider = await createAiSdkProvider(this.config)
}
// 提前构建中间件
const middlewares = buildAiSdkMiddlewares({
...config,
provider: this.actualProvider
})
logger.debug('Built middlewares in completions', {
middlewareCount: middlewares.length,
isImageGeneration: config.isImageGenerationEndpoint
})
if (!this.localProvider) {
throw new Error('Local provider not created')
}
// 根据endpoint类型创建对应的模型
let model: AiSdkModel | undefined
if (config.isImageGenerationEndpoint) {
model = this.localProvider.imageModel(modelId)
} else {
model = this.localProvider.languageModel(modelId)
// 如果有中间件,应用到语言模型上
if (middlewares.length > 0 && typeof model === 'object') {
model = wrapLanguageModel({ model, middleware: middlewares })
}
}
if (config.topicId && getEnableDeveloperMode()) {
// TypeScript类型窄化确保topicId是string类型
const traceConfig = {
...config,
topicId: config.topicId
}
return await this._completionsForTrace(modelId, params, traceConfig)
return await this._completionsForTrace(model, params, traceConfig)
} else {
return await this._completions(modelId, params, config)
return await this._completionsOrImageGeneration(model, params, config)
}
}
private async _completions(
modelId: string,
private async _completionsOrImageGeneration(
model: AiSdkModel,
params: StreamTextParams,
config: ModernAiProviderConfig
): Promise<CompletionsResult> {
// 初始化 provider 到全局管理器
try {
await createAndRegisterProvider(this.config.providerId, this.config.options)
logger.debug('Provider initialized successfully', {
providerId: this.config.providerId,
hasOptions: !!this.config.options
})
} catch (error) {
// 如果 provider 已经初始化过,可能会抛出错误,这里可以忽略
logger.debug('Provider initialization skipped (may already be initialized)', {
providerId: this.config.providerId,
error: error instanceof Error ? error.message : String(error)
})
}
if (config.isImageGenerationEndpoint) {
return await this.modernImageGeneration(modelId, params, config)
return await this.modernImageGeneration(model as ImageModel, params, config)
}
return await this.modernCompletions(modelId, params, config)
return await this.modernCompletions(model as LanguageModel, params, config)
}
/**
@ -107,10 +123,11 @@ export default class ModernAiProvider {
* legacy的completionsForTraceAI SDK spans在正确的trace上下文中
*/
private async _completionsForTrace(
modelId: string,
model: AiSdkModel,
params: StreamTextParams,
config: ModernAiProviderConfig & { topicId: string }
): Promise<CompletionsResult> {
const modelId = this.model.id
const traceName = `${this.actualProvider.name}.${modelId}.${config.callType}`
const traceParams: StartSpanParams = {
name: traceName,
@ -136,7 +153,7 @@ export default class ModernAiProvider {
modelId,
traceName
})
return await this._completions(modelId, params, config)
return await this._completionsOrImageGeneration(model, params, config)
}
try {
@ -148,7 +165,7 @@ export default class ModernAiProvider {
parentSpanCreated: true
})
const result = await this._completions(modelId, params, config)
const result = await this._completionsOrImageGeneration(model, params, config)
logger.info('Completions finished, ending parent span', {
spanId: span.spanContext().spanId,
@ -190,10 +207,11 @@ export default class ModernAiProvider {
* 使AI SDK的completions实现
*/
private async modernCompletions(
modelId: string,
model: LanguageModel,
params: StreamTextParams,
config: ModernAiProviderConfig
): Promise<CompletionsResult> {
const modelId = this.model.id
logger.info('Starting modernCompletions', {
modelId,
providerId: this.config.providerId,
@ -205,122 +223,48 @@ export default class ModernAiProvider {
// 根据条件构建插件数组
const plugins = buildPlugins(config)
logger.debug('Built plugins for AI SDK', {
pluginCount: plugins.length,
pluginNames: plugins.map((p) => p.name),
providerId: this.config.providerId,
topicId: config.topicId
})
// 用构建好的插件数组创建executor
const executor = createExecutor(this.config.providerId, this.config.options, plugins)
logger.debug('Created AI SDK executor', {
providerId: this.config.providerId,
hasOptions: !!this.config.options,
pluginCount: plugins.length
})
// 动态构建中间件数组
const middlewares = buildAiSdkMiddlewares(config)
logger.debug('Built AI SDK middlewares', {
middlewareCount: middlewares.length,
topicId: config.topicId
})
// 创建带有中间件的执行器
if (config.onChunk) {
// 流式处理 - 使用适配器
logger.info('Starting streaming with chunk adapter', {
modelId,
hasMiddlewares: middlewares.length > 0,
middlewareCount: middlewares.length,
hasMcpTools: !!config.mcpTools,
mcpToolCount: config.mcpTools?.length || 0,
topicId: config.topicId
})
const accumulate = this.model.supported_text_delta !== false // true and undefined
const adapter = new AiSdkToChunkAdapter(config.onChunk, config.mcpTools, accumulate)
logger.debug('Final params before streamText', {
modelId,
hasMessages: !!params.messages,
messageCount: params.messages?.length || 0,
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
toolNames: params.tools ? Object.keys(params.tools) : [],
hasSystem: !!params.system,
topicId: config.topicId
})
const streamResult = await executor.streamText(
modelId,
{ ...params, experimental_context: { onChunk: config.onChunk } },
middlewares.length > 0 ? { middlewares } : undefined
)
logger.info('StreamText call successful, processing stream', {
modelId,
topicId: config.topicId,
hasFullStream: !!streamResult.fullStream
const streamResult = await executor.streamText({
...params,
model,
experimental_context: { onChunk: config.onChunk }
})
const finalText = await adapter.processStream(streamResult)
logger.info('Stream processing completed', {
modelId,
topicId: config.topicId,
finalTextLength: finalText.length
})
return {
getText: () => finalText
}
} else {
// 流式处理但没有 onChunk 回调
logger.info('Starting streaming without chunk callback', {
modelId,
hasMiddlewares: middlewares.length > 0,
middlewareCount: middlewares.length,
topicId: config.topicId
const streamResult = await executor.streamText({
...params,
model
})
const streamResult = await executor.streamText(
modelId,
params,
middlewares.length > 0 ? { middlewares } : undefined
)
logger.info('StreamText call successful, waiting for text', {
modelId,
topicId: config.topicId
})
// 强制消费流,不然await streamResult.text会阻塞
await streamResult?.consumeStream()
const finalText = await streamResult.text
logger.info('Text extraction completed', {
modelId,
topicId: config.topicId,
finalTextLength: finalText.length
})
return {
getText: () => finalText
}
}
// }
// catch (error) {
// console.error('Modern AI SDK error:', error)
// throw error
// }
}
/**
* 使 AI SDK
*/
private async modernImageGeneration(
modelId: string,
model: ImageModel,
params: StreamTextParams,
config: ModernAiProviderConfig
): Promise<CompletionsResult> {
@ -364,7 +308,11 @@ export default class ModernAiProvider {
}
// 调用新 AI SDK 的图像生成功能
const result = await generateImage(this.config.providerId, this.config.options, modelId, imageParams)
const executor = createExecutor(this.config.providerId, this.config.options, [])
const result = await executor.generateImage({
model,
...imageParams
})
// 转换结果格式
const images: string[] = []
@ -441,51 +389,64 @@ export default class ModernAiProvider {
return this.legacyProvider.getEmbeddingDimensions(model)
}
public async generateImage(params: GenerateImageParams): Promise<string[]> {
// 如果支持新的 AI SDK使用现代化实现
if (isModernSdkSupported(this.actualProvider)) {
try {
const result = await this.modernGenerateImage(params)
return result
} catch (error) {
logger.warn('Modern AI SDK generateImage failed, falling back to legacy:', error as Error)
// fallback 到传统实现
return this.legacyProvider.generateImage(params)
}
}
// public async generateImage(params: GenerateImageParams): Promise<string[]> {
// // 如果支持新的 AI SDK使用现代化实现
// if (isModernSdkSupported(this.actualProvider)) {
// try {
// // 确保本地provider已创建
// if (!this.localProvider) {
// await prepareSpecialProviderConfig(this.actualProvider, this.config)
// this.localProvider = await createProvider(this.config.providerId, this.config.options)
// logger.debug('Local provider created for standalone image generation', {
// providerId: this.config.providerId
// })
// }
// 直接使用传统实现
return this.legacyProvider.generateImage(params)
}
// const result = await this.modernGenerateImage(params)
// return result
// } catch (error) {
// logger.warn('Modern AI SDK generateImage failed, falling back to legacy:', error as Error)
// // fallback 到传统实现
// return this.legacyProvider.generateImage(params)
// }
// }
/**
* 使 AI SDK
*/
private async modernGenerateImage(params: GenerateImageParams): Promise<string[]> {
const { model, prompt, imageSize, batchSize, signal } = params
// // 直接使用传统实现
// return this.legacyProvider.generateImage(params)
// }
// 转换参数格式
const aiSdkParams = {
prompt,
size: (imageSize || '1024x1024') as `${number}x${number}`,
n: batchSize || 1,
...(signal && { abortSignal: signal })
}
// /**
// * 使用现代化 AI SDK 的图像生成实现
// */
// private async modernGenerateImage(params: GenerateImageParams): Promise<string[]> {
// const { model, prompt, imageSize, batchSize, signal } = params
const result = await generateImage(this.config.providerId, this.config.options, model, aiSdkParams)
// // 转换参数格式
// const aiSdkParams = {
// prompt,
// size: (imageSize || '1024x1024') as `${number}x${number}`,
// n: batchSize || 1,
// ...(signal && { abortSignal: signal })
// }
// 转换结果格式
const images: string[] = []
if (result.images) {
for (const image of result.images) {
if ('base64' in image && image.base64) {
images.push(`data:image/png;base64,${image.base64}`)
}
}
}
// const executor = createExecutor(this.config.providerId, this.config.options, [])
// const result = await executor.generateImage({
// model: this.localProvider?.imageModel(model) as ImageModel,
// ...aiSdkParams
// })
return images
}
// // 转换结果格式
// const images: string[] = []
// if (result.images) {
// for (const image of result.images) {
// if ('base64' in image && image.base64) {
// images.push(`data:image/png;base64,${image.base64}`)
// }
// }
// }
// return images
// }
public getBaseURL(): string {
return this.legacyProvider.getBaseURL()

View File

@ -30,7 +30,7 @@ const logger = loggerService.withContext('parameterBuilder')
*
*/
export async function buildStreamTextParams(
sdkMessages: StreamTextParams['messages'],
sdkMessages: StreamTextParams['messages'] = [],
assistant: Assistant,
provider: Provider,
options: {

View File

@ -4,9 +4,8 @@
import { isOpenAIModel } from '@renderer/config/models'
import { Provider } from '@renderer/types'
import { startsWith } from './helper'
import { provider2Provider } from './helper'
import type { ModelRule } from './types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
const extraProviderConfig = (provider: Provider) => {
return {
@ -18,41 +17,41 @@ const extraProviderConfig = (provider: Provider) => {
}
}
const AIHUBMIX_RULES: ModelRule[] = [
{
name: 'claude',
match: startsWith('claude'),
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
type: 'anthropic'
})
const AIHUBMIX_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
type: 'anthropic'
})
}
},
{
match: (model) =>
(startsWith('gemini')(model) || startsWith('imagen')(model)) &&
!model.id.endsWith('-nothink') &&
!model.id.endsWith('-search'),
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
type: 'gemini',
apiHost: 'https://aihubmix.com/gemini'
})
}
},
{
match: isOpenAIModel,
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
type: 'openai-response'
})
}
}
},
{
name: 'gemini',
match: (model) =>
(startsWith('gemini')(model) || startsWith('imagen')(model)) &&
!model.id.endsWith('-nothink') &&
!model.id.endsWith('-search'),
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
type: 'gemini',
apiHost: 'https://aihubmix.com/gemini'
})
}
},
{
name: 'openai',
match: isOpenAIModel,
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
type: 'openai-response'
})
}
}
]
],
fallbackRule: (provider: Provider) => provider
}
export const aihubmixProviderCreator = provider2Provider.bind(null, AIHUBMIX_RULES)

View File

@ -1,22 +1,22 @@
import type { Model, Provider } from '@renderer/types'
import type { ModelRule } from './types'
import type { RuleSet } from './types'
export const startsWith = (prefix: string) => (model: Model) => model.id.toLowerCase().startsWith(prefix.toLowerCase())
export const endpointIs = (type: string) => (model: Model) => model.endpoint_type === type
/**
* Provider ID
* Provider
* @param ruleSet
* @param model
* @param rules
* @param fallback fallback的providerId
* @returns providerId
* @param provider provider对象
* @returns provider对象
*/
export function provider2Provider(rules: ModelRule[], model: Model, provider: Provider): Provider {
for (const rule of rules) {
export function provider2Provider(ruleSet: RuleSet, model: Model, provider: Provider): Provider {
for (const rule of ruleSet.rules) {
if (rule.match(model)) {
return rule.provider(provider)
}
}
return provider
return ruleSet.fallbackRule(provider)
}

View File

@ -1,16 +1,3 @@
// /**
// * Provider解析规则模块导出
// */
// // 导出类型
// export type { ModelRule } from './types'
// // 导出匹配函数和解析器
// export { endpointIs, resolveProvider, startsWith } from './helper'
// // 导出规则集
// export { AIHUBMIX_RULES } from './aihubmix'
// export { NEWAPI_RULES } from './newApi'
export { aihubmixProviderCreator } from './aihubmix'
export { newApiResolverCreator } from './newApi'
export { vertexAnthropicProviderCreator } from './vertext-anthropic'

View File

@ -4,49 +4,48 @@
import { Provider } from '@renderer/types'
import { endpointIs, provider2Provider } from './helper'
import type { ModelRule } from './types'
import type { RuleSet } from './types'
const NEWAPI_RULES: ModelRule[] = [
{
name: 'anthropic',
match: endpointIs('anthropic'),
provider: (provider: Provider) => {
return {
...provider,
type: 'anthropic'
const NEWAPI_RULES: RuleSet = {
rules: [
{
match: endpointIs('anthropic'),
provider: (provider: Provider) => {
return {
...provider,
type: 'anthropic'
}
}
},
{
match: endpointIs('gemini'),
provider: (provider: Provider) => {
return {
...provider,
type: 'gemini'
}
}
},
{
match: endpointIs('openai-response'),
provider: (provider: Provider) => {
return {
...provider,
type: 'openai-response'
}
}
},
{
match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model),
provider: (provider: Provider) => {
return {
...provider,
type: 'openai'
}
}
}
},
{
name: 'gemini',
match: endpointIs('gemini'),
provider: (provider: Provider) => {
return {
...provider,
type: 'gemini'
}
}
},
{
name: 'openai-response',
match: endpointIs('openai-response'),
provider: (provider: Provider) => {
return {
...provider,
type: 'openai-response'
}
}
},
{
name: 'openai',
match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model),
provider: (provider: Provider) => {
return {
...provider,
type: 'openai'
}
}
}
]
],
fallbackRule: (provider: Provider) => provider
}
export const newApiResolverCreator = provider2Provider.bind(null, NEWAPI_RULES)

View File

@ -1,7 +1,9 @@
import type { Model, Provider } from '@renderer/types'
export interface ModelRule {
name: string
match: (model: Model) => boolean
provider: (provider: Provider) => Provider
export interface RuleSet {
rules: Array<{
match: (model: Model) => boolean
provider: (provider: Provider) => Provider
}>
fallbackRule: (provider: Provider) => Provider
}

View File

@ -0,0 +1,19 @@
import type { Provider } from '@renderer/types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
const VERTEX_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => ({
...provider,
id: 'google-vertex-anthropic'
})
}
],
fallbackRule: (provider: Provider) => provider
}
export const vertexAnthropicProviderCreator = provider2Provider.bind(null, VERTEX_ANTHROPIC_RULES)

View File

@ -1,6 +1,8 @@
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { Provider } from '@renderer/types'
import type { Provider as AiSdkProvider } from 'ai'
import { initializeNewProviders } from './providerInitialization'
@ -70,3 +72,28 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
// 3. 最后的fallback通常会成为openai-compatible
return provider.id as ProviderId
}
export async function createAiSdkProvider(config) {
let localProvider: Awaited<AiSdkProvider> | null = null
try {
if (config.providerId === 'openai' && config.options?.mode === 'chat') {
config.providerId = `${config.providerId}-chat`
} else if (config.providerId === 'azure' && config.options?.mode === 'responses') {
config.providerId = `${config.providerId}-responses`
}
localProvider = await createProviderCore(config.providerId, config.options)
logger.debug('Local provider created successfully', {
providerId: config.providerId,
hasOptions: !!config.options,
localProvider: localProvider,
options: config.options
})
} catch (error) {
logger.error('Failed to create local provider', error as Error, {
providerId: config.providerId
})
throw error
}
return localProvider
}

View File

@ -19,7 +19,7 @@ import type { Model, Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api'
import { cloneDeep, isEmpty } from 'lodash'
import { aihubmixProviderCreator, newApiResolverCreator } from './config'
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { getAiSdkProviderId } from './factory'
const logger = loggerService.withContext('ProviderConfigProcessor')
@ -67,6 +67,9 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider {
if (provider.id === 'newapi') {
return newApiResolverCreator(model, provider)
}
if (provider.id === 'vertexai') {
return vertexAnthropicProviderCreator(model, provider)
}
return provider
}
@ -157,7 +160,7 @@ export function providerToAiSdkConfig(
extraOptions.secretAccessKey = getAwsBedrockSecretAccessKey()
}
// google-vertex
if (aiSdkProviderId === 'google-vertex') {
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}

View File

@ -24,6 +24,14 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
supportsImageGeneration: true,
aliases: ['vertexai']
},
{
id: 'google-vertex-anthropic',
name: 'Google Vertex AI Anthropic',
import: () => import('@ai-sdk/google-vertex/anthropic/edge'),
creatorFunctionName: 'createVertexAnthropic',
supportsImageGeneration: true,
aliases: ['vertexai-anthropic']
},
{
id: 'bedrock',
name: 'Amazon Bedrock',

View File

@ -93,7 +93,6 @@ export async function fetchChatCompletion({
modelId: assistant.model?.id,
modelName: assistant.model?.name
})
const AI = new AiProviderNew(assistant.model || getDefaultModel())
const provider = AI.getActualProvider()
@ -126,7 +125,6 @@ export async function fetchChatCompletion({
streamOutput: assistant.settings?.streamOutput ?? true,
onChunk: onChunkReceived,
model: assistant.model,
provider: provider,
enableReasoning: capabilities.enableReasoning,
isPromptToolUse: isPromptToolUse(assistant),
isSupportedToolUse: isSupportedToolUse(assistant),

View File

@ -1,6 +1,29 @@
import { generateObject, generateText, streamObject, streamText } from 'ai'
import type { ImageModel, LanguageModel } from 'ai'
import { generateObject, generateText, ModelMessage, streamObject, streamText } from 'ai'
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model'>
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model'>
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model' | 'messages'> &
(
| {
prompt: string | Array<ModelMessage>
messages?: never
}
| {
messages: Array<ModelMessage>
prompt?: never
}
)
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model' | 'messages'> &
(
| {
prompt: string | Array<ModelMessage>
messages?: never
}
| {
messages: Array<ModelMessage>
prompt?: never
}
)
export type StreamObjectParams = Omit<Parameters<typeof streamObject>[0], 'model'>
export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'model'>
export type AiSdkModel = LanguageModel | ImageModel

View File

@ -356,6 +356,7 @@ export type ProviderType =
| 'vertexai'
| 'mistral'
| 'aws-bedrock'
| 'vertex-anthropic'
export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search' | 'rerank'

View File

@ -9209,7 +9209,7 @@ __metadata:
"@viz-js/lang-dot": "npm:^1.0.5"
"@viz-js/viz": "npm:^3.14.0"
"@xyflow/react": "npm:^12.4.4"
ai: "npm:^5.0.26"
ai: "npm:^5.0.29"
antd: "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch"
archiver: "npm:^7.0.1"
async-mutex: "npm:^0.5.0"
@ -9425,9 +9425,9 @@ __metadata:
languageName: node
linkType: hard
"ai@npm:^5.0.26":
version: 5.0.26
resolution: "ai@npm:5.0.26"
"ai@npm:^5.0.29":
version: 5.0.29
resolution: "ai@npm:5.0.29"
dependencies:
"@ai-sdk/gateway": "npm:1.0.15"
"@ai-sdk/provider": "npm:2.0.0"
@ -9435,7 +9435,7 @@ __metadata:
"@opentelemetry/api": "npm:1.9.0"
peerDependencies:
zod: ^3.25.76 || ^4
checksum: 10c0/0423f296b1aa9f22ad106278e8d1e7a2ae9d068358720cdc23c0f222af5406ac1e5ccbce19833709aa1c62b841361d310fb3c42781f426966a9f9ca287ae7faa
checksum: 10c0/526cd2fd59b35b19d902665e3dc1ba5a09f2bb1377295d642fb8a33e13a890874e4dd4b49a787de7f31f4ec6b07257be8514efac08f993081daeb430cf2f60ba
languageName: node
linkType: hard