mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-27 12:51:26 +08:00
chore(dependencies): update ai package version and enhance aiCore functionality
- Updated the 'ai' package version from 5.0.26 to 5.0.29 in package.json and yarn.lock. - Refactored aiCore's provider schemas to introduce new provider types and improve type safety. - Enhanced the RuntimeExecutor class to streamline model handling and plugin execution for various AI tasks. - Updated tests to reflect changes in parameter handling and ensure compatibility with new provider configurations.
This commit is contained in:
parent
4cceddc179
commit
782953cca1
@ -174,7 +174,7 @@
|
||||
"@viz-js/lang-dot": "^1.0.5",
|
||||
"@viz-js/viz": "^3.14.0",
|
||||
"@xyflow/react": "^12.4.4",
|
||||
"ai": "^5.0.26",
|
||||
"ai": "^5.0.29",
|
||||
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
|
||||
"archiver": "^7.0.1",
|
||||
"async-mutex": "^0.5.0",
|
||||
|
||||
@ -4,11 +4,13 @@
|
||||
|
||||
import { createAnthropic } from '@ai-sdk/anthropic'
|
||||
import { createAzure } from '@ai-sdk/azure'
|
||||
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
|
||||
import { createDeepSeek } from '@ai-sdk/deepseek'
|
||||
import { createGoogleGenerativeAI } from '@ai-sdk/google'
|
||||
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
|
||||
import { createXai } from '@ai-sdk/xai'
|
||||
import { customProvider, type Provider } from 'ai'
|
||||
import * as z from 'zod'
|
||||
|
||||
/**
|
||||
@ -16,12 +18,13 @@ import * as z from 'zod'
|
||||
*/
|
||||
export const baseProviderIds = [
|
||||
'openai',
|
||||
'openai-responses',
|
||||
'openai-chat',
|
||||
'openai-compatible',
|
||||
'anthropic',
|
||||
'google',
|
||||
'xai',
|
||||
'azure',
|
||||
'azure-responses',
|
||||
'deepseek'
|
||||
] as const
|
||||
|
||||
@ -38,7 +41,7 @@ export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
|
||||
export const baseProviderSchema = z.object({
|
||||
id: baseProviderIdSchema,
|
||||
name: z.string(),
|
||||
creator: z.function().args(z.any()).returns(z.any()),
|
||||
creator: z.function().args(z.any()).returns(z.any()) as z.ZodType<(options: any) => Provider>,
|
||||
supportsImageGeneration: z.boolean()
|
||||
})
|
||||
|
||||
@ -56,9 +59,17 @@ export const baseProviders = [
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'openai-responses',
|
||||
name: 'OpenAI Responses',
|
||||
creator: (options: OpenAIProviderSettings) => createOpenAI(options).responses,
|
||||
id: 'openai-chat',
|
||||
name: 'OpenAI Chat',
|
||||
creator: (options: OpenAIProviderSettings) => {
|
||||
const provider = createOpenAI(options)
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.chat(modelId)
|
||||
}
|
||||
})
|
||||
},
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
@ -91,6 +102,20 @@ export const baseProviders = [
|
||||
creator: createAzure,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'azure-responses',
|
||||
name: 'Azure OpenAI Responses',
|
||||
creator: (options: AzureOpenAIProviderSettings) => {
|
||||
const provider = createAzure(options)
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.responses(modelId)
|
||||
}
|
||||
})
|
||||
},
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'deepseek',
|
||||
name: 'DeepSeek',
|
||||
|
||||
@ -4,7 +4,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { type AiPlugin } from '../../plugins'
|
||||
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||
import { ImageGenerationError } from '../errors'
|
||||
import { ImageGenerationError, ImageModelResolutionError } from '../errors'
|
||||
import { RuntimeExecutor } from '../executor'
|
||||
|
||||
// Mock dependencies
|
||||
@ -76,9 +76,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
|
||||
describe('Basic functionality', () => {
|
||||
it('should generate a single image with minimal parameters', async () => {
|
||||
const result = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A futuristic cityscape at sunset'
|
||||
})
|
||||
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape at sunset' })
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai|dall-e-3')
|
||||
|
||||
@ -91,7 +89,8 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
})
|
||||
|
||||
it('should generate image with pre-created model', async () => {
|
||||
const result = await executor.generateImage(mockImageModel, {
|
||||
const result = await executor.generateImage({
|
||||
model: mockImageModel,
|
||||
prompt: 'A beautiful landscape'
|
||||
})
|
||||
|
||||
@ -105,10 +104,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
})
|
||||
|
||||
it('should support multiple images generation', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A futuristic cityscape',
|
||||
n: 3
|
||||
})
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape', n: 3 })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
@ -118,10 +114,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
})
|
||||
|
||||
it('should support size specification', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A beautiful sunset',
|
||||
size: '1024x1024'
|
||||
})
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A beautiful sunset', size: '1024x1024' })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
@ -131,10 +124,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
})
|
||||
|
||||
it('should support aspect ratio specification', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A mountain landscape',
|
||||
aspectRatio: '16:9'
|
||||
})
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A mountain landscape', aspectRatio: '16:9' })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
@ -144,10 +134,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
})
|
||||
|
||||
it('should support seed for consistent output', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A cat in space',
|
||||
seed: 1234567890
|
||||
})
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A cat in space', seed: 1234567890 })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
@ -159,10 +146,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
it('should support abort signal', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A cityscape',
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A cityscape', abortSignal: abortController.signal })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
@ -172,7 +156,8 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
})
|
||||
|
||||
it('should support provider-specific options', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
await executor.generateImage({
|
||||
model: 'dall-e-3',
|
||||
prompt: 'A space station',
|
||||
providerOptions: {
|
||||
openai: {
|
||||
@ -195,7 +180,8 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
})
|
||||
|
||||
it('should support custom headers', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
await executor.generateImage({
|
||||
model: 'dall-e-3',
|
||||
prompt: 'A robot',
|
||||
headers: {
|
||||
'X-Custom-Header': 'test-value'
|
||||
@ -242,9 +228,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
[testPlugin]
|
||||
)
|
||||
|
||||
const result = await executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
const result = await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||
|
||||
expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
|
||||
|
||||
@ -287,9 +271,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
[modelResolutionPlugin]
|
||||
)
|
||||
|
||||
await executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||
|
||||
expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith(
|
||||
'dall-e-3',
|
||||
@ -312,6 +294,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
if (!context.isRecursiveCall && params.prompt === 'original') {
|
||||
// Make a recursive call with modified prompt
|
||||
await context.recursiveCall({
|
||||
model: 'dall-e-3',
|
||||
prompt: 'modified'
|
||||
})
|
||||
}
|
||||
@ -327,9 +310,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
[recursivePlugin]
|
||||
)
|
||||
|
||||
await executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'original'
|
||||
})
|
||||
await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'original' })
|
||||
|
||||
expect(recursivePlugin.transformParams).toHaveBeenCalledTimes(2)
|
||||
expect(aiGenerateImage).toHaveBeenCalledTimes(2)
|
||||
@ -343,22 +324,47 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
throw modelError
|
||||
})
|
||||
|
||||
await expect(
|
||||
executor.generateImage('invalid-model', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
).rejects.toThrow(ImageGenerationError)
|
||||
await expect(executor.generateImage({ model: 'invalid-model', prompt: 'A test image' })).rejects.toThrow(
|
||||
ImageGenerationError
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle ImageModelResolutionError correctly', async () => {
|
||||
const resolutionError = new ImageModelResolutionError('invalid-model', 'openai', new Error('Model not found'))
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
|
||||
throw resolutionError
|
||||
})
|
||||
|
||||
const thrownError = await executor
|
||||
.generateImage({ model: 'invalid-model', prompt: 'A test image' })
|
||||
.catch((error) => error)
|
||||
|
||||
expect(thrownError).toBeInstanceOf(ImageGenerationError)
|
||||
expect(thrownError.message).toContain('Failed to generate image:')
|
||||
expect(thrownError.providerId).toBe('openai')
|
||||
expect(thrownError.modelId).toBe('invalid-model')
|
||||
expect(thrownError.cause).toBeInstanceOf(ImageModelResolutionError)
|
||||
expect(thrownError.cause.message).toContain('Failed to resolve image model: invalid-model')
|
||||
})
|
||||
|
||||
it('should handle ImageModelResolutionError without provider', async () => {
|
||||
const resolutionError = new ImageModelResolutionError('unknown-model')
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
|
||||
throw resolutionError
|
||||
})
|
||||
|
||||
await expect(executor.generateImage({ model: 'unknown-model', prompt: 'A test image' })).rejects.toThrow(
|
||||
ImageGenerationError
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle image generation API errors', async () => {
|
||||
const apiError = new Error('API request failed')
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(apiError)
|
||||
|
||||
await expect(
|
||||
executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
).rejects.toThrow('Failed to generate image:')
|
||||
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
||||
'Failed to generate image:'
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle NoImageGeneratedError', async () => {
|
||||
@ -370,11 +376,9 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(noImageError)
|
||||
vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true)
|
||||
|
||||
await expect(
|
||||
executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
).rejects.toThrow('Failed to generate image:')
|
||||
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
||||
'Failed to generate image:'
|
||||
)
|
||||
})
|
||||
|
||||
it('should execute onError plugin hook on failure', async () => {
|
||||
@ -394,11 +398,9 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
[errorPlugin]
|
||||
)
|
||||
|
||||
await expect(
|
||||
executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
).rejects.toThrow('Failed to generate image:')
|
||||
await expect(executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
||||
'Failed to generate image:'
|
||||
)
|
||||
|
||||
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
||||
error,
|
||||
@ -418,10 +420,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
setTimeout(() => abortController.abort(), 10)
|
||||
|
||||
await expect(
|
||||
executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image',
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', abortSignal: abortController.signal })
|
||||
).rejects.toThrow('Failed to generate image:')
|
||||
})
|
||||
})
|
||||
@ -432,9 +431,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
apiKey: 'google-key'
|
||||
})
|
||||
|
||||
await googleExecutor.generateImage('imagen-3.0-generate-002', {
|
||||
prompt: 'A landscape'
|
||||
})
|
||||
await googleExecutor.generateImage({ model: 'imagen-3.0-generate-002', prompt: 'A landscape' })
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('google|imagen-3.0-generate-002')
|
||||
})
|
||||
@ -444,9 +441,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
apiKey: 'xai-key'
|
||||
})
|
||||
|
||||
await xaiExecutor.generateImage('grok-2-image', {
|
||||
prompt: 'A futuristic robot'
|
||||
})
|
||||
await xaiExecutor.generateImage({ model: 'grok-2-image', prompt: 'A futuristic robot' })
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('xai|grok-2-image')
|
||||
})
|
||||
@ -454,11 +449,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
|
||||
describe('Advanced features', () => {
|
||||
it('should support batch image generation with maxImagesPerCall', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image',
|
||||
n: 10,
|
||||
maxImagesPerCall: 5
|
||||
})
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', n: 10, maxImagesPerCall: 5 })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
@ -469,10 +460,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
})
|
||||
|
||||
it('should support retries with maxRetries', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image',
|
||||
maxRetries: 3
|
||||
})
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', maxRetries: 3 })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
@ -494,7 +482,8 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
|
||||
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithWarnings)
|
||||
|
||||
const result = await executor.generateImage('dall-e-3', {
|
||||
const result = await executor.generateImage({
|
||||
model: 'dall-e-3',
|
||||
prompt: 'A test image',
|
||||
size: '2048x2048' // Unsupported size
|
||||
})
|
||||
@ -504,9 +493,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
})
|
||||
|
||||
it('should provide access to provider metadata', async () => {
|
||||
const result = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||
|
||||
expect(result.providerMetadata).toBeDefined()
|
||||
expect(result.providerMetadata.openai).toBeDefined()
|
||||
@ -526,9 +513,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
|
||||
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithMetadata)
|
||||
|
||||
const result = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||
|
||||
expect(result.responses).toHaveLength(1)
|
||||
expect(result.responses[0].modelId).toBe('dall-e-3')
|
||||
|
||||
@ -72,42 +72,36 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
// === 高阶重载:直接使用模型 ===
|
||||
|
||||
/**
|
||||
* 流式文本生成 - 使用已创建的模型(高级用法)
|
||||
* 流式文本生成
|
||||
*/
|
||||
async streamText(
|
||||
model: LanguageModel,
|
||||
params: Omit<Parameters<typeof streamText>[0], 'model'>
|
||||
): Promise<ReturnType<typeof streamText>>
|
||||
async streamText(
|
||||
modelId: string,
|
||||
params: Omit<Parameters<typeof streamText>[0], 'model'>,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamText>>
|
||||
async streamText(
|
||||
modelOrId: LanguageModel,
|
||||
params: Omit<Parameters<typeof streamText>[0], 'model'>,
|
||||
params: Parameters<typeof streamText>[0],
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamText>> {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
const { model, ...restParams } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
} else {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
// 2. 执行插件处理
|
||||
return this.pluginEngine.executeStreamWithPlugins(
|
||||
'streamText',
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
params,
|
||||
async (model, transformedParams, streamTransforms) => {
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams, streamTransforms) => {
|
||||
const experimental_transform =
|
||||
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
|
||||
|
||||
const finalParams = {
|
||||
model,
|
||||
model: resolvedModel,
|
||||
...transformedParams,
|
||||
experimental_transform
|
||||
} as Parameters<typeof streamText>[0]
|
||||
@ -120,145 +114,126 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
// === 其他方法的重载 ===
|
||||
|
||||
/**
|
||||
* 生成文本 - 使用已创建的模型
|
||||
* 生成文本
|
||||
*/
|
||||
async generateText(
|
||||
model: LanguageModel,
|
||||
params: Omit<Parameters<typeof generateText>[0], 'model'>
|
||||
): Promise<ReturnType<typeof generateText>>
|
||||
async generateText(
|
||||
modelId: string,
|
||||
params: Omit<Parameters<typeof generateText>[0], 'model'>,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateText>>
|
||||
async generateText(
|
||||
modelOrId: LanguageModel | string,
|
||||
params: Omit<Parameters<typeof generateText>[0], 'model'>,
|
||||
params: Parameters<typeof generateText>[0],
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateText>> {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
const { model, ...restParams } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
} else {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'generateText',
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
params,
|
||||
async (model, transformedParams) =>
|
||||
generateText({ model, ...transformedParams } as Parameters<typeof generateText>[0])
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) =>
|
||||
generateText({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateText>[0])
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成结构化对象 - 使用已创建的模型
|
||||
* 生成结构化对象
|
||||
*/
|
||||
async generateObject(
|
||||
model: LanguageModel,
|
||||
params: Omit<Parameters<typeof generateObject>[0], 'model'>
|
||||
): Promise<ReturnType<typeof generateObject>>
|
||||
async generateObject(
|
||||
modelOrId: string,
|
||||
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateObject>>
|
||||
async generateObject(
|
||||
modelOrId: LanguageModel | string,
|
||||
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
|
||||
params: Parameters<typeof generateObject>[0],
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateObject>> {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
const { model, ...restParams } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
} else {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'generateObject',
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
params,
|
||||
async (model, transformedParams) =>
|
||||
generateObject({ model, ...transformedParams } as Parameters<typeof generateObject>[0])
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) =>
|
||||
generateObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateObject>[0])
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 流式生成结构化对象 - 使用已创建的模型
|
||||
* 流式生成结构化对象
|
||||
*/
|
||||
async streamObject(
|
||||
model: LanguageModel,
|
||||
params: Omit<Parameters<typeof streamObject>[0], 'model'>
|
||||
): Promise<ReturnType<typeof streamObject>>
|
||||
async streamObject(
|
||||
modelId: string,
|
||||
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamObject>>
|
||||
async streamObject(
|
||||
modelOrId: LanguageModel | string,
|
||||
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
|
||||
params: Parameters<typeof streamObject>[0],
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamObject>> {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
const { model, ...restParams } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
} else {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'streamObject',
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
params,
|
||||
async (model, transformedParams) =>
|
||||
streamObject({ model, ...transformedParams } as Parameters<typeof streamObject>[0])
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) =>
|
||||
streamObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof streamObject>[0])
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成图像 - 使用已创建的图像模型
|
||||
* 生成图像
|
||||
*/
|
||||
async generateImage(
|
||||
model: ImageModelV2,
|
||||
params: Omit<Parameters<typeof generateImage>[0], 'model'>
|
||||
): Promise<ReturnType<typeof generateImage>>
|
||||
async generateImage(
|
||||
modelId: string,
|
||||
params: Omit<Parameters<typeof generateImage>[0], 'model'>,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateImage>>
|
||||
async generateImage(
|
||||
modelOrId: ImageModelV2 | string,
|
||||
params: Omit<Parameters<typeof generateImage>[0], 'model'>
|
||||
params: Omit<Parameters<typeof generateImage>[0], 'model'> & { model: string | ImageModelV2 }
|
||||
): Promise<ReturnType<typeof generateImage>> {
|
||||
try {
|
||||
this.pluginEngine.usePlugins([this.createResolveImageModelPlugin(), this.createConfigureContextPlugin()])
|
||||
const { model, ...restParams } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
this.pluginEngine.usePlugins([this.createResolveImageModelPlugin(), this.createConfigureContextPlugin()])
|
||||
} else {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return await this.pluginEngine.executeImageWithPlugins(
|
||||
'generateImage',
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
params,
|
||||
async (model, transformedParams) => {
|
||||
return await generateImage({ model, ...transformedParams })
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) => {
|
||||
return await generateImage({ model: resolvedModel, ...transformedParams })
|
||||
}
|
||||
)
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
const modelId = typeof params.model === 'string' ? params.model : params.model.modelId
|
||||
throw new ImageGenerationError(
|
||||
`Failed to generate image: ${error.message}`,
|
||||
this.config.providerId,
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
modelId,
|
||||
error
|
||||
)
|
||||
}
|
||||
|
||||
@ -46,13 +46,12 @@ export function createOpenAICompatibleExecutor(
|
||||
export async function streamText<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['streamText']>[1],
|
||||
params: Parameters<RuntimeExecutor<T>['streamText']>[0],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.streamText(modelId, params, { middlewares })
|
||||
return executor.streamText(params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
@ -61,13 +60,12 @@ export async function streamText<T extends ProviderId>(
|
||||
export async function generateText<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['generateText']>[1],
|
||||
params: Parameters<RuntimeExecutor<T>['generateText']>[0],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.generateText(modelId, params, { middlewares })
|
||||
return executor.generateText(params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
@ -76,13 +74,12 @@ export async function generateText<T extends ProviderId>(
|
||||
export async function generateObject<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['generateObject']>[1],
|
||||
params: Parameters<RuntimeExecutor<T>['generateObject']>[0],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateObject']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.generateObject(modelId, params, { middlewares })
|
||||
return executor.generateObject(params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
@ -91,13 +88,12 @@ export async function generateObject<T extends ProviderId>(
|
||||
export async function streamObject<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['streamObject']>[1],
|
||||
params: Parameters<RuntimeExecutor<T>['streamObject']>[0],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['streamObject']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.streamObject(modelId, params, { middlewares })
|
||||
return executor.streamObject(params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
@ -106,13 +102,11 @@ export async function streamObject<T extends ProviderId>(
|
||||
export async function generateImage<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['generateImage']>[1],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
params: Parameters<RuntimeExecutor<T>['generateImage']>[0],
|
||||
plugins?: AiPlugin[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateImage']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.generateImage(modelId, params, { middlewares })
|
||||
return executor.generateImage(params)
|
||||
}
|
||||
|
||||
// === Agent 功能预留 ===
|
||||
|
||||
@ -64,11 +64,24 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
*/
|
||||
async executeWithPlugins<TParams, TResult>(
|
||||
methodName: string,
|
||||
modelId: string,
|
||||
model: LanguageModel,
|
||||
params: TParams,
|
||||
executor: (model: LanguageModel, transformedParams: TParams) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: LanguageModel | undefined
|
||||
let modelId: string
|
||||
|
||||
if (typeof model === 'string') {
|
||||
// 字符串:需要通过插件解析
|
||||
modelId = model
|
||||
} else {
|
||||
// 模型对象:直接使用
|
||||
resolvedModel = model
|
||||
modelId = model.modelId
|
||||
}
|
||||
|
||||
// 使用正确的createContext创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
|
||||
@ -76,7 +89,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeWithPlugins(methodName, modelId, newParams, executor, context)
|
||||
const result = await this.executeWithPlugins(methodName, model, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
@ -88,17 +101,24 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型
|
||||
const model = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||
if (!model) {
|
||||
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||
// 2. 解析模型(如果是字符串)
|
||||
if (typeof model === 'string') {
|
||||
const resolved = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||
if (!resolved) {
|
||||
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||
}
|
||||
resolvedModel = resolved
|
||||
}
|
||||
|
||||
if (!resolvedModel) {
|
||||
throw new Error(`Model resolution failed: no model available`)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 执行具体的 API 调用
|
||||
const result = await executor(model, transformedParams)
|
||||
const result = await executor(resolvedModel, transformedParams)
|
||||
|
||||
// 5. 转换结果(对于非流式调用)
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
@ -120,11 +140,24 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
*/
|
||||
async executeImageWithPlugins<TParams, TResult>(
|
||||
methodName: string,
|
||||
modelId: string,
|
||||
model: ImageModelV2 | string,
|
||||
params: TParams,
|
||||
executor: (model: ImageModelV2, transformedParams: TParams) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: ImageModelV2 | undefined
|
||||
let modelId: string
|
||||
|
||||
if (typeof model === 'string') {
|
||||
// 字符串:需要通过插件解析
|
||||
modelId = model
|
||||
} else {
|
||||
// 模型对象:直接使用
|
||||
resolvedModel = model
|
||||
modelId = model.modelId
|
||||
}
|
||||
|
||||
// 使用正确的createContext创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
|
||||
@ -132,7 +165,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeImageWithPlugins(methodName, modelId, newParams, executor, context)
|
||||
const result = await this.executeImageWithPlugins(methodName, model, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
@ -144,17 +177,24 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型
|
||||
const model = await this.pluginManager.executeFirst<ImageModelV2>('resolveModel', modelId, context)
|
||||
if (!model) {
|
||||
throw new Error(`Failed to resolve image model: ${modelId}`)
|
||||
// 2. 解析模型(如果是字符串)
|
||||
if (typeof model === 'string') {
|
||||
const resolved = await this.pluginManager.executeFirst<ImageModelV2>('resolveModel', modelId, context)
|
||||
if (!resolved) {
|
||||
throw new Error(`Failed to resolve image model: ${modelId}`)
|
||||
}
|
||||
resolvedModel = resolved
|
||||
}
|
||||
|
||||
if (!resolvedModel) {
|
||||
throw new Error(`Image model resolution failed: no model available`)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 执行具体的 API 调用
|
||||
const result = await executor(model, transformedParams)
|
||||
const result = await executor(resolvedModel, transformedParams)
|
||||
|
||||
// 5. 转换结果
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
@ -176,11 +216,24 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
*/
|
||||
async executeStreamWithPlugins<TParams, TResult>(
|
||||
methodName: string,
|
||||
modelId: string,
|
||||
model: LanguageModel,
|
||||
params: TParams,
|
||||
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: LanguageModel | undefined
|
||||
let modelId: string
|
||||
|
||||
if (typeof model === 'string') {
|
||||
// 字符串:需要通过插件解析
|
||||
modelId = model
|
||||
} else {
|
||||
// 模型对象:直接使用
|
||||
resolvedModel = model
|
||||
modelId = model.modelId
|
||||
}
|
||||
|
||||
// 创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
|
||||
@ -188,7 +241,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeStreamWithPlugins(methodName, modelId, newParams, executor, context)
|
||||
const result = await this.executeStreamWithPlugins(methodName, model, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
@ -200,11 +253,17 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型
|
||||
const model = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||
// 2. 解析模型(如果是字符串)
|
||||
if (typeof model === 'string') {
|
||||
const resolved = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||
if (!resolved) {
|
||||
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||
}
|
||||
resolvedModel = resolved
|
||||
}
|
||||
|
||||
if (!model) {
|
||||
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||
if (!resolvedModel) {
|
||||
throw new Error(`Model resolution failed: no model available`)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
@ -214,7 +273,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context)
|
||||
|
||||
// 5. 执行流式 API 调用
|
||||
const result = await executor(model, transformedParams, streamTransforms)
|
||||
const result = await executor(resolvedModel, transformedParams, streamTransforms)
|
||||
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
|
||||
|
||||
@ -7,22 +7,23 @@
|
||||
* 2. 暂时保持接口兼容性
|
||||
*/
|
||||
|
||||
import { createExecutor, generateImage } from '@cherrystudio/ai-core'
|
||||
import { createAndRegisterProvider } from '@cherrystudio/ai-core/provider'
|
||||
import { createExecutor } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
|
||||
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import type { Assistant, Model, Provider } from '@renderer/types'
|
||||
import type { AiSdkModel, StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { type ImageModel, type LanguageModel, type Provider as AiSdkProvider, wrapLanguageModel } from 'ai'
|
||||
|
||||
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
|
||||
import LegacyAiProvider from './legacy/index'
|
||||
import { CompletionsResult } from './legacy/middleware/schemas'
|
||||
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
|
||||
import { buildPlugins } from './plugins/PluginBuilder'
|
||||
import { createAiSdkProvider } from './provider/factory'
|
||||
import {
|
||||
getActualProvider,
|
||||
isModernSdkSupported,
|
||||
@ -44,6 +45,7 @@ export default class ModernAiProvider {
|
||||
private config: ReturnType<typeof providerToAiSdkConfig>
|
||||
private actualProvider: Provider
|
||||
private model: Model
|
||||
private localProvider: Awaited<AiSdkProvider> | null = null
|
||||
|
||||
constructor(model: Model, provider?: Provider) {
|
||||
this.actualProvider = provider || getActualProvider(model)
|
||||
@ -62,44 +64,58 @@ export default class ModernAiProvider {
|
||||
// 准备特殊配置
|
||||
await prepareSpecialProviderConfig(this.actualProvider, this.config)
|
||||
|
||||
logger.debug('this.config', this.config)
|
||||
// 提前创建本地 provider 实例
|
||||
if (!this.localProvider) {
|
||||
this.localProvider = await createAiSdkProvider(this.config)
|
||||
}
|
||||
|
||||
// 提前构建中间件
|
||||
const middlewares = buildAiSdkMiddlewares({
|
||||
...config,
|
||||
provider: this.actualProvider
|
||||
})
|
||||
logger.debug('Built middlewares in completions', {
|
||||
middlewareCount: middlewares.length,
|
||||
isImageGeneration: config.isImageGenerationEndpoint
|
||||
})
|
||||
if (!this.localProvider) {
|
||||
throw new Error('Local provider not created')
|
||||
}
|
||||
|
||||
// 根据endpoint类型创建对应的模型
|
||||
let model: AiSdkModel | undefined
|
||||
if (config.isImageGenerationEndpoint) {
|
||||
model = this.localProvider.imageModel(modelId)
|
||||
} else {
|
||||
model = this.localProvider.languageModel(modelId)
|
||||
// 如果有中间件,应用到语言模型上
|
||||
if (middlewares.length > 0 && typeof model === 'object') {
|
||||
model = wrapLanguageModel({ model, middleware: middlewares })
|
||||
}
|
||||
}
|
||||
|
||||
if (config.topicId && getEnableDeveloperMode()) {
|
||||
// TypeScript类型窄化:确保topicId是string类型
|
||||
const traceConfig = {
|
||||
...config,
|
||||
topicId: config.topicId
|
||||
}
|
||||
return await this._completionsForTrace(modelId, params, traceConfig)
|
||||
return await this._completionsForTrace(model, params, traceConfig)
|
||||
} else {
|
||||
return await this._completions(modelId, params, config)
|
||||
return await this._completionsOrImageGeneration(model, params, config)
|
||||
}
|
||||
}
|
||||
|
||||
private async _completions(
|
||||
modelId: string,
|
||||
private async _completionsOrImageGeneration(
|
||||
model: AiSdkModel,
|
||||
params: StreamTextParams,
|
||||
config: ModernAiProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
// 初始化 provider 到全局管理器
|
||||
try {
|
||||
await createAndRegisterProvider(this.config.providerId, this.config.options)
|
||||
logger.debug('Provider initialized successfully', {
|
||||
providerId: this.config.providerId,
|
||||
hasOptions: !!this.config.options
|
||||
})
|
||||
} catch (error) {
|
||||
// 如果 provider 已经初始化过,可能会抛出错误,这里可以忽略
|
||||
logger.debug('Provider initialization skipped (may already be initialized)', {
|
||||
providerId: this.config.providerId,
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
})
|
||||
}
|
||||
|
||||
if (config.isImageGenerationEndpoint) {
|
||||
return await this.modernImageGeneration(modelId, params, config)
|
||||
return await this.modernImageGeneration(model as ImageModel, params, config)
|
||||
}
|
||||
|
||||
return await this.modernCompletions(modelId, params, config)
|
||||
return await this.modernCompletions(model as LanguageModel, params, config)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -107,10 +123,11 @@ export default class ModernAiProvider {
|
||||
* 类似于legacy的completionsForTrace,确保AI SDK spans在正确的trace上下文中
|
||||
*/
|
||||
private async _completionsForTrace(
|
||||
modelId: string,
|
||||
model: AiSdkModel,
|
||||
params: StreamTextParams,
|
||||
config: ModernAiProviderConfig & { topicId: string }
|
||||
): Promise<CompletionsResult> {
|
||||
const modelId = this.model.id
|
||||
const traceName = `${this.actualProvider.name}.${modelId}.${config.callType}`
|
||||
const traceParams: StartSpanParams = {
|
||||
name: traceName,
|
||||
@ -136,7 +153,7 @@ export default class ModernAiProvider {
|
||||
modelId,
|
||||
traceName
|
||||
})
|
||||
return await this._completions(modelId, params, config)
|
||||
return await this._completionsOrImageGeneration(model, params, config)
|
||||
}
|
||||
|
||||
try {
|
||||
@ -148,7 +165,7 @@ export default class ModernAiProvider {
|
||||
parentSpanCreated: true
|
||||
})
|
||||
|
||||
const result = await this._completions(modelId, params, config)
|
||||
const result = await this._completionsOrImageGeneration(model, params, config)
|
||||
|
||||
logger.info('Completions finished, ending parent span', {
|
||||
spanId: span.spanContext().spanId,
|
||||
@ -190,10 +207,11 @@ export default class ModernAiProvider {
|
||||
* 使用现代化AI SDK的completions实现
|
||||
*/
|
||||
private async modernCompletions(
|
||||
modelId: string,
|
||||
model: LanguageModel,
|
||||
params: StreamTextParams,
|
||||
config: ModernAiProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
const modelId = this.model.id
|
||||
logger.info('Starting modernCompletions', {
|
||||
modelId,
|
||||
providerId: this.config.providerId,
|
||||
@ -205,122 +223,48 @@ export default class ModernAiProvider {
|
||||
|
||||
// 根据条件构建插件数组
|
||||
const plugins = buildPlugins(config)
|
||||
logger.debug('Built plugins for AI SDK', {
|
||||
pluginCount: plugins.length,
|
||||
pluginNames: plugins.map((p) => p.name),
|
||||
providerId: this.config.providerId,
|
||||
topicId: config.topicId
|
||||
})
|
||||
|
||||
// 用构建好的插件数组创建executor
|
||||
const executor = createExecutor(this.config.providerId, this.config.options, plugins)
|
||||
logger.debug('Created AI SDK executor', {
|
||||
providerId: this.config.providerId,
|
||||
hasOptions: !!this.config.options,
|
||||
pluginCount: plugins.length
|
||||
})
|
||||
|
||||
// 动态构建中间件数组
|
||||
const middlewares = buildAiSdkMiddlewares(config)
|
||||
logger.debug('Built AI SDK middlewares', {
|
||||
middlewareCount: middlewares.length,
|
||||
topicId: config.topicId
|
||||
})
|
||||
|
||||
// 创建带有中间件的执行器
|
||||
if (config.onChunk) {
|
||||
// 流式处理 - 使用适配器
|
||||
logger.info('Starting streaming with chunk adapter', {
|
||||
modelId,
|
||||
hasMiddlewares: middlewares.length > 0,
|
||||
middlewareCount: middlewares.length,
|
||||
hasMcpTools: !!config.mcpTools,
|
||||
mcpToolCount: config.mcpTools?.length || 0,
|
||||
topicId: config.topicId
|
||||
})
|
||||
|
||||
const accumulate = this.model.supported_text_delta !== false // true and undefined
|
||||
const adapter = new AiSdkToChunkAdapter(config.onChunk, config.mcpTools, accumulate)
|
||||
|
||||
logger.debug('Final params before streamText', {
|
||||
modelId,
|
||||
hasMessages: !!params.messages,
|
||||
messageCount: params.messages?.length || 0,
|
||||
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
||||
toolNames: params.tools ? Object.keys(params.tools) : [],
|
||||
hasSystem: !!params.system,
|
||||
topicId: config.topicId
|
||||
})
|
||||
|
||||
const streamResult = await executor.streamText(
|
||||
modelId,
|
||||
{ ...params, experimental_context: { onChunk: config.onChunk } },
|
||||
middlewares.length > 0 ? { middlewares } : undefined
|
||||
)
|
||||
|
||||
logger.info('StreamText call successful, processing stream', {
|
||||
modelId,
|
||||
topicId: config.topicId,
|
||||
hasFullStream: !!streamResult.fullStream
|
||||
const streamResult = await executor.streamText({
|
||||
...params,
|
||||
model,
|
||||
experimental_context: { onChunk: config.onChunk }
|
||||
})
|
||||
|
||||
const finalText = await adapter.processStream(streamResult)
|
||||
|
||||
logger.info('Stream processing completed', {
|
||||
modelId,
|
||||
topicId: config.topicId,
|
||||
finalTextLength: finalText.length
|
||||
})
|
||||
|
||||
return {
|
||||
getText: () => finalText
|
||||
}
|
||||
} else {
|
||||
// 流式处理但没有 onChunk 回调
|
||||
logger.info('Starting streaming without chunk callback', {
|
||||
modelId,
|
||||
hasMiddlewares: middlewares.length > 0,
|
||||
middlewareCount: middlewares.length,
|
||||
topicId: config.topicId
|
||||
const streamResult = await executor.streamText({
|
||||
...params,
|
||||
model
|
||||
})
|
||||
|
||||
const streamResult = await executor.streamText(
|
||||
modelId,
|
||||
params,
|
||||
middlewares.length > 0 ? { middlewares } : undefined
|
||||
)
|
||||
|
||||
logger.info('StreamText call successful, waiting for text', {
|
||||
modelId,
|
||||
topicId: config.topicId
|
||||
})
|
||||
// 强制消费流,不然await streamResult.text会阻塞
|
||||
await streamResult?.consumeStream()
|
||||
|
||||
const finalText = await streamResult.text
|
||||
|
||||
logger.info('Text extraction completed', {
|
||||
modelId,
|
||||
topicId: config.topicId,
|
||||
finalTextLength: finalText.length
|
||||
})
|
||||
|
||||
return {
|
||||
getText: () => finalText
|
||||
}
|
||||
}
|
||||
// }
|
||||
// catch (error) {
|
||||
// console.error('Modern AI SDK error:', error)
|
||||
// throw error
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化 AI SDK 的图像生成实现,支持流式输出
|
||||
*/
|
||||
private async modernImageGeneration(
|
||||
modelId: string,
|
||||
model: ImageModel,
|
||||
params: StreamTextParams,
|
||||
config: ModernAiProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
@ -364,7 +308,11 @@ export default class ModernAiProvider {
|
||||
}
|
||||
|
||||
// 调用新 AI SDK 的图像生成功能
|
||||
const result = await generateImage(this.config.providerId, this.config.options, modelId, imageParams)
|
||||
const executor = createExecutor(this.config.providerId, this.config.options, [])
|
||||
const result = await executor.generateImage({
|
||||
model,
|
||||
...imageParams
|
||||
})
|
||||
|
||||
// 转换结果格式
|
||||
const images: string[] = []
|
||||
@ -441,51 +389,64 @@ export default class ModernAiProvider {
|
||||
return this.legacyProvider.getEmbeddingDimensions(model)
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
// 如果支持新的 AI SDK,使用现代化实现
|
||||
if (isModernSdkSupported(this.actualProvider)) {
|
||||
try {
|
||||
const result = await this.modernGenerateImage(params)
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.warn('Modern AI SDK generateImage failed, falling back to legacy:', error as Error)
|
||||
// fallback 到传统实现
|
||||
return this.legacyProvider.generateImage(params)
|
||||
}
|
||||
}
|
||||
// public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
// // 如果支持新的 AI SDK,使用现代化实现
|
||||
// if (isModernSdkSupported(this.actualProvider)) {
|
||||
// try {
|
||||
// // 确保本地provider已创建
|
||||
// if (!this.localProvider) {
|
||||
// await prepareSpecialProviderConfig(this.actualProvider, this.config)
|
||||
// this.localProvider = await createProvider(this.config.providerId, this.config.options)
|
||||
// logger.debug('Local provider created for standalone image generation', {
|
||||
// providerId: this.config.providerId
|
||||
// })
|
||||
// }
|
||||
|
||||
// 直接使用传统实现
|
||||
return this.legacyProvider.generateImage(params)
|
||||
}
|
||||
// const result = await this.modernGenerateImage(params)
|
||||
// return result
|
||||
// } catch (error) {
|
||||
// logger.warn('Modern AI SDK generateImage failed, falling back to legacy:', error as Error)
|
||||
// // fallback 到传统实现
|
||||
// return this.legacyProvider.generateImage(params)
|
||||
// }
|
||||
// }
|
||||
|
||||
/**
|
||||
* 使用现代化 AI SDK 的图像生成实现
|
||||
*/
|
||||
private async modernGenerateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
const { model, prompt, imageSize, batchSize, signal } = params
|
||||
// // 直接使用传统实现
|
||||
// return this.legacyProvider.generateImage(params)
|
||||
// }
|
||||
|
||||
// 转换参数格式
|
||||
const aiSdkParams = {
|
||||
prompt,
|
||||
size: (imageSize || '1024x1024') as `${number}x${number}`,
|
||||
n: batchSize || 1,
|
||||
...(signal && { abortSignal: signal })
|
||||
}
|
||||
// /**
|
||||
// * 使用现代化 AI SDK 的图像生成实现
|
||||
// */
|
||||
// private async modernGenerateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
// const { model, prompt, imageSize, batchSize, signal } = params
|
||||
|
||||
const result = await generateImage(this.config.providerId, this.config.options, model, aiSdkParams)
|
||||
// // 转换参数格式
|
||||
// const aiSdkParams = {
|
||||
// prompt,
|
||||
// size: (imageSize || '1024x1024') as `${number}x${number}`,
|
||||
// n: batchSize || 1,
|
||||
// ...(signal && { abortSignal: signal })
|
||||
// }
|
||||
|
||||
// 转换结果格式
|
||||
const images: string[] = []
|
||||
if (result.images) {
|
||||
for (const image of result.images) {
|
||||
if ('base64' in image && image.base64) {
|
||||
images.push(`data:image/png;base64,${image.base64}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
// const executor = createExecutor(this.config.providerId, this.config.options, [])
|
||||
// const result = await executor.generateImage({
|
||||
// model: this.localProvider?.imageModel(model) as ImageModel,
|
||||
// ...aiSdkParams
|
||||
// })
|
||||
|
||||
return images
|
||||
}
|
||||
// // 转换结果格式
|
||||
// const images: string[] = []
|
||||
// if (result.images) {
|
||||
// for (const image of result.images) {
|
||||
// if ('base64' in image && image.base64) {
|
||||
// images.push(`data:image/png;base64,${image.base64}`)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// return images
|
||||
// }
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.legacyProvider.getBaseURL()
|
||||
|
||||
@ -30,7 +30,7 @@ const logger = loggerService.withContext('parameterBuilder')
|
||||
* 这是主要的参数构建函数,整合所有转换逻辑
|
||||
*/
|
||||
export async function buildStreamTextParams(
|
||||
sdkMessages: StreamTextParams['messages'],
|
||||
sdkMessages: StreamTextParams['messages'] = [],
|
||||
assistant: Assistant,
|
||||
provider: Provider,
|
||||
options: {
|
||||
|
||||
@ -4,9 +4,8 @@
|
||||
import { isOpenAIModel } from '@renderer/config/models'
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
import { startsWith } from './helper'
|
||||
import { provider2Provider } from './helper'
|
||||
import type { ModelRule } from './types'
|
||||
import { provider2Provider, startsWith } from './helper'
|
||||
import type { RuleSet } from './types'
|
||||
|
||||
const extraProviderConfig = (provider: Provider) => {
|
||||
return {
|
||||
@ -18,41 +17,41 @@ const extraProviderConfig = (provider: Provider) => {
|
||||
}
|
||||
}
|
||||
|
||||
const AIHUBMIX_RULES: ModelRule[] = [
|
||||
{
|
||||
name: 'claude',
|
||||
match: startsWith('claude'),
|
||||
provider: (provider: Provider) => {
|
||||
return extraProviderConfig({
|
||||
...provider,
|
||||
type: 'anthropic'
|
||||
})
|
||||
const AIHUBMIX_RULES: RuleSet = {
|
||||
rules: [
|
||||
{
|
||||
match: startsWith('claude'),
|
||||
provider: (provider: Provider) => {
|
||||
return extraProviderConfig({
|
||||
...provider,
|
||||
type: 'anthropic'
|
||||
})
|
||||
}
|
||||
},
|
||||
{
|
||||
match: (model) =>
|
||||
(startsWith('gemini')(model) || startsWith('imagen')(model)) &&
|
||||
!model.id.endsWith('-nothink') &&
|
||||
!model.id.endsWith('-search'),
|
||||
provider: (provider: Provider) => {
|
||||
return extraProviderConfig({
|
||||
...provider,
|
||||
type: 'gemini',
|
||||
apiHost: 'https://aihubmix.com/gemini'
|
||||
})
|
||||
}
|
||||
},
|
||||
{
|
||||
match: isOpenAIModel,
|
||||
provider: (provider: Provider) => {
|
||||
return extraProviderConfig({
|
||||
...provider,
|
||||
type: 'openai-response'
|
||||
})
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'gemini',
|
||||
match: (model) =>
|
||||
(startsWith('gemini')(model) || startsWith('imagen')(model)) &&
|
||||
!model.id.endsWith('-nothink') &&
|
||||
!model.id.endsWith('-search'),
|
||||
provider: (provider: Provider) => {
|
||||
return extraProviderConfig({
|
||||
...provider,
|
||||
type: 'gemini',
|
||||
apiHost: 'https://aihubmix.com/gemini'
|
||||
})
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'openai',
|
||||
match: isOpenAIModel,
|
||||
provider: (provider: Provider) => {
|
||||
return extraProviderConfig({
|
||||
...provider,
|
||||
type: 'openai-response'
|
||||
})
|
||||
}
|
||||
}
|
||||
]
|
||||
],
|
||||
fallbackRule: (provider: Provider) => provider
|
||||
}
|
||||
|
||||
export const aihubmixProviderCreator = provider2Provider.bind(null, AIHUBMIX_RULES)
|
||||
|
||||
@ -1,22 +1,22 @@
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
|
||||
import type { ModelRule } from './types'
|
||||
import type { RuleSet } from './types'
|
||||
|
||||
export const startsWith = (prefix: string) => (model: Model) => model.id.toLowerCase().startsWith(prefix.toLowerCase())
|
||||
export const endpointIs = (type: string) => (model: Model) => model.endpoint_type === type
|
||||
|
||||
/**
|
||||
* 解析模型对应的Provider ID
|
||||
* 解析模型对应的Provider
|
||||
* @param ruleSet 规则集对象
|
||||
* @param model 模型对象
|
||||
* @param rules 匹配规则数组
|
||||
* @param fallback 默认fallback的providerId
|
||||
* @returns 解析出的providerId
|
||||
* @param provider 原始provider对象
|
||||
* @returns 解析出的provider对象
|
||||
*/
|
||||
export function provider2Provider(rules: ModelRule[], model: Model, provider: Provider): Provider {
|
||||
for (const rule of rules) {
|
||||
export function provider2Provider(ruleSet: RuleSet, model: Model, provider: Provider): Provider {
|
||||
for (const rule of ruleSet.rules) {
|
||||
if (rule.match(model)) {
|
||||
return rule.provider(provider)
|
||||
}
|
||||
}
|
||||
return provider
|
||||
return ruleSet.fallbackRule(provider)
|
||||
}
|
||||
|
||||
@ -1,16 +1,3 @@
|
||||
// /**
|
||||
// * Provider解析规则模块导出
|
||||
// */
|
||||
|
||||
// // 导出类型
|
||||
// export type { ModelRule } from './types'
|
||||
|
||||
// // 导出匹配函数和解析器
|
||||
// export { endpointIs, resolveProvider, startsWith } from './helper'
|
||||
|
||||
// // 导出规则集
|
||||
// export { AIHUBMIX_RULES } from './aihubmix'
|
||||
// export { NEWAPI_RULES } from './newApi'
|
||||
|
||||
export { aihubmixProviderCreator } from './aihubmix'
|
||||
export { newApiResolverCreator } from './newApi'
|
||||
export { vertexAnthropicProviderCreator } from './vertext-anthropic'
|
||||
|
||||
@ -4,49 +4,48 @@
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
import { endpointIs, provider2Provider } from './helper'
|
||||
import type { ModelRule } from './types'
|
||||
import type { RuleSet } from './types'
|
||||
|
||||
const NEWAPI_RULES: ModelRule[] = [
|
||||
{
|
||||
name: 'anthropic',
|
||||
match: endpointIs('anthropic'),
|
||||
provider: (provider: Provider) => {
|
||||
return {
|
||||
...provider,
|
||||
type: 'anthropic'
|
||||
const NEWAPI_RULES: RuleSet = {
|
||||
rules: [
|
||||
{
|
||||
match: endpointIs('anthropic'),
|
||||
provider: (provider: Provider) => {
|
||||
return {
|
||||
...provider,
|
||||
type: 'anthropic'
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
match: endpointIs('gemini'),
|
||||
provider: (provider: Provider) => {
|
||||
return {
|
||||
...provider,
|
||||
type: 'gemini'
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
match: endpointIs('openai-response'),
|
||||
provider: (provider: Provider) => {
|
||||
return {
|
||||
...provider,
|
||||
type: 'openai-response'
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model),
|
||||
provider: (provider: Provider) => {
|
||||
return {
|
||||
...provider,
|
||||
type: 'openai'
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'gemini',
|
||||
match: endpointIs('gemini'),
|
||||
provider: (provider: Provider) => {
|
||||
return {
|
||||
...provider,
|
||||
type: 'gemini'
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'openai-response',
|
||||
match: endpointIs('openai-response'),
|
||||
provider: (provider: Provider) => {
|
||||
return {
|
||||
...provider,
|
||||
type: 'openai-response'
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'openai',
|
||||
match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model),
|
||||
provider: (provider: Provider) => {
|
||||
return {
|
||||
...provider,
|
||||
type: 'openai'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
],
|
||||
fallbackRule: (provider: Provider) => provider
|
||||
}
|
||||
|
||||
export const newApiResolverCreator = provider2Provider.bind(null, NEWAPI_RULES)
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
|
||||
export interface ModelRule {
|
||||
name: string
|
||||
match: (model: Model) => boolean
|
||||
provider: (provider: Provider) => Provider
|
||||
export interface RuleSet {
|
||||
rules: Array<{
|
||||
match: (model: Model) => boolean
|
||||
provider: (provider: Provider) => Provider
|
||||
}>
|
||||
fallbackRule: (provider: Provider) => Provider
|
||||
}
|
||||
|
||||
19
src/renderer/src/aiCore/provider/config/vertext-anthropic.ts
Normal file
19
src/renderer/src/aiCore/provider/config/vertext-anthropic.ts
Normal file
@ -0,0 +1,19 @@
|
||||
import type { Provider } from '@renderer/types'
|
||||
|
||||
import { provider2Provider, startsWith } from './helper'
|
||||
import type { RuleSet } from './types'
|
||||
|
||||
const VERTEX_ANTHROPIC_RULES: RuleSet = {
|
||||
rules: [
|
||||
{
|
||||
match: startsWith('claude'),
|
||||
provider: (provider: Provider) => ({
|
||||
...provider,
|
||||
id: 'google-vertex-anthropic'
|
||||
})
|
||||
}
|
||||
],
|
||||
fallbackRule: (provider: Provider) => provider
|
||||
}
|
||||
|
||||
export const vertexAnthropicProviderCreator = provider2Provider.bind(null, VERTEX_ANTHROPIC_RULES)
|
||||
@ -1,6 +1,8 @@
|
||||
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
|
||||
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import { Provider } from '@renderer/types'
|
||||
import type { Provider as AiSdkProvider } from 'ai'
|
||||
|
||||
import { initializeNewProviders } from './providerInitialization'
|
||||
|
||||
@ -70,3 +72,28 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
|
||||
// 3. 最后的fallback(通常会成为openai-compatible)
|
||||
return provider.id as ProviderId
|
||||
}
|
||||
|
||||
export async function createAiSdkProvider(config) {
|
||||
let localProvider: Awaited<AiSdkProvider> | null = null
|
||||
try {
|
||||
if (config.providerId === 'openai' && config.options?.mode === 'chat') {
|
||||
config.providerId = `${config.providerId}-chat`
|
||||
} else if (config.providerId === 'azure' && config.options?.mode === 'responses') {
|
||||
config.providerId = `${config.providerId}-responses`
|
||||
}
|
||||
localProvider = await createProviderCore(config.providerId, config.options)
|
||||
|
||||
logger.debug('Local provider created successfully', {
|
||||
providerId: config.providerId,
|
||||
hasOptions: !!config.options,
|
||||
localProvider: localProvider,
|
||||
options: config.options
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to create local provider', error as Error, {
|
||||
providerId: config.providerId
|
||||
})
|
||||
throw error
|
||||
}
|
||||
return localProvider
|
||||
}
|
||||
|
||||
@ -19,7 +19,7 @@ import type { Model, Provider } from '@renderer/types'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import { cloneDeep, isEmpty } from 'lodash'
|
||||
|
||||
import { aihubmixProviderCreator, newApiResolverCreator } from './config'
|
||||
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
|
||||
import { getAiSdkProviderId } from './factory'
|
||||
|
||||
const logger = loggerService.withContext('ProviderConfigProcessor')
|
||||
@ -67,6 +67,9 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider {
|
||||
if (provider.id === 'newapi') {
|
||||
return newApiResolverCreator(model, provider)
|
||||
}
|
||||
if (provider.id === 'vertexai') {
|
||||
return vertexAnthropicProviderCreator(model, provider)
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
@ -157,7 +160,7 @@ export function providerToAiSdkConfig(
|
||||
extraOptions.secretAccessKey = getAwsBedrockSecretAccessKey()
|
||||
}
|
||||
// google-vertex
|
||||
if (aiSdkProviderId === 'google-vertex') {
|
||||
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
|
||||
if (!isVertexAIConfigured()) {
|
||||
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
|
||||
}
|
||||
|
||||
@ -24,6 +24,14 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
|
||||
supportsImageGeneration: true,
|
||||
aliases: ['vertexai']
|
||||
},
|
||||
{
|
||||
id: 'google-vertex-anthropic',
|
||||
name: 'Google Vertex AI Anthropic',
|
||||
import: () => import('@ai-sdk/google-vertex/anthropic/edge'),
|
||||
creatorFunctionName: 'createVertexAnthropic',
|
||||
supportsImageGeneration: true,
|
||||
aliases: ['vertexai-anthropic']
|
||||
},
|
||||
{
|
||||
id: 'bedrock',
|
||||
name: 'Amazon Bedrock',
|
||||
|
||||
@ -93,7 +93,6 @@ export async function fetchChatCompletion({
|
||||
modelId: assistant.model?.id,
|
||||
modelName: assistant.model?.name
|
||||
})
|
||||
|
||||
const AI = new AiProviderNew(assistant.model || getDefaultModel())
|
||||
const provider = AI.getActualProvider()
|
||||
|
||||
@ -126,7 +125,6 @@ export async function fetchChatCompletion({
|
||||
streamOutput: assistant.settings?.streamOutput ?? true,
|
||||
onChunk: onChunkReceived,
|
||||
model: assistant.model,
|
||||
provider: provider,
|
||||
enableReasoning: capabilities.enableReasoning,
|
||||
isPromptToolUse: isPromptToolUse(assistant),
|
||||
isSupportedToolUse: isSupportedToolUse(assistant),
|
||||
|
||||
@ -1,6 +1,29 @@
|
||||
import { generateObject, generateText, streamObject, streamText } from 'ai'
|
||||
import type { ImageModel, LanguageModel } from 'ai'
|
||||
import { generateObject, generateText, ModelMessage, streamObject, streamText } from 'ai'
|
||||
|
||||
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model'>
|
||||
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model'>
|
||||
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model' | 'messages'> &
|
||||
(
|
||||
| {
|
||||
prompt: string | Array<ModelMessage>
|
||||
messages?: never
|
||||
}
|
||||
| {
|
||||
messages: Array<ModelMessage>
|
||||
prompt?: never
|
||||
}
|
||||
)
|
||||
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model' | 'messages'> &
|
||||
(
|
||||
| {
|
||||
prompt: string | Array<ModelMessage>
|
||||
messages?: never
|
||||
}
|
||||
| {
|
||||
messages: Array<ModelMessage>
|
||||
prompt?: never
|
||||
}
|
||||
)
|
||||
export type StreamObjectParams = Omit<Parameters<typeof streamObject>[0], 'model'>
|
||||
export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'model'>
|
||||
|
||||
export type AiSdkModel = LanguageModel | ImageModel
|
||||
|
||||
@ -356,6 +356,7 @@ export type ProviderType =
|
||||
| 'vertexai'
|
||||
| 'mistral'
|
||||
| 'aws-bedrock'
|
||||
| 'vertex-anthropic'
|
||||
|
||||
export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search' | 'rerank'
|
||||
|
||||
|
||||
10
yarn.lock
10
yarn.lock
@ -9209,7 +9209,7 @@ __metadata:
|
||||
"@viz-js/lang-dot": "npm:^1.0.5"
|
||||
"@viz-js/viz": "npm:^3.14.0"
|
||||
"@xyflow/react": "npm:^12.4.4"
|
||||
ai: "npm:^5.0.26"
|
||||
ai: "npm:^5.0.29"
|
||||
antd: "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch"
|
||||
archiver: "npm:^7.0.1"
|
||||
async-mutex: "npm:^0.5.0"
|
||||
@ -9425,9 +9425,9 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"ai@npm:^5.0.26":
|
||||
version: 5.0.26
|
||||
resolution: "ai@npm:5.0.26"
|
||||
"ai@npm:^5.0.29":
|
||||
version: 5.0.29
|
||||
resolution: "ai@npm:5.0.29"
|
||||
dependencies:
|
||||
"@ai-sdk/gateway": "npm:1.0.15"
|
||||
"@ai-sdk/provider": "npm:2.0.0"
|
||||
@ -9435,7 +9435,7 @@ __metadata:
|
||||
"@opentelemetry/api": "npm:1.9.0"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4
|
||||
checksum: 10c0/0423f296b1aa9f22ad106278e8d1e7a2ae9d068358720cdc23c0f222af5406ac1e5ccbce19833709aa1c62b841361d310fb3c42781f426966a9f9ca287ae7faa
|
||||
checksum: 10c0/526cd2fd59b35b19d902665e3dc1ba5a09f2bb1377295d642fb8a33e13a890874e4dd4b49a787de7f31f4ec6b07257be8514efac08f993081daeb430cf2f60ba
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user