From 782953cca1f0ef9a9e2f7dd4f58e166582c9b82c Mon Sep 17 00:00:00 2001 From: MyPrototypeWhat Date: Tue, 2 Sep 2025 18:54:07 +0800 Subject: [PATCH] chore(dependencies): update ai package version and enhance aiCore functionality - Updated the 'ai' package version from 5.0.26 to 5.0.29 in package.json and yarn.lock. - Refactored aiCore's provider schemas to introduce new provider types and improve type safety. - Enhanced the RuntimeExecutor class to streamline model handling and plugin execution for various AI tasks. - Updated tests to reflect changes in parameter handling and ensure compatibility with new provider configurations. --- package.json | 2 +- packages/aiCore/src/core/providers/schemas.ts | 35 ++- .../runtime/__tests__/generateImage.test.ts | 149 +++++----- packages/aiCore/src/core/runtime/executor.ts | 193 ++++++------- packages/aiCore/src/core/runtime/index.ts | 28 +- .../aiCore/src/core/runtime/pluginEngine.ts | 101 +++++-- src/renderer/src/aiCore/index_new.ts | 265 ++++++++---------- .../aiCore/prepareParams/parameterBuilder.ts | 2 +- .../src/aiCore/provider/config/aihubmix.ts | 75 +++-- .../src/aiCore/provider/config/helper.ts | 16 +- .../src/aiCore/provider/config/index.ts | 15 +- .../src/aiCore/provider/config/newApi.ts | 81 +++--- .../src/aiCore/provider/config/types.ts | 10 +- .../provider/config/vertext-anthropic.ts | 19 ++ src/renderer/src/aiCore/provider/factory.ts | 27 ++ .../src/aiCore/provider/providerConfig.ts | 7 +- .../aiCore/provider/providerInitialization.ts | 8 + src/renderer/src/services/ApiService.ts | 2 - src/renderer/src/types/aiCoreTypes.ts | 29 +- src/renderer/src/types/index.ts | 1 + yarn.lock | 10 +- 21 files changed, 570 insertions(+), 505 deletions(-) create mode 100644 src/renderer/src/aiCore/provider/config/vertext-anthropic.ts diff --git a/package.json b/package.json index c2731e7c86..9997141c8c 100644 --- a/package.json +++ b/package.json @@ -174,7 +174,7 @@ "@viz-js/lang-dot": "^1.0.5", "@viz-js/viz": "^3.14.0", "@xyflow/react": "^12.4.4", - "ai": "^5.0.26", + "ai": "^5.0.29", "antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch", "archiver": "^7.0.1", "async-mutex": "^0.5.0", diff --git a/packages/aiCore/src/core/providers/schemas.ts b/packages/aiCore/src/core/providers/schemas.ts index 9b5441c8a0..0c1c847d98 100644 --- a/packages/aiCore/src/core/providers/schemas.ts +++ b/packages/aiCore/src/core/providers/schemas.ts @@ -4,11 +4,13 @@ import { createAnthropic } from '@ai-sdk/anthropic' import { createAzure } from '@ai-sdk/azure' +import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure' import { createDeepSeek } from '@ai-sdk/deepseek' import { createGoogleGenerativeAI } from '@ai-sdk/google' import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai' import { createOpenAICompatible } from '@ai-sdk/openai-compatible' import { createXai } from '@ai-sdk/xai' +import { customProvider, type Provider } from 'ai' import * as z from 'zod' /** @@ -16,12 +18,13 @@ import * as z from 'zod' */ export const baseProviderIds = [ 'openai', - 'openai-responses', + 'openai-chat', 'openai-compatible', 'anthropic', 'google', 'xai', 'azure', + 'azure-responses', 'deepseek' ] as const @@ -38,7 +41,7 @@ export type BaseProviderId = z.infer export const baseProviderSchema = z.object({ id: baseProviderIdSchema, name: z.string(), - creator: z.function().args(z.any()).returns(z.any()), + creator: z.function().args(z.any()).returns(z.any()) as z.ZodType<(options: any) => Provider>, supportsImageGeneration: z.boolean() }) @@ -56,9 +59,17 @@ export const baseProviders = [ supportsImageGeneration: true }, { - id: 'openai-responses', - name: 'OpenAI Responses', - creator: (options: OpenAIProviderSettings) => createOpenAI(options).responses, + id: 'openai-chat', + name: 'OpenAI Chat', + creator: (options: OpenAIProviderSettings) => { + const provider = createOpenAI(options) + return customProvider({ + fallbackProvider: { + ...provider, + languageModel: (modelId: string) => provider.chat(modelId) + } + }) + }, supportsImageGeneration: true }, { @@ -91,6 +102,20 @@ export const baseProviders = [ creator: createAzure, supportsImageGeneration: true }, + { + id: 'azure-responses', + name: 'Azure OpenAI Responses', + creator: (options: AzureOpenAIProviderSettings) => { + const provider = createAzure(options) + return customProvider({ + fallbackProvider: { + ...provider, + languageModel: (modelId: string) => provider.responses(modelId) + } + }) + }, + supportsImageGeneration: true + }, { id: 'deepseek', name: 'DeepSeek', diff --git a/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts b/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts index 872c7c4940..bde5779fd9 100644 --- a/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts +++ b/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts @@ -4,7 +4,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import { type AiPlugin } from '../../plugins' import { globalRegistryManagement } from '../../providers/RegistryManagement' -import { ImageGenerationError } from '../errors' +import { ImageGenerationError, ImageModelResolutionError } from '../errors' import { RuntimeExecutor } from '../executor' // Mock dependencies @@ -76,9 +76,7 @@ describe('RuntimeExecutor.generateImage', () => { describe('Basic functionality', () => { it('should generate a single image with minimal parameters', async () => { - const result = await executor.generateImage('dall-e-3', { - prompt: 'A futuristic cityscape at sunset' - }) + const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape at sunset' }) expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai|dall-e-3') @@ -91,7 +89,8 @@ describe('RuntimeExecutor.generateImage', () => { }) it('should generate image with pre-created model', async () => { - const result = await executor.generateImage(mockImageModel, { + const result = await executor.generateImage({ + model: mockImageModel, prompt: 'A beautiful landscape' }) @@ -105,10 +104,7 @@ describe('RuntimeExecutor.generateImage', () => { }) it('should support multiple images generation', async () => { - await executor.generateImage('dall-e-3', { - prompt: 'A futuristic cityscape', - n: 3 - }) + await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape', n: 3 }) expect(aiGenerateImage).toHaveBeenCalledWith({ model: mockImageModel, @@ -118,10 +114,7 @@ describe('RuntimeExecutor.generateImage', () => { }) it('should support size specification', async () => { - await executor.generateImage('dall-e-3', { - prompt: 'A beautiful sunset', - size: '1024x1024' - }) + await executor.generateImage({ model: 'dall-e-3', prompt: 'A beautiful sunset', size: '1024x1024' }) expect(aiGenerateImage).toHaveBeenCalledWith({ model: mockImageModel, @@ -131,10 +124,7 @@ describe('RuntimeExecutor.generateImage', () => { }) it('should support aspect ratio specification', async () => { - await executor.generateImage('dall-e-3', { - prompt: 'A mountain landscape', - aspectRatio: '16:9' - }) + await executor.generateImage({ model: 'dall-e-3', prompt: 'A mountain landscape', aspectRatio: '16:9' }) expect(aiGenerateImage).toHaveBeenCalledWith({ model: mockImageModel, @@ -144,10 +134,7 @@ describe('RuntimeExecutor.generateImage', () => { }) it('should support seed for consistent output', async () => { - await executor.generateImage('dall-e-3', { - prompt: 'A cat in space', - seed: 1234567890 - }) + await executor.generateImage({ model: 'dall-e-3', prompt: 'A cat in space', seed: 1234567890 }) expect(aiGenerateImage).toHaveBeenCalledWith({ model: mockImageModel, @@ -159,10 +146,7 @@ describe('RuntimeExecutor.generateImage', () => { it('should support abort signal', async () => { const abortController = new AbortController() - await executor.generateImage('dall-e-3', { - prompt: 'A cityscape', - abortSignal: abortController.signal - }) + await executor.generateImage({ model: 'dall-e-3', prompt: 'A cityscape', abortSignal: abortController.signal }) expect(aiGenerateImage).toHaveBeenCalledWith({ model: mockImageModel, @@ -172,7 +156,8 @@ describe('RuntimeExecutor.generateImage', () => { }) it('should support provider-specific options', async () => { - await executor.generateImage('dall-e-3', { + await executor.generateImage({ + model: 'dall-e-3', prompt: 'A space station', providerOptions: { openai: { @@ -195,7 +180,8 @@ describe('RuntimeExecutor.generateImage', () => { }) it('should support custom headers', async () => { - await executor.generateImage('dall-e-3', { + await executor.generateImage({ + model: 'dall-e-3', prompt: 'A robot', headers: { 'X-Custom-Header': 'test-value' @@ -242,9 +228,7 @@ describe('RuntimeExecutor.generateImage', () => { [testPlugin] ) - const result = await executorWithPlugin.generateImage('dall-e-3', { - prompt: 'A test image' - }) + const result = await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' }) expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd']) @@ -287,9 +271,7 @@ describe('RuntimeExecutor.generateImage', () => { [modelResolutionPlugin] ) - await executorWithPlugin.generateImage('dall-e-3', { - prompt: 'A test image' - }) + await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' }) expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith( 'dall-e-3', @@ -312,6 +294,7 @@ describe('RuntimeExecutor.generateImage', () => { if (!context.isRecursiveCall && params.prompt === 'original') { // Make a recursive call with modified prompt await context.recursiveCall({ + model: 'dall-e-3', prompt: 'modified' }) } @@ -327,9 +310,7 @@ describe('RuntimeExecutor.generateImage', () => { [recursivePlugin] ) - await executorWithPlugin.generateImage('dall-e-3', { - prompt: 'original' - }) + await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'original' }) expect(recursivePlugin.transformParams).toHaveBeenCalledTimes(2) expect(aiGenerateImage).toHaveBeenCalledTimes(2) @@ -343,22 +324,47 @@ describe('RuntimeExecutor.generateImage', () => { throw modelError }) - await expect( - executor.generateImage('invalid-model', { - prompt: 'A test image' - }) - ).rejects.toThrow(ImageGenerationError) + await expect(executor.generateImage({ model: 'invalid-model', prompt: 'A test image' })).rejects.toThrow( + ImageGenerationError + ) + }) + + it('should handle ImageModelResolutionError correctly', async () => { + const resolutionError = new ImageModelResolutionError('invalid-model', 'openai', new Error('Model not found')) + vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => { + throw resolutionError + }) + + const thrownError = await executor + .generateImage({ model: 'invalid-model', prompt: 'A test image' }) + .catch((error) => error) + + expect(thrownError).toBeInstanceOf(ImageGenerationError) + expect(thrownError.message).toContain('Failed to generate image:') + expect(thrownError.providerId).toBe('openai') + expect(thrownError.modelId).toBe('invalid-model') + expect(thrownError.cause).toBeInstanceOf(ImageModelResolutionError) + expect(thrownError.cause.message).toContain('Failed to resolve image model: invalid-model') + }) + + it('should handle ImageModelResolutionError without provider', async () => { + const resolutionError = new ImageModelResolutionError('unknown-model') + vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => { + throw resolutionError + }) + + await expect(executor.generateImage({ model: 'unknown-model', prompt: 'A test image' })).rejects.toThrow( + ImageGenerationError + ) }) it('should handle image generation API errors', async () => { const apiError = new Error('API request failed') vi.mocked(aiGenerateImage).mockRejectedValue(apiError) - await expect( - executor.generateImage('dall-e-3', { - prompt: 'A test image' - }) - ).rejects.toThrow('Failed to generate image:') + await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow( + 'Failed to generate image:' + ) }) it('should handle NoImageGeneratedError', async () => { @@ -370,11 +376,9 @@ describe('RuntimeExecutor.generateImage', () => { vi.mocked(aiGenerateImage).mockRejectedValue(noImageError) vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true) - await expect( - executor.generateImage('dall-e-3', { - prompt: 'A test image' - }) - ).rejects.toThrow('Failed to generate image:') + await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow( + 'Failed to generate image:' + ) }) it('should execute onError plugin hook on failure', async () => { @@ -394,11 +398,9 @@ describe('RuntimeExecutor.generateImage', () => { [errorPlugin] ) - await expect( - executorWithPlugin.generateImage('dall-e-3', { - prompt: 'A test image' - }) - ).rejects.toThrow('Failed to generate image:') + await expect(executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow( + 'Failed to generate image:' + ) expect(errorPlugin.onError).toHaveBeenCalledWith( error, @@ -418,10 +420,7 @@ describe('RuntimeExecutor.generateImage', () => { setTimeout(() => abortController.abort(), 10) await expect( - executor.generateImage('dall-e-3', { - prompt: 'A test image', - abortSignal: abortController.signal - }) + executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', abortSignal: abortController.signal }) ).rejects.toThrow('Failed to generate image:') }) }) @@ -432,9 +431,7 @@ describe('RuntimeExecutor.generateImage', () => { apiKey: 'google-key' }) - await googleExecutor.generateImage('imagen-3.0-generate-002', { - prompt: 'A landscape' - }) + await googleExecutor.generateImage({ model: 'imagen-3.0-generate-002', prompt: 'A landscape' }) expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('google|imagen-3.0-generate-002') }) @@ -444,9 +441,7 @@ describe('RuntimeExecutor.generateImage', () => { apiKey: 'xai-key' }) - await xaiExecutor.generateImage('grok-2-image', { - prompt: 'A futuristic robot' - }) + await xaiExecutor.generateImage({ model: 'grok-2-image', prompt: 'A futuristic robot' }) expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('xai|grok-2-image') }) @@ -454,11 +449,7 @@ describe('RuntimeExecutor.generateImage', () => { describe('Advanced features', () => { it('should support batch image generation with maxImagesPerCall', async () => { - await executor.generateImage('dall-e-3', { - prompt: 'A test image', - n: 10, - maxImagesPerCall: 5 - }) + await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', n: 10, maxImagesPerCall: 5 }) expect(aiGenerateImage).toHaveBeenCalledWith({ model: mockImageModel, @@ -469,10 +460,7 @@ describe('RuntimeExecutor.generateImage', () => { }) it('should support retries with maxRetries', async () => { - await executor.generateImage('dall-e-3', { - prompt: 'A test image', - maxRetries: 3 - }) + await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', maxRetries: 3 }) expect(aiGenerateImage).toHaveBeenCalledWith({ model: mockImageModel, @@ -494,7 +482,8 @@ describe('RuntimeExecutor.generateImage', () => { vi.mocked(aiGenerateImage).mockResolvedValue(resultWithWarnings) - const result = await executor.generateImage('dall-e-3', { + const result = await executor.generateImage({ + model: 'dall-e-3', prompt: 'A test image', size: '2048x2048' // Unsupported size }) @@ -504,9 +493,7 @@ describe('RuntimeExecutor.generateImage', () => { }) it('should provide access to provider metadata', async () => { - const result = await executor.generateImage('dall-e-3', { - prompt: 'A test image' - }) + const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' }) expect(result.providerMetadata).toBeDefined() expect(result.providerMetadata.openai).toBeDefined() @@ -526,9 +513,7 @@ describe('RuntimeExecutor.generateImage', () => { vi.mocked(aiGenerateImage).mockResolvedValue(resultWithMetadata) - const result = await executor.generateImage('dall-e-3', { - prompt: 'A test image' - }) + const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' }) expect(result.responses).toHaveLength(1) expect(result.responses[0].modelId).toBe('dall-e-3') diff --git a/packages/aiCore/src/core/runtime/executor.ts b/packages/aiCore/src/core/runtime/executor.ts index e53bae6474..ab80f9cecc 100644 --- a/packages/aiCore/src/core/runtime/executor.ts +++ b/packages/aiCore/src/core/runtime/executor.ts @@ -72,42 +72,36 @@ export class RuntimeExecutor { // === 高阶重载:直接使用模型 === /** - * 流式文本生成 - 使用已创建的模型(高级用法) + * 流式文本生成 */ async streamText( - model: LanguageModel, - params: Omit[0], 'model'> - ): Promise> - async streamText( - modelId: string, - params: Omit[0], 'model'>, - options?: { - middlewares?: LanguageModelV2Middleware[] - } - ): Promise> - async streamText( - modelOrId: LanguageModel, - params: Omit[0], 'model'>, + params: Parameters[0], options?: { middlewares?: LanguageModelV2Middleware[] } ): Promise> { - this.pluginEngine.usePlugins([ - this.createResolveModelPlugin(options?.middlewares), - this.createConfigureContextPlugin() - ]) + const { model, ...restParams } = params + + // 根据 model 类型决定插件配置 + if (typeof model === 'string') { + this.pluginEngine.usePlugins([ + this.createResolveModelPlugin(options?.middlewares), + this.createConfigureContextPlugin() + ]) + } else { + this.pluginEngine.usePlugins([this.createConfigureContextPlugin()]) + } - // 2. 执行插件处理 return this.pluginEngine.executeStreamWithPlugins( 'streamText', - typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId, - params, - async (model, transformedParams, streamTransforms) => { + model, + restParams, + async (resolvedModel, transformedParams, streamTransforms) => { const experimental_transform = params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined) const finalParams = { - model, + model: resolvedModel, ...transformedParams, experimental_transform } as Parameters[0] @@ -120,145 +114,126 @@ export class RuntimeExecutor { // === 其他方法的重载 === /** - * 生成文本 - 使用已创建的模型 + * 生成文本 */ async generateText( - model: LanguageModel, - params: Omit[0], 'model'> - ): Promise> - async generateText( - modelId: string, - params: Omit[0], 'model'>, - options?: { - middlewares?: LanguageModelV2Middleware[] - } - ): Promise> - async generateText( - modelOrId: LanguageModel | string, - params: Omit[0], 'model'>, + params: Parameters[0], options?: { middlewares?: LanguageModelV2Middleware[] } ): Promise> { - this.pluginEngine.usePlugins([ - this.createResolveModelPlugin(options?.middlewares), - this.createConfigureContextPlugin() - ]) + const { model, ...restParams } = params + + // 根据 model 类型决定插件配置 + if (typeof model === 'string') { + this.pluginEngine.usePlugins([ + this.createResolveModelPlugin(options?.middlewares), + this.createConfigureContextPlugin() + ]) + } else { + this.pluginEngine.usePlugins([this.createConfigureContextPlugin()]) + } return this.pluginEngine.executeWithPlugins( 'generateText', - typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId, - params, - async (model, transformedParams) => - generateText({ model, ...transformedParams } as Parameters[0]) + model, + restParams, + async (resolvedModel, transformedParams) => + generateText({ model: resolvedModel, ...transformedParams } as Parameters[0]) ) } /** - * 生成结构化对象 - 使用已创建的模型 + * 生成结构化对象 */ async generateObject( - model: LanguageModel, - params: Omit[0], 'model'> - ): Promise> - async generateObject( - modelOrId: string, - params: Omit[0], 'model'>, - options?: { - middlewares?: LanguageModelV2Middleware[] - } - ): Promise> - async generateObject( - modelOrId: LanguageModel | string, - params: Omit[0], 'model'>, + params: Parameters[0], options?: { middlewares?: LanguageModelV2Middleware[] } ): Promise> { - this.pluginEngine.usePlugins([ - this.createResolveModelPlugin(options?.middlewares), - this.createConfigureContextPlugin() - ]) + const { model, ...restParams } = params + + // 根据 model 类型决定插件配置 + if (typeof model === 'string') { + this.pluginEngine.usePlugins([ + this.createResolveModelPlugin(options?.middlewares), + this.createConfigureContextPlugin() + ]) + } else { + this.pluginEngine.usePlugins([this.createConfigureContextPlugin()]) + } return this.pluginEngine.executeWithPlugins( 'generateObject', - typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId, - params, - async (model, transformedParams) => - generateObject({ model, ...transformedParams } as Parameters[0]) + model, + restParams, + async (resolvedModel, transformedParams) => + generateObject({ model: resolvedModel, ...transformedParams } as Parameters[0]) ) } /** - * 流式生成结构化对象 - 使用已创建的模型 + * 流式生成结构化对象 */ async streamObject( - model: LanguageModel, - params: Omit[0], 'model'> - ): Promise> - async streamObject( - modelId: string, - params: Omit[0], 'model'>, - options?: { - middlewares?: LanguageModelV2Middleware[] - } - ): Promise> - async streamObject( - modelOrId: LanguageModel | string, - params: Omit[0], 'model'>, + params: Parameters[0], options?: { middlewares?: LanguageModelV2Middleware[] } ): Promise> { - this.pluginEngine.usePlugins([ - this.createResolveModelPlugin(options?.middlewares), - this.createConfigureContextPlugin() - ]) + const { model, ...restParams } = params + + // 根据 model 类型决定插件配置 + if (typeof model === 'string') { + this.pluginEngine.usePlugins([ + this.createResolveModelPlugin(options?.middlewares), + this.createConfigureContextPlugin() + ]) + } else { + this.pluginEngine.usePlugins([this.createConfigureContextPlugin()]) + } return this.pluginEngine.executeWithPlugins( 'streamObject', - typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId, - params, - async (model, transformedParams) => - streamObject({ model, ...transformedParams } as Parameters[0]) + model, + restParams, + async (resolvedModel, transformedParams) => + streamObject({ model: resolvedModel, ...transformedParams } as Parameters[0]) ) } /** - * 生成图像 - 使用已创建的图像模型 + * 生成图像 */ async generateImage( - model: ImageModelV2, - params: Omit[0], 'model'> - ): Promise> - async generateImage( - modelId: string, - params: Omit[0], 'model'>, - options?: { - middlewares?: LanguageModelV2Middleware[] - } - ): Promise> - async generateImage( - modelOrId: ImageModelV2 | string, - params: Omit[0], 'model'> + params: Omit[0], 'model'> & { model: string | ImageModelV2 } ): Promise> { try { - this.pluginEngine.usePlugins([this.createResolveImageModelPlugin(), this.createConfigureContextPlugin()]) + const { model, ...restParams } = params + + // 根据 model 类型决定插件配置 + if (typeof model === 'string') { + this.pluginEngine.usePlugins([this.createResolveImageModelPlugin(), this.createConfigureContextPlugin()]) + } else { + this.pluginEngine.usePlugins([this.createConfigureContextPlugin()]) + } return await this.pluginEngine.executeImageWithPlugins( 'generateImage', - typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId, - params, - async (model, transformedParams) => { - return await generateImage({ model, ...transformedParams }) + model, + restParams, + async (resolvedModel, transformedParams) => { + return await generateImage({ model: resolvedModel, ...transformedParams }) } ) } catch (error) { if (error instanceof Error) { + const modelId = typeof params.model === 'string' ? params.model : params.model.modelId throw new ImageGenerationError( `Failed to generate image: ${error.message}`, this.config.providerId, - typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId, + modelId, error ) } diff --git a/packages/aiCore/src/core/runtime/index.ts b/packages/aiCore/src/core/runtime/index.ts index 279c6126b2..37aa4fec34 100644 --- a/packages/aiCore/src/core/runtime/index.ts +++ b/packages/aiCore/src/core/runtime/index.ts @@ -46,13 +46,12 @@ export function createOpenAICompatibleExecutor( export async function streamText( providerId: T, options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }, - modelId: string, - params: Parameters['streamText']>[1], + params: Parameters['streamText']>[0], plugins?: AiPlugin[], middlewares?: LanguageModelV2Middleware[] ): Promise['streamText']>> { const executor = createExecutor(providerId, options, plugins) - return executor.streamText(modelId, params, { middlewares }) + return executor.streamText(params, { middlewares }) } /** @@ -61,13 +60,12 @@ export async function streamText( export async function generateText( providerId: T, options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }, - modelId: string, - params: Parameters['generateText']>[1], + params: Parameters['generateText']>[0], plugins?: AiPlugin[], middlewares?: LanguageModelV2Middleware[] ): Promise['generateText']>> { const executor = createExecutor(providerId, options, plugins) - return executor.generateText(modelId, params, { middlewares }) + return executor.generateText(params, { middlewares }) } /** @@ -76,13 +74,12 @@ export async function generateText( export async function generateObject( providerId: T, options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }, - modelId: string, - params: Parameters['generateObject']>[1], + params: Parameters['generateObject']>[0], plugins?: AiPlugin[], middlewares?: LanguageModelV2Middleware[] ): Promise['generateObject']>> { const executor = createExecutor(providerId, options, plugins) - return executor.generateObject(modelId, params, { middlewares }) + return executor.generateObject(params, { middlewares }) } /** @@ -91,13 +88,12 @@ export async function generateObject( export async function streamObject( providerId: T, options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }, - modelId: string, - params: Parameters['streamObject']>[1], + params: Parameters['streamObject']>[0], plugins?: AiPlugin[], middlewares?: LanguageModelV2Middleware[] ): Promise['streamObject']>> { const executor = createExecutor(providerId, options, plugins) - return executor.streamObject(modelId, params, { middlewares }) + return executor.streamObject(params, { middlewares }) } /** @@ -106,13 +102,11 @@ export async function streamObject( export async function generateImage( providerId: T, options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }, - modelId: string, - params: Parameters['generateImage']>[1], - plugins?: AiPlugin[], - middlewares?: LanguageModelV2Middleware[] + params: Parameters['generateImage']>[0], + plugins?: AiPlugin[] ): Promise['generateImage']>> { const executor = createExecutor(providerId, options, plugins) - return executor.generateImage(modelId, params, { middlewares }) + return executor.generateImage(params) } // === Agent 功能预留 === diff --git a/packages/aiCore/src/core/runtime/pluginEngine.ts b/packages/aiCore/src/core/runtime/pluginEngine.ts index 17e0c1de7f..7a4bb440f7 100644 --- a/packages/aiCore/src/core/runtime/pluginEngine.ts +++ b/packages/aiCore/src/core/runtime/pluginEngine.ts @@ -64,11 +64,24 @@ export class PluginEngine { */ async executeWithPlugins( methodName: string, - modelId: string, + model: LanguageModel, params: TParams, executor: (model: LanguageModel, transformedParams: TParams) => Promise, _context?: ReturnType ): Promise { + // 统一处理模型解析 + let resolvedModel: LanguageModel | undefined + let modelId: string + + if (typeof model === 'string') { + // 字符串:需要通过插件解析 + modelId = model + } else { + // 模型对象:直接使用 + resolvedModel = model + modelId = model.modelId + } + // 使用正确的createContext创建请求上下文 const context = _context ? _context : createContext(this.providerId, modelId, params) @@ -76,7 +89,7 @@ export class PluginEngine { context.recursiveCall = async (newParams: any): Promise => { // 递归调用自身,重新走完整的插件流程 context.isRecursiveCall = true - const result = await this.executeWithPlugins(methodName, modelId, newParams, executor, context) + const result = await this.executeWithPlugins(methodName, model, newParams, executor, context) context.isRecursiveCall = false return result } @@ -88,17 +101,24 @@ export class PluginEngine { // 1. 触发请求开始事件 await this.pluginManager.executeParallel('onRequestStart', context) - // 2. 解析模型 - const model = await this.pluginManager.executeFirst('resolveModel', modelId, context) - if (!model) { - throw new Error(`Failed to resolve model: ${modelId}`) + // 2. 解析模型(如果是字符串) + if (typeof model === 'string') { + const resolved = await this.pluginManager.executeFirst('resolveModel', modelId, context) + if (!resolved) { + throw new Error(`Failed to resolve model: ${modelId}`) + } + resolvedModel = resolved + } + + if (!resolvedModel) { + throw new Error(`Model resolution failed: no model available`) } // 3. 转换请求参数 const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context) // 4. 执行具体的 API 调用 - const result = await executor(model, transformedParams) + const result = await executor(resolvedModel, transformedParams) // 5. 转换结果(对于非流式调用) const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context) @@ -120,11 +140,24 @@ export class PluginEngine { */ async executeImageWithPlugins( methodName: string, - modelId: string, + model: ImageModelV2 | string, params: TParams, executor: (model: ImageModelV2, transformedParams: TParams) => Promise, _context?: ReturnType ): Promise { + // 统一处理模型解析 + let resolvedModel: ImageModelV2 | undefined + let modelId: string + + if (typeof model === 'string') { + // 字符串:需要通过插件解析 + modelId = model + } else { + // 模型对象:直接使用 + resolvedModel = model + modelId = model.modelId + } + // 使用正确的createContext创建请求上下文 const context = _context ? _context : createContext(this.providerId, modelId, params) @@ -132,7 +165,7 @@ export class PluginEngine { context.recursiveCall = async (newParams: any): Promise => { // 递归调用自身,重新走完整的插件流程 context.isRecursiveCall = true - const result = await this.executeImageWithPlugins(methodName, modelId, newParams, executor, context) + const result = await this.executeImageWithPlugins(methodName, model, newParams, executor, context) context.isRecursiveCall = false return result } @@ -144,17 +177,24 @@ export class PluginEngine { // 1. 触发请求开始事件 await this.pluginManager.executeParallel('onRequestStart', context) - // 2. 解析模型 - const model = await this.pluginManager.executeFirst('resolveModel', modelId, context) - if (!model) { - throw new Error(`Failed to resolve image model: ${modelId}`) + // 2. 解析模型(如果是字符串) + if (typeof model === 'string') { + const resolved = await this.pluginManager.executeFirst('resolveModel', modelId, context) + if (!resolved) { + throw new Error(`Failed to resolve image model: ${modelId}`) + } + resolvedModel = resolved + } + + if (!resolvedModel) { + throw new Error(`Image model resolution failed: no model available`) } // 3. 转换请求参数 const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context) // 4. 执行具体的 API 调用 - const result = await executor(model, transformedParams) + const result = await executor(resolvedModel, transformedParams) // 5. 转换结果 const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context) @@ -176,11 +216,24 @@ export class PluginEngine { */ async executeStreamWithPlugins( methodName: string, - modelId: string, + model: LanguageModel, params: TParams, executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise, _context?: ReturnType ): Promise { + // 统一处理模型解析 + let resolvedModel: LanguageModel | undefined + let modelId: string + + if (typeof model === 'string') { + // 字符串:需要通过插件解析 + modelId = model + } else { + // 模型对象:直接使用 + resolvedModel = model + modelId = model.modelId + } + // 创建请求上下文 const context = _context ? _context : createContext(this.providerId, modelId, params) @@ -188,7 +241,7 @@ export class PluginEngine { context.recursiveCall = async (newParams: any): Promise => { // 递归调用自身,重新走完整的插件流程 context.isRecursiveCall = true - const result = await this.executeStreamWithPlugins(methodName, modelId, newParams, executor, context) + const result = await this.executeStreamWithPlugins(methodName, model, newParams, executor, context) context.isRecursiveCall = false return result } @@ -200,11 +253,17 @@ export class PluginEngine { // 1. 触发请求开始事件 await this.pluginManager.executeParallel('onRequestStart', context) - // 2. 解析模型 - const model = await this.pluginManager.executeFirst('resolveModel', modelId, context) + // 2. 解析模型(如果是字符串) + if (typeof model === 'string') { + const resolved = await this.pluginManager.executeFirst('resolveModel', modelId, context) + if (!resolved) { + throw new Error(`Failed to resolve model: ${modelId}`) + } + resolvedModel = resolved + } - if (!model) { - throw new Error(`Failed to resolve model: ${modelId}`) + if (!resolvedModel) { + throw new Error(`Model resolution failed: no model available`) } // 3. 转换请求参数 @@ -214,7 +273,7 @@ export class PluginEngine { const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context) // 5. 执行流式 API 调用 - const result = await executor(model, transformedParams, streamTransforms) + const result = await executor(resolvedModel, transformedParams, streamTransforms) const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context) diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index b5998e013d..0a498d58ae 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -7,22 +7,23 @@ * 2. 暂时保持接口兼容性 */ -import { createExecutor, generateImage } from '@cherrystudio/ai-core' -import { createAndRegisterProvider } from '@cherrystudio/ai-core/provider' +import { createExecutor } from '@cherrystudio/ai-core' import { loggerService } from '@logger' import { isNotSupportedImageSizeModel } from '@renderer/config/models' import { getEnableDeveloperMode } from '@renderer/hooks/useSettings' import { addSpan, endSpan } from '@renderer/services/SpanManagerService' import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity' -import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types' -import type { StreamTextParams } from '@renderer/types/aiCoreTypes' +import type { Assistant, Model, Provider } from '@renderer/types' +import type { AiSdkModel, StreamTextParams } from '@renderer/types/aiCoreTypes' import { ChunkType } from '@renderer/types/chunk' +import { type ImageModel, type LanguageModel, type Provider as AiSdkProvider, wrapLanguageModel } from 'ai' import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter' import LegacyAiProvider from './legacy/index' import { CompletionsResult } from './legacy/middleware/schemas' import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder' import { buildPlugins } from './plugins/PluginBuilder' +import { createAiSdkProvider } from './provider/factory' import { getActualProvider, isModernSdkSupported, @@ -44,6 +45,7 @@ export default class ModernAiProvider { private config: ReturnType private actualProvider: Provider private model: Model + private localProvider: Awaited | null = null constructor(model: Model, provider?: Provider) { this.actualProvider = provider || getActualProvider(model) @@ -62,44 +64,58 @@ export default class ModernAiProvider { // 准备特殊配置 await prepareSpecialProviderConfig(this.actualProvider, this.config) - logger.debug('this.config', this.config) + // 提前创建本地 provider 实例 + if (!this.localProvider) { + this.localProvider = await createAiSdkProvider(this.config) + } + + // 提前构建中间件 + const middlewares = buildAiSdkMiddlewares({ + ...config, + provider: this.actualProvider + }) + logger.debug('Built middlewares in completions', { + middlewareCount: middlewares.length, + isImageGeneration: config.isImageGenerationEndpoint + }) + if (!this.localProvider) { + throw new Error('Local provider not created') + } + + // 根据endpoint类型创建对应的模型 + let model: AiSdkModel | undefined + if (config.isImageGenerationEndpoint) { + model = this.localProvider.imageModel(modelId) + } else { + model = this.localProvider.languageModel(modelId) + // 如果有中间件,应用到语言模型上 + if (middlewares.length > 0 && typeof model === 'object') { + model = wrapLanguageModel({ model, middleware: middlewares }) + } + } + if (config.topicId && getEnableDeveloperMode()) { // TypeScript类型窄化:确保topicId是string类型 const traceConfig = { ...config, topicId: config.topicId } - return await this._completionsForTrace(modelId, params, traceConfig) + return await this._completionsForTrace(model, params, traceConfig) } else { - return await this._completions(modelId, params, config) + return await this._completionsOrImageGeneration(model, params, config) } } - private async _completions( - modelId: string, + private async _completionsOrImageGeneration( + model: AiSdkModel, params: StreamTextParams, config: ModernAiProviderConfig ): Promise { - // 初始化 provider 到全局管理器 - try { - await createAndRegisterProvider(this.config.providerId, this.config.options) - logger.debug('Provider initialized successfully', { - providerId: this.config.providerId, - hasOptions: !!this.config.options - }) - } catch (error) { - // 如果 provider 已经初始化过,可能会抛出错误,这里可以忽略 - logger.debug('Provider initialization skipped (may already be initialized)', { - providerId: this.config.providerId, - error: error instanceof Error ? error.message : String(error) - }) - } - if (config.isImageGenerationEndpoint) { - return await this.modernImageGeneration(modelId, params, config) + return await this.modernImageGeneration(model as ImageModel, params, config) } - return await this.modernCompletions(modelId, params, config) + return await this.modernCompletions(model as LanguageModel, params, config) } /** @@ -107,10 +123,11 @@ export default class ModernAiProvider { * 类似于legacy的completionsForTrace,确保AI SDK spans在正确的trace上下文中 */ private async _completionsForTrace( - modelId: string, + model: AiSdkModel, params: StreamTextParams, config: ModernAiProviderConfig & { topicId: string } ): Promise { + const modelId = this.model.id const traceName = `${this.actualProvider.name}.${modelId}.${config.callType}` const traceParams: StartSpanParams = { name: traceName, @@ -136,7 +153,7 @@ export default class ModernAiProvider { modelId, traceName }) - return await this._completions(modelId, params, config) + return await this._completionsOrImageGeneration(model, params, config) } try { @@ -148,7 +165,7 @@ export default class ModernAiProvider { parentSpanCreated: true }) - const result = await this._completions(modelId, params, config) + const result = await this._completionsOrImageGeneration(model, params, config) logger.info('Completions finished, ending parent span', { spanId: span.spanContext().spanId, @@ -190,10 +207,11 @@ export default class ModernAiProvider { * 使用现代化AI SDK的completions实现 */ private async modernCompletions( - modelId: string, + model: LanguageModel, params: StreamTextParams, config: ModernAiProviderConfig ): Promise { + const modelId = this.model.id logger.info('Starting modernCompletions', { modelId, providerId: this.config.providerId, @@ -205,122 +223,48 @@ export default class ModernAiProvider { // 根据条件构建插件数组 const plugins = buildPlugins(config) - logger.debug('Built plugins for AI SDK', { - pluginCount: plugins.length, - pluginNames: plugins.map((p) => p.name), - providerId: this.config.providerId, - topicId: config.topicId - }) // 用构建好的插件数组创建executor const executor = createExecutor(this.config.providerId, this.config.options, plugins) - logger.debug('Created AI SDK executor', { - providerId: this.config.providerId, - hasOptions: !!this.config.options, - pluginCount: plugins.length - }) - - // 动态构建中间件数组 - const middlewares = buildAiSdkMiddlewares(config) - logger.debug('Built AI SDK middlewares', { - middlewareCount: middlewares.length, - topicId: config.topicId - }) // 创建带有中间件的执行器 if (config.onChunk) { - // 流式处理 - 使用适配器 - logger.info('Starting streaming with chunk adapter', { - modelId, - hasMiddlewares: middlewares.length > 0, - middlewareCount: middlewares.length, - hasMcpTools: !!config.mcpTools, - mcpToolCount: config.mcpTools?.length || 0, - topicId: config.topicId - }) - const accumulate = this.model.supported_text_delta !== false // true and undefined const adapter = new AiSdkToChunkAdapter(config.onChunk, config.mcpTools, accumulate) - logger.debug('Final params before streamText', { - modelId, - hasMessages: !!params.messages, - messageCount: params.messages?.length || 0, - hasTools: !!params.tools && Object.keys(params.tools).length > 0, - toolNames: params.tools ? Object.keys(params.tools) : [], - hasSystem: !!params.system, - topicId: config.topicId - }) - - const streamResult = await executor.streamText( - modelId, - { ...params, experimental_context: { onChunk: config.onChunk } }, - middlewares.length > 0 ? { middlewares } : undefined - ) - - logger.info('StreamText call successful, processing stream', { - modelId, - topicId: config.topicId, - hasFullStream: !!streamResult.fullStream + const streamResult = await executor.streamText({ + ...params, + model, + experimental_context: { onChunk: config.onChunk } }) const finalText = await adapter.processStream(streamResult) - logger.info('Stream processing completed', { - modelId, - topicId: config.topicId, - finalTextLength: finalText.length - }) - return { getText: () => finalText } } else { - // 流式处理但没有 onChunk 回调 - logger.info('Starting streaming without chunk callback', { - modelId, - hasMiddlewares: middlewares.length > 0, - middlewareCount: middlewares.length, - topicId: config.topicId + const streamResult = await executor.streamText({ + ...params, + model }) - const streamResult = await executor.streamText( - modelId, - params, - middlewares.length > 0 ? { middlewares } : undefined - ) - - logger.info('StreamText call successful, waiting for text', { - modelId, - topicId: config.topicId - }) // 强制消费流,不然await streamResult.text会阻塞 await streamResult?.consumeStream() const finalText = await streamResult.text - logger.info('Text extraction completed', { - modelId, - topicId: config.topicId, - finalTextLength: finalText.length - }) - return { getText: () => finalText } } - // } - // catch (error) { - // console.error('Modern AI SDK error:', error) - // throw error - // } } /** * 使用现代化 AI SDK 的图像生成实现,支持流式输出 */ private async modernImageGeneration( - modelId: string, + model: ImageModel, params: StreamTextParams, config: ModernAiProviderConfig ): Promise { @@ -364,7 +308,11 @@ export default class ModernAiProvider { } // 调用新 AI SDK 的图像生成功能 - const result = await generateImage(this.config.providerId, this.config.options, modelId, imageParams) + const executor = createExecutor(this.config.providerId, this.config.options, []) + const result = await executor.generateImage({ + model, + ...imageParams + }) // 转换结果格式 const images: string[] = [] @@ -441,51 +389,64 @@ export default class ModernAiProvider { return this.legacyProvider.getEmbeddingDimensions(model) } - public async generateImage(params: GenerateImageParams): Promise { - // 如果支持新的 AI SDK,使用现代化实现 - if (isModernSdkSupported(this.actualProvider)) { - try { - const result = await this.modernGenerateImage(params) - return result - } catch (error) { - logger.warn('Modern AI SDK generateImage failed, falling back to legacy:', error as Error) - // fallback 到传统实现 - return this.legacyProvider.generateImage(params) - } - } + // public async generateImage(params: GenerateImageParams): Promise { + // // 如果支持新的 AI SDK,使用现代化实现 + // if (isModernSdkSupported(this.actualProvider)) { + // try { + // // 确保本地provider已创建 + // if (!this.localProvider) { + // await prepareSpecialProviderConfig(this.actualProvider, this.config) + // this.localProvider = await createProvider(this.config.providerId, this.config.options) + // logger.debug('Local provider created for standalone image generation', { + // providerId: this.config.providerId + // }) + // } - // 直接使用传统实现 - return this.legacyProvider.generateImage(params) - } + // const result = await this.modernGenerateImage(params) + // return result + // } catch (error) { + // logger.warn('Modern AI SDK generateImage failed, falling back to legacy:', error as Error) + // // fallback 到传统实现 + // return this.legacyProvider.generateImage(params) + // } + // } - /** - * 使用现代化 AI SDK 的图像生成实现 - */ - private async modernGenerateImage(params: GenerateImageParams): Promise { - const { model, prompt, imageSize, batchSize, signal } = params + // // 直接使用传统实现 + // return this.legacyProvider.generateImage(params) + // } - // 转换参数格式 - const aiSdkParams = { - prompt, - size: (imageSize || '1024x1024') as `${number}x${number}`, - n: batchSize || 1, - ...(signal && { abortSignal: signal }) - } + // /** + // * 使用现代化 AI SDK 的图像生成实现 + // */ + // private async modernGenerateImage(params: GenerateImageParams): Promise { + // const { model, prompt, imageSize, batchSize, signal } = params - const result = await generateImage(this.config.providerId, this.config.options, model, aiSdkParams) + // // 转换参数格式 + // const aiSdkParams = { + // prompt, + // size: (imageSize || '1024x1024') as `${number}x${number}`, + // n: batchSize || 1, + // ...(signal && { abortSignal: signal }) + // } - // 转换结果格式 - const images: string[] = [] - if (result.images) { - for (const image of result.images) { - if ('base64' in image && image.base64) { - images.push(`data:image/png;base64,${image.base64}`) - } - } - } + // const executor = createExecutor(this.config.providerId, this.config.options, []) + // const result = await executor.generateImage({ + // model: this.localProvider?.imageModel(model) as ImageModel, + // ...aiSdkParams + // }) - return images - } + // // 转换结果格式 + // const images: string[] = [] + // if (result.images) { + // for (const image of result.images) { + // if ('base64' in image && image.base64) { + // images.push(`data:image/png;base64,${image.base64}`) + // } + // } + // } + + // return images + // } public getBaseURL(): string { return this.legacyProvider.getBaseURL() diff --git a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts index f0e0a35b0f..46355d30b3 100644 --- a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts +++ b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts @@ -30,7 +30,7 @@ const logger = loggerService.withContext('parameterBuilder') * 这是主要的参数构建函数,整合所有转换逻辑 */ export async function buildStreamTextParams( - sdkMessages: StreamTextParams['messages'], + sdkMessages: StreamTextParams['messages'] = [], assistant: Assistant, provider: Provider, options: { diff --git a/src/renderer/src/aiCore/provider/config/aihubmix.ts b/src/renderer/src/aiCore/provider/config/aihubmix.ts index 928b30acfa..88453ca38e 100644 --- a/src/renderer/src/aiCore/provider/config/aihubmix.ts +++ b/src/renderer/src/aiCore/provider/config/aihubmix.ts @@ -4,9 +4,8 @@ import { isOpenAIModel } from '@renderer/config/models' import { Provider } from '@renderer/types' -import { startsWith } from './helper' -import { provider2Provider } from './helper' -import type { ModelRule } from './types' +import { provider2Provider, startsWith } from './helper' +import type { RuleSet } from './types' const extraProviderConfig = (provider: Provider) => { return { @@ -18,41 +17,41 @@ const extraProviderConfig = (provider: Provider) => { } } -const AIHUBMIX_RULES: ModelRule[] = [ - { - name: 'claude', - match: startsWith('claude'), - provider: (provider: Provider) => { - return extraProviderConfig({ - ...provider, - type: 'anthropic' - }) +const AIHUBMIX_RULES: RuleSet = { + rules: [ + { + match: startsWith('claude'), + provider: (provider: Provider) => { + return extraProviderConfig({ + ...provider, + type: 'anthropic' + }) + } + }, + { + match: (model) => + (startsWith('gemini')(model) || startsWith('imagen')(model)) && + !model.id.endsWith('-nothink') && + !model.id.endsWith('-search'), + provider: (provider: Provider) => { + return extraProviderConfig({ + ...provider, + type: 'gemini', + apiHost: 'https://aihubmix.com/gemini' + }) + } + }, + { + match: isOpenAIModel, + provider: (provider: Provider) => { + return extraProviderConfig({ + ...provider, + type: 'openai-response' + }) + } } - }, - { - name: 'gemini', - match: (model) => - (startsWith('gemini')(model) || startsWith('imagen')(model)) && - !model.id.endsWith('-nothink') && - !model.id.endsWith('-search'), - provider: (provider: Provider) => { - return extraProviderConfig({ - ...provider, - type: 'gemini', - apiHost: 'https://aihubmix.com/gemini' - }) - } - }, - { - name: 'openai', - match: isOpenAIModel, - provider: (provider: Provider) => { - return extraProviderConfig({ - ...provider, - type: 'openai-response' - }) - } - } -] + ], + fallbackRule: (provider: Provider) => provider +} export const aihubmixProviderCreator = provider2Provider.bind(null, AIHUBMIX_RULES) diff --git a/src/renderer/src/aiCore/provider/config/helper.ts b/src/renderer/src/aiCore/provider/config/helper.ts index 31ae9b4eb0..656911fc76 100644 --- a/src/renderer/src/aiCore/provider/config/helper.ts +++ b/src/renderer/src/aiCore/provider/config/helper.ts @@ -1,22 +1,22 @@ import type { Model, Provider } from '@renderer/types' -import type { ModelRule } from './types' +import type { RuleSet } from './types' export const startsWith = (prefix: string) => (model: Model) => model.id.toLowerCase().startsWith(prefix.toLowerCase()) export const endpointIs = (type: string) => (model: Model) => model.endpoint_type === type /** - * 解析模型对应的Provider ID + * 解析模型对应的Provider + * @param ruleSet 规则集对象 * @param model 模型对象 - * @param rules 匹配规则数组 - * @param fallback 默认fallback的providerId - * @returns 解析出的providerId + * @param provider 原始provider对象 + * @returns 解析出的provider对象 */ -export function provider2Provider(rules: ModelRule[], model: Model, provider: Provider): Provider { - for (const rule of rules) { +export function provider2Provider(ruleSet: RuleSet, model: Model, provider: Provider): Provider { + for (const rule of ruleSet.rules) { if (rule.match(model)) { return rule.provider(provider) } } - return provider + return ruleSet.fallbackRule(provider) } diff --git a/src/renderer/src/aiCore/provider/config/index.ts b/src/renderer/src/aiCore/provider/config/index.ts index 7c19231d4e..2f51234cec 100644 --- a/src/renderer/src/aiCore/provider/config/index.ts +++ b/src/renderer/src/aiCore/provider/config/index.ts @@ -1,16 +1,3 @@ -// /** -// * Provider解析规则模块导出 -// */ - -// // 导出类型 -// export type { ModelRule } from './types' - -// // 导出匹配函数和解析器 -// export { endpointIs, resolveProvider, startsWith } from './helper' - -// // 导出规则集 -// export { AIHUBMIX_RULES } from './aihubmix' -// export { NEWAPI_RULES } from './newApi' - export { aihubmixProviderCreator } from './aihubmix' export { newApiResolverCreator } from './newApi' +export { vertexAnthropicProviderCreator } from './vertext-anthropic' diff --git a/src/renderer/src/aiCore/provider/config/newApi.ts b/src/renderer/src/aiCore/provider/config/newApi.ts index e7de0bb328..5277495cdb 100644 --- a/src/renderer/src/aiCore/provider/config/newApi.ts +++ b/src/renderer/src/aiCore/provider/config/newApi.ts @@ -4,49 +4,48 @@ import { Provider } from '@renderer/types' import { endpointIs, provider2Provider } from './helper' -import type { ModelRule } from './types' +import type { RuleSet } from './types' -const NEWAPI_RULES: ModelRule[] = [ - { - name: 'anthropic', - match: endpointIs('anthropic'), - provider: (provider: Provider) => { - return { - ...provider, - type: 'anthropic' +const NEWAPI_RULES: RuleSet = { + rules: [ + { + match: endpointIs('anthropic'), + provider: (provider: Provider) => { + return { + ...provider, + type: 'anthropic' + } + } + }, + { + match: endpointIs('gemini'), + provider: (provider: Provider) => { + return { + ...provider, + type: 'gemini' + } + } + }, + { + match: endpointIs('openai-response'), + provider: (provider: Provider) => { + return { + ...provider, + type: 'openai-response' + } + } + }, + { + match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model), + provider: (provider: Provider) => { + return { + ...provider, + type: 'openai' + } } } - }, - { - name: 'gemini', - match: endpointIs('gemini'), - provider: (provider: Provider) => { - return { - ...provider, - type: 'gemini' - } - } - }, - { - name: 'openai-response', - match: endpointIs('openai-response'), - provider: (provider: Provider) => { - return { - ...provider, - type: 'openai-response' - } - } - }, - { - name: 'openai', - match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model), - provider: (provider: Provider) => { - return { - ...provider, - type: 'openai' - } - } - } -] + ], + fallbackRule: (provider: Provider) => provider +} export const newApiResolverCreator = provider2Provider.bind(null, NEWAPI_RULES) diff --git a/src/renderer/src/aiCore/provider/config/types.ts b/src/renderer/src/aiCore/provider/config/types.ts index 5f3cc5a56b..f3938b84d1 100644 --- a/src/renderer/src/aiCore/provider/config/types.ts +++ b/src/renderer/src/aiCore/provider/config/types.ts @@ -1,7 +1,9 @@ import type { Model, Provider } from '@renderer/types' -export interface ModelRule { - name: string - match: (model: Model) => boolean - provider: (provider: Provider) => Provider +export interface RuleSet { + rules: Array<{ + match: (model: Model) => boolean + provider: (provider: Provider) => Provider + }> + fallbackRule: (provider: Provider) => Provider } diff --git a/src/renderer/src/aiCore/provider/config/vertext-anthropic.ts b/src/renderer/src/aiCore/provider/config/vertext-anthropic.ts new file mode 100644 index 0000000000..23c8b5185c --- /dev/null +++ b/src/renderer/src/aiCore/provider/config/vertext-anthropic.ts @@ -0,0 +1,19 @@ +import type { Provider } from '@renderer/types' + +import { provider2Provider, startsWith } from './helper' +import type { RuleSet } from './types' + +const VERTEX_ANTHROPIC_RULES: RuleSet = { + rules: [ + { + match: startsWith('claude'), + provider: (provider: Provider) => ({ + ...provider, + id: 'google-vertex-anthropic' + }) + } + ], + fallbackRule: (provider: Provider) => provider +} + +export const vertexAnthropicProviderCreator = provider2Provider.bind(null, VERTEX_ANTHROPIC_RULES) diff --git a/src/renderer/src/aiCore/provider/factory.ts b/src/renderer/src/aiCore/provider/factory.ts index 752450ea0c..617758753e 100644 --- a/src/renderer/src/aiCore/provider/factory.ts +++ b/src/renderer/src/aiCore/provider/factory.ts @@ -1,6 +1,8 @@ import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider' +import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' import { loggerService } from '@logger' import { Provider } from '@renderer/types' +import type { Provider as AiSdkProvider } from 'ai' import { initializeNewProviders } from './providerInitialization' @@ -70,3 +72,28 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com // 3. 最后的fallback(通常会成为openai-compatible) return provider.id as ProviderId } + +export async function createAiSdkProvider(config) { + let localProvider: Awaited | null = null + try { + if (config.providerId === 'openai' && config.options?.mode === 'chat') { + config.providerId = `${config.providerId}-chat` + } else if (config.providerId === 'azure' && config.options?.mode === 'responses') { + config.providerId = `${config.providerId}-responses` + } + localProvider = await createProviderCore(config.providerId, config.options) + + logger.debug('Local provider created successfully', { + providerId: config.providerId, + hasOptions: !!config.options, + localProvider: localProvider, + options: config.options + }) + } catch (error) { + logger.error('Failed to create local provider', error as Error, { + providerId: config.providerId + }) + throw error + } + return localProvider +} diff --git a/src/renderer/src/aiCore/provider/providerConfig.ts b/src/renderer/src/aiCore/provider/providerConfig.ts index 3b7c25ee42..cdc8621e62 100644 --- a/src/renderer/src/aiCore/provider/providerConfig.ts +++ b/src/renderer/src/aiCore/provider/providerConfig.ts @@ -19,7 +19,7 @@ import type { Model, Provider } from '@renderer/types' import { formatApiHost } from '@renderer/utils/api' import { cloneDeep, isEmpty } from 'lodash' -import { aihubmixProviderCreator, newApiResolverCreator } from './config' +import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config' import { getAiSdkProviderId } from './factory' const logger = loggerService.withContext('ProviderConfigProcessor') @@ -67,6 +67,9 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider { if (provider.id === 'newapi') { return newApiResolverCreator(model, provider) } + if (provider.id === 'vertexai') { + return vertexAnthropicProviderCreator(model, provider) + } return provider } @@ -157,7 +160,7 @@ export function providerToAiSdkConfig( extraOptions.secretAccessKey = getAwsBedrockSecretAccessKey() } // google-vertex - if (aiSdkProviderId === 'google-vertex') { + if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') { if (!isVertexAIConfigured()) { throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.') } diff --git a/src/renderer/src/aiCore/provider/providerInitialization.ts b/src/renderer/src/aiCore/provider/providerInitialization.ts index 49b8544b38..cf3366d70a 100644 --- a/src/renderer/src/aiCore/provider/providerInitialization.ts +++ b/src/renderer/src/aiCore/provider/providerInitialization.ts @@ -24,6 +24,14 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [ supportsImageGeneration: true, aliases: ['vertexai'] }, + { + id: 'google-vertex-anthropic', + name: 'Google Vertex AI Anthropic', + import: () => import('@ai-sdk/google-vertex/anthropic/edge'), + creatorFunctionName: 'createVertexAnthropic', + supportsImageGeneration: true, + aliases: ['vertexai-anthropic'] + }, { id: 'bedrock', name: 'Amazon Bedrock', diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 4438cb94d9..5bd6201d98 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -93,7 +93,6 @@ export async function fetchChatCompletion({ modelId: assistant.model?.id, modelName: assistant.model?.name }) - const AI = new AiProviderNew(assistant.model || getDefaultModel()) const provider = AI.getActualProvider() @@ -126,7 +125,6 @@ export async function fetchChatCompletion({ streamOutput: assistant.settings?.streamOutput ?? true, onChunk: onChunkReceived, model: assistant.model, - provider: provider, enableReasoning: capabilities.enableReasoning, isPromptToolUse: isPromptToolUse(assistant), isSupportedToolUse: isSupportedToolUse(assistant), diff --git a/src/renderer/src/types/aiCoreTypes.ts b/src/renderer/src/types/aiCoreTypes.ts index 14edaf5abb..e93218ab9e 100644 --- a/src/renderer/src/types/aiCoreTypes.ts +++ b/src/renderer/src/types/aiCoreTypes.ts @@ -1,6 +1,29 @@ -import { generateObject, generateText, streamObject, streamText } from 'ai' +import type { ImageModel, LanguageModel } from 'ai' +import { generateObject, generateText, ModelMessage, streamObject, streamText } from 'ai' -export type StreamTextParams = Omit[0], 'model'> -export type GenerateTextParams = Omit[0], 'model'> +export type StreamTextParams = Omit[0], 'model' | 'messages'> & + ( + | { + prompt: string | Array + messages?: never + } + | { + messages: Array + prompt?: never + } + ) +export type GenerateTextParams = Omit[0], 'model' | 'messages'> & + ( + | { + prompt: string | Array + messages?: never + } + | { + messages: Array + prompt?: never + } + ) export type StreamObjectParams = Omit[0], 'model'> export type GenerateObjectParams = Omit[0], 'model'> + +export type AiSdkModel = LanguageModel | ImageModel diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index d21b8972b9..ae8ce30e70 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -356,6 +356,7 @@ export type ProviderType = | 'vertexai' | 'mistral' | 'aws-bedrock' + | 'vertex-anthropic' export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search' | 'rerank' diff --git a/yarn.lock b/yarn.lock index a41869b827..5899227c0e 100644 --- a/yarn.lock +++ b/yarn.lock @@ -9209,7 +9209,7 @@ __metadata: "@viz-js/lang-dot": "npm:^1.0.5" "@viz-js/viz": "npm:^3.14.0" "@xyflow/react": "npm:^12.4.4" - ai: "npm:^5.0.26" + ai: "npm:^5.0.29" antd: "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch" archiver: "npm:^7.0.1" async-mutex: "npm:^0.5.0" @@ -9425,9 +9425,9 @@ __metadata: languageName: node linkType: hard -"ai@npm:^5.0.26": - version: 5.0.26 - resolution: "ai@npm:5.0.26" +"ai@npm:^5.0.29": + version: 5.0.29 + resolution: "ai@npm:5.0.29" dependencies: "@ai-sdk/gateway": "npm:1.0.15" "@ai-sdk/provider": "npm:2.0.0" @@ -9435,7 +9435,7 @@ __metadata: "@opentelemetry/api": "npm:1.9.0" peerDependencies: zod: ^3.25.76 || ^4 - checksum: 10c0/0423f296b1aa9f22ad106278e8d1e7a2ae9d068358720cdc23c0f222af5406ac1e5ccbce19833709aa1c62b841361d310fb3c42781f426966a9f9ca287ae7faa + checksum: 10c0/526cd2fd59b35b19d902665e3dc1ba5a09f2bb1377295d642fb8a33e13a890874e4dd4b49a787de7f31f4ec6b07257be8514efac08f993081daeb430cf2f60ba languageName: node linkType: hard