mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-10 07:19:02 +08:00
feat: integrate image generation capabilities and enhance testing framework
- Added support for image generation in the `RuntimeExecutor` with a new `generateImage` method. - Updated `aiCore` package to include `vitest` for testing, with new test scripts added. - Enhanced type definitions to accommodate image model handling in plugins. - Introduced new methods for resolving and executing image generation with plugins. - Updated package dependencies in `package.json` to include `vitest` and ensure compatibility with new features.
This commit is contained in:
parent
7216e9943c
commit
ecc08bd3f7
167
packages/aiCore/examples/image-generation.ts
Normal file
167
packages/aiCore/examples/image-generation.ts
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
/**
|
||||||
|
* Image Generation Example
|
||||||
|
* 演示如何使用 aiCore 的文生图功能
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { createExecutor, generateImage } from '../src/index'
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
// 方式1: 使用执行器实例
|
||||||
|
console.log('📸 创建 OpenAI 图像生成执行器...')
|
||||||
|
const executor = createExecutor('openai', {
|
||||||
|
apiKey: process.env.OPENAI_API_KEY!
|
||||||
|
})
|
||||||
|
|
||||||
|
try {
|
||||||
|
console.log('🎨 使用执行器生成图像...')
|
||||||
|
const result1 = await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A futuristic cityscape at sunset with flying cars',
|
||||||
|
size: '1024x1024',
|
||||||
|
n: 1
|
||||||
|
})
|
||||||
|
|
||||||
|
console.log('✅ 图像生成成功!')
|
||||||
|
console.log('📊 结果:', {
|
||||||
|
imagesCount: result1.images.length,
|
||||||
|
mediaType: result1.image.mediaType,
|
||||||
|
hasBase64: !!result1.image.base64,
|
||||||
|
providerMetadata: result1.providerMetadata
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
console.error('❌ 执行器生成失败:', error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 方式2: 使用直接调用 API
|
||||||
|
try {
|
||||||
|
console.log('🎨 使用直接 API 生成图像...')
|
||||||
|
const result2 = await generateImage('openai', { apiKey: process.env.OPENAI_API_KEY! }, 'dall-e-3', {
|
||||||
|
prompt: 'A magical forest with glowing mushrooms and fairy lights',
|
||||||
|
aspectRatio: '16:9',
|
||||||
|
providerOptions: {
|
||||||
|
openai: {
|
||||||
|
quality: 'hd',
|
||||||
|
style: 'vivid'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
console.log('✅ 直接 API 生成成功!')
|
||||||
|
console.log('📊 结果:', {
|
||||||
|
imagesCount: result2.images.length,
|
||||||
|
mediaType: result2.image.mediaType,
|
||||||
|
hasBase64: !!result2.image.base64
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
console.error('❌ 直接 API 生成失败:', error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 方式3: 支持其他提供商 (Google Imagen)
|
||||||
|
if (process.env.GOOGLE_API_KEY) {
|
||||||
|
try {
|
||||||
|
console.log('🎨 使用 Google Imagen 生成图像...')
|
||||||
|
const googleExecutor = createExecutor('google', {
|
||||||
|
apiKey: process.env.GOOGLE_API_KEY!
|
||||||
|
})
|
||||||
|
|
||||||
|
const result3 = await googleExecutor.generateImage('imagen-3.0-generate-002', {
|
||||||
|
prompt: 'A serene mountain lake at dawn with mist rising from the water',
|
||||||
|
aspectRatio: '1:1'
|
||||||
|
})
|
||||||
|
|
||||||
|
console.log('✅ Google Imagen 生成成功!')
|
||||||
|
console.log('📊 结果:', {
|
||||||
|
imagesCount: result3.images.length,
|
||||||
|
mediaType: result3.image.mediaType,
|
||||||
|
hasBase64: !!result3.image.base64
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
console.error('❌ Google Imagen 生成失败:', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 方式4: 支持插件系统
|
||||||
|
const pluginExample = async () => {
|
||||||
|
console.log('🔌 演示插件系统...')
|
||||||
|
|
||||||
|
// 创建一个示例插件,用于修改提示词
|
||||||
|
const promptEnhancerPlugin = {
|
||||||
|
name: 'prompt-enhancer',
|
||||||
|
transformParams: async (params: any) => {
|
||||||
|
console.log('🔧 插件: 增强提示词...')
|
||||||
|
return {
|
||||||
|
...params,
|
||||||
|
prompt: `${params.prompt}, highly detailed, cinematic lighting, 4K resolution`
|
||||||
|
}
|
||||||
|
},
|
||||||
|
transformResult: async (result: any) => {
|
||||||
|
console.log('🔧 插件: 处理结果...')
|
||||||
|
return {
|
||||||
|
...result,
|
||||||
|
enhanced: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const executorWithPlugin = createExecutor(
|
||||||
|
'openai',
|
||||||
|
{
|
||||||
|
apiKey: process.env.OPENAI_API_KEY!
|
||||||
|
},
|
||||||
|
[promptEnhancerPlugin]
|
||||||
|
)
|
||||||
|
|
||||||
|
try {
|
||||||
|
const result4 = await executorWithPlugin.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A cute robot playing in a garden'
|
||||||
|
})
|
||||||
|
|
||||||
|
console.log('✅ 插件系统生成成功!')
|
||||||
|
console.log('📊 结果:', {
|
||||||
|
imagesCount: result4.images.length,
|
||||||
|
enhanced: (result4 as any).enhanced,
|
||||||
|
mediaType: result4.image.mediaType
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
console.error('❌ 插件系统生成失败:', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await pluginExample()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 错误处理演示
|
||||||
|
async function errorHandlingExample() {
|
||||||
|
console.log('⚠️ 演示错误处理...')
|
||||||
|
|
||||||
|
try {
|
||||||
|
const executor = createExecutor('openai', {
|
||||||
|
apiKey: 'invalid-key'
|
||||||
|
})
|
||||||
|
|
||||||
|
await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'Test image'
|
||||||
|
})
|
||||||
|
} catch (error: any) {
|
||||||
|
console.log('✅ 成功捕获错误:', error.constructor.name)
|
||||||
|
console.log('📋 错误信息:', error.message)
|
||||||
|
console.log('🏷️ 提供商ID:', error.providerId)
|
||||||
|
console.log('🏷️ 模型ID:', error.modelId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 运行示例
|
||||||
|
if (require.main === module) {
|
||||||
|
main()
|
||||||
|
.then(() => {
|
||||||
|
console.log('🎉 所有示例完成!')
|
||||||
|
return errorHandlingExample()
|
||||||
|
})
|
||||||
|
.then(() => {
|
||||||
|
console.log('🎯 示例程序结束')
|
||||||
|
process.exit(0)
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('💥 程序执行出错:', error)
|
||||||
|
process.exit(1)
|
||||||
|
})
|
||||||
|
}
|
||||||
@ -9,7 +9,9 @@
|
|||||||
"scripts": {
|
"scripts": {
|
||||||
"build": "tsdown",
|
"build": "tsdown",
|
||||||
"dev": "tsc -w",
|
"dev": "tsc -w",
|
||||||
"clean": "rm -rf dist"
|
"clean": "rm -rf dist",
|
||||||
|
"test": "vitest run",
|
||||||
|
"test:watch": "vitest"
|
||||||
},
|
},
|
||||||
"keywords": [
|
"keywords": [
|
||||||
"ai",
|
"ai",
|
||||||
@ -45,7 +47,8 @@
|
|||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"tsdown": "^0.12.9",
|
"tsdown": "^0.12.9",
|
||||||
"typescript": "^5.0.0"
|
"typescript": "^5.0.0",
|
||||||
|
"vitest": "^1.0.0"
|
||||||
},
|
},
|
||||||
"sideEffects": false,
|
"sideEffects": false,
|
||||||
"engines": {
|
"engines": {
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import type { ImageModelV2 } from '@ai-sdk/provider'
|
||||||
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
|
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
|
||||||
|
|
||||||
import { type ProviderId } from '../providers/types'
|
import { type ProviderId } from '../providers/types'
|
||||||
@ -32,7 +33,10 @@ export interface AiPlugin {
|
|||||||
enforce?: 'pre' | 'post'
|
enforce?: 'pre' | 'post'
|
||||||
|
|
||||||
// 【First】首个钩子 - 只执行第一个返回值的插件
|
// 【First】首个钩子 - 只执行第一个返回值的插件
|
||||||
resolveModel?: (modelId: string, context: AiRequestContext) => Promise<LanguageModel | null> | LanguageModel | null
|
resolveModel?: (
|
||||||
|
modelId: string,
|
||||||
|
context: AiRequestContext
|
||||||
|
) => Promise<LanguageModel | ImageModelV2 | null> | LanguageModel | ImageModelV2 | null
|
||||||
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
|
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
|
||||||
|
|
||||||
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
||||||
|
|||||||
533
packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts
Normal file
533
packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts
Normal file
@ -0,0 +1,533 @@
|
|||||||
|
import { ImageModelV2 } from '@ai-sdk/provider'
|
||||||
|
import { experimental_generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
import { createImageModel } from '../../models/ModelCreator'
|
||||||
|
import { type AiPlugin } from '../../plugins'
|
||||||
|
import { ImageGenerationError } from '../errors'
|
||||||
|
import { RuntimeExecutor } from '../executor'
|
||||||
|
|
||||||
|
// Mock dependencies
|
||||||
|
vi.mock('ai', () => ({
|
||||||
|
experimental_generateImage: vi.fn(),
|
||||||
|
NoImageGeneratedError: class NoImageGeneratedError extends Error {
|
||||||
|
static isInstance = vi.fn()
|
||||||
|
constructor() {
|
||||||
|
super('No image generated')
|
||||||
|
this.name = 'NoImageGeneratedError'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../models/ModelCreator', () => ({
|
||||||
|
createImageModel: vi.fn()
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('RuntimeExecutor.generateImage', () => {
|
||||||
|
let executor: RuntimeExecutor<'openai'>
|
||||||
|
let mockImageModel: ImageModelV2
|
||||||
|
let mockGenerateImageResult: any
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
// Reset all mocks
|
||||||
|
vi.clearAllMocks()
|
||||||
|
|
||||||
|
// Create executor instance
|
||||||
|
executor = RuntimeExecutor.create('openai', {
|
||||||
|
apiKey: 'test-key'
|
||||||
|
})
|
||||||
|
|
||||||
|
// Mock image model
|
||||||
|
mockImageModel = {
|
||||||
|
modelId: 'dall-e-3',
|
||||||
|
provider: 'openai'
|
||||||
|
} as ImageModelV2
|
||||||
|
|
||||||
|
// Mock generateImage result
|
||||||
|
mockGenerateImageResult = {
|
||||||
|
image: {
|
||||||
|
base64: 'base64-encoded-image-data',
|
||||||
|
uint8Array: new Uint8Array([1, 2, 3]),
|
||||||
|
mediaType: 'image/png'
|
||||||
|
},
|
||||||
|
images: [
|
||||||
|
{
|
||||||
|
base64: 'base64-encoded-image-data',
|
||||||
|
uint8Array: new Uint8Array([1, 2, 3]),
|
||||||
|
mediaType: 'image/png'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
warnings: [],
|
||||||
|
providerMetadata: {
|
||||||
|
openai: {
|
||||||
|
images: [{ revisedPrompt: 'A detailed prompt' }]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
responses: []
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup mocks
|
||||||
|
vi.mocked(createImageModel).mockResolvedValue(mockImageModel)
|
||||||
|
vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult)
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Basic functionality', () => {
|
||||||
|
it('should generate a single image with minimal parameters', async () => {
|
||||||
|
const result = await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A futuristic cityscape at sunset'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(createImageModel).toHaveBeenCalledWith('openai', 'dall-e-3', { apiKey: 'test-key' })
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A futuristic cityscape at sunset'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toEqual(mockGenerateImageResult)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should generate image with pre-created model', async () => {
|
||||||
|
const result = await executor.generateImage(mockImageModel, {
|
||||||
|
prompt: 'A beautiful landscape'
|
||||||
|
})
|
||||||
|
|
||||||
|
// Note: createImageModel may still be called due to resolveImageModel logic
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A beautiful landscape'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toEqual(mockGenerateImageResult)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support multiple images generation', async () => {
|
||||||
|
await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A futuristic cityscape',
|
||||||
|
n: 3
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A futuristic cityscape',
|
||||||
|
n: 3
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support size specification', async () => {
|
||||||
|
await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A beautiful sunset',
|
||||||
|
size: '1024x1024'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A beautiful sunset',
|
||||||
|
size: '1024x1024'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support aspect ratio specification', async () => {
|
||||||
|
await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A mountain landscape',
|
||||||
|
aspectRatio: '16:9'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A mountain landscape',
|
||||||
|
aspectRatio: '16:9'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support seed for consistent output', async () => {
|
||||||
|
await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A cat in space',
|
||||||
|
seed: 1234567890
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A cat in space',
|
||||||
|
seed: 1234567890
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support abort signal', async () => {
|
||||||
|
const abortController = new AbortController()
|
||||||
|
|
||||||
|
await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A cityscape',
|
||||||
|
abortSignal: abortController.signal
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A cityscape',
|
||||||
|
abortSignal: abortController.signal
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support provider-specific options', async () => {
|
||||||
|
await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A space station',
|
||||||
|
providerOptions: {
|
||||||
|
openai: {
|
||||||
|
quality: 'hd',
|
||||||
|
style: 'vivid'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A space station',
|
||||||
|
providerOptions: {
|
||||||
|
openai: {
|
||||||
|
quality: 'hd',
|
||||||
|
style: 'vivid'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support custom headers', async () => {
|
||||||
|
await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A robot',
|
||||||
|
headers: {
|
||||||
|
'X-Custom-Header': 'test-value'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A robot',
|
||||||
|
headers: {
|
||||||
|
'X-Custom-Header': 'test-value'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Plugin integration', () => {
|
||||||
|
it('should execute plugins in correct order', async () => {
|
||||||
|
const pluginCallOrder: string[] = []
|
||||||
|
|
||||||
|
const testPlugin: AiPlugin = {
|
||||||
|
name: 'test-plugin',
|
||||||
|
onRequestStart: vi.fn(async () => {
|
||||||
|
pluginCallOrder.push('onRequestStart')
|
||||||
|
}),
|
||||||
|
transformParams: vi.fn(async (params) => {
|
||||||
|
pluginCallOrder.push('transformParams')
|
||||||
|
return { ...params, size: '512x512' }
|
||||||
|
}),
|
||||||
|
transformResult: vi.fn(async (result) => {
|
||||||
|
pluginCallOrder.push('transformResult')
|
||||||
|
return { ...result, processed: true }
|
||||||
|
}),
|
||||||
|
onRequestEnd: vi.fn(async () => {
|
||||||
|
pluginCallOrder.push('onRequestEnd')
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const executorWithPlugin = RuntimeExecutor.create(
|
||||||
|
'openai',
|
||||||
|
{
|
||||||
|
apiKey: 'test-key'
|
||||||
|
},
|
||||||
|
[testPlugin]
|
||||||
|
)
|
||||||
|
|
||||||
|
const result = await executorWithPlugin.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A test image'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
|
||||||
|
|
||||||
|
expect(testPlugin.transformParams).toHaveBeenCalledWith(
|
||||||
|
{ prompt: 'A test image' },
|
||||||
|
expect.objectContaining({
|
||||||
|
providerId: 'openai',
|
||||||
|
modelId: 'dall-e-3'
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A test image',
|
||||||
|
size: '512x512' // Should be transformed by plugin
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
...mockGenerateImageResult,
|
||||||
|
processed: true // Should be transformed by plugin
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle model resolution through plugins', async () => {
|
||||||
|
const customImageModel = {
|
||||||
|
modelId: 'custom-model',
|
||||||
|
provider: 'openai'
|
||||||
|
} as ImageModelV2
|
||||||
|
|
||||||
|
const modelResolutionPlugin: AiPlugin = {
|
||||||
|
name: 'model-resolver',
|
||||||
|
resolveModel: vi.fn(async () => customImageModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
const executorWithPlugin = RuntimeExecutor.create(
|
||||||
|
'openai',
|
||||||
|
{
|
||||||
|
apiKey: 'test-key'
|
||||||
|
},
|
||||||
|
[modelResolutionPlugin]
|
||||||
|
)
|
||||||
|
|
||||||
|
await executorWithPlugin.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A test image'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith(
|
||||||
|
'dall-e-3',
|
||||||
|
expect.objectContaining({
|
||||||
|
providerId: 'openai',
|
||||||
|
modelId: 'dall-e-3'
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: customImageModel,
|
||||||
|
prompt: 'A test image'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support recursive calls from plugins', async () => {
|
||||||
|
const recursivePlugin: AiPlugin = {
|
||||||
|
name: 'recursive-plugin',
|
||||||
|
transformParams: vi.fn(async (params, context) => {
|
||||||
|
if (!context.isRecursiveCall && params.prompt === 'original') {
|
||||||
|
// Make a recursive call with modified prompt
|
||||||
|
await context.recursiveCall({
|
||||||
|
prompt: 'modified'
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const executorWithPlugin = RuntimeExecutor.create(
|
||||||
|
'openai',
|
||||||
|
{
|
||||||
|
apiKey: 'test-key'
|
||||||
|
},
|
||||||
|
[recursivePlugin]
|
||||||
|
)
|
||||||
|
|
||||||
|
await executorWithPlugin.generateImage('dall-e-3', {
|
||||||
|
prompt: 'original'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(recursivePlugin.transformParams).toHaveBeenCalledTimes(2)
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledTimes(2)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Error handling', () => {
|
||||||
|
it('should handle model creation errors', async () => {
|
||||||
|
const modelError = new Error('Failed to create image model')
|
||||||
|
vi.mocked(createImageModel).mockRejectedValue(modelError)
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
executor.generateImage('invalid-model', {
|
||||||
|
prompt: 'A test image'
|
||||||
|
})
|
||||||
|
).rejects.toThrow(ImageGenerationError)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle image generation API errors', async () => {
|
||||||
|
const apiError = new Error('API request failed')
|
||||||
|
vi.mocked(aiGenerateImage).mockRejectedValue(apiError)
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A test image'
|
||||||
|
})
|
||||||
|
).rejects.toThrow('Failed to generate image: API request failed')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle NoImageGeneratedError', async () => {
|
||||||
|
const noImageError = new NoImageGeneratedError({
|
||||||
|
cause: new Error('No image generated'),
|
||||||
|
responses: []
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.mocked(aiGenerateImage).mockRejectedValue(noImageError)
|
||||||
|
vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true)
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A test image'
|
||||||
|
})
|
||||||
|
).rejects.toThrow('Failed to generate image: No image generated')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should execute onError plugin hook on failure', async () => {
|
||||||
|
const error = new Error('Generation failed')
|
||||||
|
vi.mocked(aiGenerateImage).mockRejectedValue(error)
|
||||||
|
|
||||||
|
const errorPlugin: AiPlugin = {
|
||||||
|
name: 'error-handler',
|
||||||
|
onError: vi.fn()
|
||||||
|
}
|
||||||
|
|
||||||
|
const executorWithPlugin = RuntimeExecutor.create(
|
||||||
|
'openai',
|
||||||
|
{
|
||||||
|
apiKey: 'test-key'
|
||||||
|
},
|
||||||
|
[errorPlugin]
|
||||||
|
)
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
executorWithPlugin.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A test image'
|
||||||
|
})
|
||||||
|
).rejects.toThrow('Failed to generate image: Generation failed')
|
||||||
|
|
||||||
|
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
||||||
|
error,
|
||||||
|
expect.objectContaining({
|
||||||
|
providerId: 'openai',
|
||||||
|
modelId: 'dall-e-3'
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle abort signal timeout', async () => {
|
||||||
|
const abortError = new Error('Operation was aborted')
|
||||||
|
abortError.name = 'AbortError'
|
||||||
|
vi.mocked(aiGenerateImage).mockRejectedValue(abortError)
|
||||||
|
|
||||||
|
const abortController = new AbortController()
|
||||||
|
setTimeout(() => abortController.abort(), 10)
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A test image',
|
||||||
|
abortSignal: abortController.signal
|
||||||
|
})
|
||||||
|
).rejects.toThrow('Operation was aborted')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Multiple providers support', () => {
|
||||||
|
it('should work with different providers', async () => {
|
||||||
|
const googleExecutor = RuntimeExecutor.create('google', {
|
||||||
|
apiKey: 'google-key'
|
||||||
|
})
|
||||||
|
|
||||||
|
await googleExecutor.generateImage('imagen-3.0-generate-002', {
|
||||||
|
prompt: 'A landscape'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(createImageModel).toHaveBeenCalledWith('google', 'imagen-3.0-generate-002', { apiKey: 'google-key' })
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support xAI Grok image models', async () => {
|
||||||
|
const xaiExecutor = RuntimeExecutor.create('xai', {
|
||||||
|
apiKey: 'xai-key'
|
||||||
|
})
|
||||||
|
|
||||||
|
await xaiExecutor.generateImage('grok-2-image', {
|
||||||
|
prompt: 'A futuristic robot'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(createImageModel).toHaveBeenCalledWith('xai', 'grok-2-image', { apiKey: 'xai-key' })
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Advanced features', () => {
|
||||||
|
it('should support batch image generation with maxImagesPerCall', async () => {
|
||||||
|
await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A test image',
|
||||||
|
n: 10,
|
||||||
|
maxImagesPerCall: 5
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A test image',
|
||||||
|
n: 10,
|
||||||
|
maxImagesPerCall: 5
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support retries with maxRetries', async () => {
|
||||||
|
await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A test image',
|
||||||
|
maxRetries: 3
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A test image',
|
||||||
|
maxRetries: 3
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle warnings from the model', async () => {
|
||||||
|
const resultWithWarnings = {
|
||||||
|
...mockGenerateImageResult,
|
||||||
|
warnings: [
|
||||||
|
{
|
||||||
|
type: 'unsupported-setting',
|
||||||
|
message: 'Size parameter not supported for this model'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithWarnings)
|
||||||
|
|
||||||
|
const result = await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A test image',
|
||||||
|
size: '2048x2048' // Unsupported size
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.warnings).toHaveLength(1)
|
||||||
|
expect(result.warnings[0].type).toBe('unsupported-setting')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should provide access to provider metadata', async () => {
|
||||||
|
const result = await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A test image'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.providerMetadata).toBeDefined()
|
||||||
|
expect(result.providerMetadata.openai).toBeDefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should provide response metadata', async () => {
|
||||||
|
const resultWithMetadata = {
|
||||||
|
...mockGenerateImageResult,
|
||||||
|
responses: [
|
||||||
|
{
|
||||||
|
timestamp: new Date(),
|
||||||
|
modelId: 'dall-e-3',
|
||||||
|
headers: { 'x-request-id': 'test-123' }
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithMetadata)
|
||||||
|
|
||||||
|
const result = await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A test image'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.responses).toHaveLength(1)
|
||||||
|
expect(result.responses[0].modelId).toBe('dall-e-3')
|
||||||
|
expect(result.responses[0].headers).toEqual({ 'x-request-id': 'test-123' })
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
38
packages/aiCore/src/core/runtime/errors.ts
Normal file
38
packages/aiCore/src/core/runtime/errors.ts
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
/**
|
||||||
|
* Error classes for runtime operations
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Error thrown when image generation fails
|
||||||
|
*/
|
||||||
|
export class ImageGenerationError extends Error {
|
||||||
|
constructor(
|
||||||
|
message: string,
|
||||||
|
public providerId?: string,
|
||||||
|
public modelId?: string,
|
||||||
|
public cause?: Error
|
||||||
|
) {
|
||||||
|
super(message)
|
||||||
|
this.name = 'ImageGenerationError'
|
||||||
|
|
||||||
|
// Maintain proper stack trace (for V8 engines)
|
||||||
|
if (Error.captureStackTrace) {
|
||||||
|
Error.captureStackTrace(this, ImageGenerationError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Error thrown when model resolution fails during image generation
|
||||||
|
*/
|
||||||
|
export class ImageModelResolutionError extends ImageGenerationError {
|
||||||
|
constructor(modelId: string, providerId?: string, cause?: Error) {
|
||||||
|
super(
|
||||||
|
`Failed to resolve image model: ${modelId}${providerId ? ` for provider: ${providerId}` : ''}`,
|
||||||
|
providerId,
|
||||||
|
modelId,
|
||||||
|
cause
|
||||||
|
)
|
||||||
|
this.name = 'ImageModelResolutionError'
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -2,13 +2,21 @@
|
|||||||
* 运行时执行器
|
* 运行时执行器
|
||||||
* 专注于插件化的AI调用处理
|
* 专注于插件化的AI调用处理
|
||||||
*/
|
*/
|
||||||
import { LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
import { ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||||
import { generateObject, generateText, LanguageModel, streamObject, streamText } from 'ai'
|
import {
|
||||||
|
experimental_generateImage as generateImage,
|
||||||
|
generateObject,
|
||||||
|
generateText,
|
||||||
|
LanguageModel,
|
||||||
|
streamObject,
|
||||||
|
streamText
|
||||||
|
} from 'ai'
|
||||||
|
|
||||||
import { type ProviderId } from '../../types'
|
import { type ProviderId } from '../../types'
|
||||||
import { createModel, getProviderInfo } from '../models'
|
import { createImageModel, createModel, getProviderInfo } from '../models'
|
||||||
import { type ModelConfig } from '../models/types'
|
import { type ModelConfig } from '../models/types'
|
||||||
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
||||||
|
import { ImageGenerationError, ImageModelResolutionError } from './errors'
|
||||||
import { PluginEngine } from './pluginEngine'
|
import { PluginEngine } from './pluginEngine'
|
||||||
import { type RuntimeConfig } from './types'
|
import { type RuntimeConfig } from './types'
|
||||||
|
|
||||||
@ -42,6 +50,17 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
createResolveImageModelPlugin() {
|
||||||
|
return definePlugin({
|
||||||
|
name: '_internal_resolveImageModel',
|
||||||
|
enforce: 'post',
|
||||||
|
|
||||||
|
resolveModel: async (modelId: string) => {
|
||||||
|
return await this.resolveImageModel(modelId)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
createConfigureContextPlugin() {
|
createConfigureContextPlugin() {
|
||||||
return definePlugin({
|
return definePlugin({
|
||||||
name: '_internal_configureContext',
|
name: '_internal_configureContext',
|
||||||
@ -203,6 +222,48 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 生成图像 - 使用已创建的图像模型
|
||||||
|
*/
|
||||||
|
async generateImage(
|
||||||
|
model: ImageModelV2,
|
||||||
|
params: Omit<Parameters<typeof generateImage>[0], 'model'>
|
||||||
|
): Promise<ReturnType<typeof generateImage>>
|
||||||
|
async generateImage(
|
||||||
|
modelId: string,
|
||||||
|
params: Omit<Parameters<typeof generateImage>[0], 'model'>,
|
||||||
|
options?: {
|
||||||
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
}
|
||||||
|
): Promise<ReturnType<typeof generateImage>>
|
||||||
|
async generateImage(
|
||||||
|
modelOrId: ImageModelV2 | string,
|
||||||
|
params: Omit<Parameters<typeof generateImage>[0], 'model'>
|
||||||
|
): Promise<ReturnType<typeof generateImage>> {
|
||||||
|
try {
|
||||||
|
this.pluginEngine.usePlugins([this.createResolveImageModelPlugin(), this.createConfigureContextPlugin()])
|
||||||
|
|
||||||
|
return await this.pluginEngine.executeImageWithPlugins(
|
||||||
|
'generateImage',
|
||||||
|
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||||
|
params,
|
||||||
|
async (model, transformedParams) => {
|
||||||
|
return await generateImage({ model, ...transformedParams })
|
||||||
|
}
|
||||||
|
)
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error) {
|
||||||
|
throw new ImageGenerationError(
|
||||||
|
`Failed to generate image: ${error.message}`,
|
||||||
|
this.config.providerId,
|
||||||
|
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||||
|
error
|
||||||
|
)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// === 辅助方法 ===
|
// === 辅助方法 ===
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -228,6 +289,27 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析图像模型:如果是字符串则创建图像模型,如果是模型则直接返回
|
||||||
|
*/
|
||||||
|
private async resolveImageModel(modelOrId: ImageModelV2 | string): Promise<ImageModelV2> {
|
||||||
|
try {
|
||||||
|
if (typeof modelOrId === 'string') {
|
||||||
|
// 字符串modelId,需要创建图像模型
|
||||||
|
return await createImageModel(this.config.providerId, modelOrId, this.config.providerSettings)
|
||||||
|
} else {
|
||||||
|
// 已经是模型,直接返回
|
||||||
|
return modelOrId
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
throw new ImageModelResolutionError(
|
||||||
|
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||||
|
this.config.providerId,
|
||||||
|
error instanceof Error ? error : undefined
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取客户端信息
|
* 获取客户端信息
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -100,6 +100,21 @@ export async function streamObject<T extends ProviderId>(
|
|||||||
return executor.streamObject(modelId, params, { middlewares })
|
return executor.streamObject(modelId, params, { middlewares })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 直接生成图像 - 支持middlewares
|
||||||
|
*/
|
||||||
|
export async function generateImage<T extends ProviderId>(
|
||||||
|
providerId: T,
|
||||||
|
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||||
|
modelId: string,
|
||||||
|
params: Parameters<RuntimeExecutor<T>['generateImage']>[1],
|
||||||
|
plugins?: AiPlugin[],
|
||||||
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
): Promise<ReturnType<RuntimeExecutor<T>['generateImage']>> {
|
||||||
|
const executor = createExecutor(providerId, options, plugins)
|
||||||
|
return executor.generateImage(modelId, params, { middlewares })
|
||||||
|
}
|
||||||
|
|
||||||
// === Agent 功能预留 ===
|
// === Agent 功能预留 ===
|
||||||
// 未来将在 ../agents/ 文件夹中添加:
|
// 未来将在 ../agents/ 文件夹中添加:
|
||||||
// - AgentExecutor.ts
|
// - AgentExecutor.ts
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import { ImageModelV2 } from '@ai-sdk/provider'
|
||||||
import { LanguageModel } from 'ai'
|
import { LanguageModel } from 'ai'
|
||||||
|
|
||||||
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
||||||
@ -113,6 +114,62 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行带插件的图像生成操作
|
||||||
|
* 提供给AiExecutor使用
|
||||||
|
*/
|
||||||
|
async executeImageWithPlugins<TParams, TResult>(
|
||||||
|
methodName: string,
|
||||||
|
modelId: string,
|
||||||
|
params: TParams,
|
||||||
|
executor: (model: ImageModelV2, transformedParams: TParams) => Promise<TResult>,
|
||||||
|
_context?: ReturnType<typeof createContext>
|
||||||
|
): Promise<TResult> {
|
||||||
|
// 使用正确的createContext创建请求上下文
|
||||||
|
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||||
|
|
||||||
|
// 🔥 为上下文添加递归调用能力
|
||||||
|
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||||
|
// 递归调用自身,重新走完整的插件流程
|
||||||
|
context.isRecursiveCall = true
|
||||||
|
const result = await this.executeImageWithPlugins(methodName, modelId, newParams, executor, context)
|
||||||
|
context.isRecursiveCall = false
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 0. 配置上下文
|
||||||
|
await this.pluginManager.executeConfigureContext(context)
|
||||||
|
|
||||||
|
// 1. 触发请求开始事件
|
||||||
|
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||||
|
|
||||||
|
// 2. 解析模型
|
||||||
|
const model = await this.pluginManager.executeFirst<ImageModelV2>('resolveModel', modelId, context)
|
||||||
|
if (!model) {
|
||||||
|
throw new Error(`Failed to resolve image model: ${modelId}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 转换请求参数
|
||||||
|
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||||
|
|
||||||
|
// 4. 执行具体的 API 调用
|
||||||
|
const result = await executor(model, transformedParams)
|
||||||
|
|
||||||
|
// 5. 转换结果
|
||||||
|
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||||
|
|
||||||
|
// 6. 触发完成事件
|
||||||
|
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||||
|
|
||||||
|
return transformedResult
|
||||||
|
} catch (error) {
|
||||||
|
// 7. 触发错误事件
|
||||||
|
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 执行流式调用的通用逻辑(支持流转换器)
|
* 执行流式调用的通用逻辑(支持流转换器)
|
||||||
* 提供给AiExecutor使用
|
* 提供给AiExecutor使用
|
||||||
|
|||||||
@ -17,6 +17,7 @@ import { createExecutor } from './core/runtime'
|
|||||||
export {
|
export {
|
||||||
createExecutor,
|
createExecutor,
|
||||||
createOpenAICompatibleExecutor,
|
createOpenAICompatibleExecutor,
|
||||||
|
generateImage,
|
||||||
generateObject,
|
generateObject,
|
||||||
generateText,
|
generateText,
|
||||||
streamText
|
streamText
|
||||||
|
|||||||
13
packages/aiCore/vitest.config.ts
Normal file
13
packages/aiCore/vitest.config.ts
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import { defineConfig } from 'vitest/config'
|
||||||
|
|
||||||
|
export default defineConfig({
|
||||||
|
test: {
|
||||||
|
environment: 'node',
|
||||||
|
globals: true
|
||||||
|
},
|
||||||
|
resolve: {
|
||||||
|
alias: {
|
||||||
|
'@': './src'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
@ -12,6 +12,7 @@ import {
|
|||||||
AiCore,
|
AiCore,
|
||||||
AiPlugin,
|
AiPlugin,
|
||||||
createExecutor,
|
createExecutor,
|
||||||
|
generateImage,
|
||||||
ProviderConfigFactory,
|
ProviderConfigFactory,
|
||||||
type ProviderId,
|
type ProviderId,
|
||||||
type ProviderSettingsMap,
|
type ProviderSettingsMap,
|
||||||
@ -22,6 +23,7 @@ import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
|||||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||||
|
import { ChunkType } from '@renderer/types/chunk'
|
||||||
import { formatApiHost } from '@renderer/utils/api'
|
import { formatApiHost } from '@renderer/utils/api'
|
||||||
import { cloneDeep } from 'lodash'
|
import { cloneDeep } from 'lodash'
|
||||||
|
|
||||||
@ -127,9 +129,11 @@ function isModernSdkSupported(provider: Provider, model?: Model): boolean {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否为图像生成模型(暂时不支持)
|
// 图像生成模型现在支持新的 AI SDK
|
||||||
|
// (但需要确保 provider 是支持的
|
||||||
|
|
||||||
if (model && isDedicatedImageGenerationModel(model)) {
|
if (model && isDedicatedImageGenerationModel(model)) {
|
||||||
return false
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
@ -211,6 +215,12 @@ export default class ModernAiProvider {
|
|||||||
middlewareConfig: AiSdkMiddlewareConfig
|
middlewareConfig: AiSdkMiddlewareConfig
|
||||||
): Promise<CompletionsResult> {
|
): Promise<CompletionsResult> {
|
||||||
console.log('completions', modelId, params, middlewareConfig)
|
console.log('completions', modelId, params, middlewareConfig)
|
||||||
|
|
||||||
|
// 检查是否为图像生成模型
|
||||||
|
if (middlewareConfig.model && isDedicatedImageGenerationModel(middlewareConfig.model)) {
|
||||||
|
return await this.modernImageGeneration(modelId, params, middlewareConfig)
|
||||||
|
}
|
||||||
|
|
||||||
return await this.modernCompletions(modelId, params, middlewareConfig)
|
return await this.modernCompletions(modelId, params, middlewareConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -271,6 +281,109 @@ export default class ModernAiProvider {
|
|||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 使用现代化 AI SDK 的图像生成实现,支持流式输出
|
||||||
|
*/
|
||||||
|
private async modernImageGeneration(
|
||||||
|
modelId: string,
|
||||||
|
params: StreamTextParams,
|
||||||
|
middlewareConfig: AiSdkMiddlewareConfig
|
||||||
|
): Promise<CompletionsResult> {
|
||||||
|
const { onChunk } = middlewareConfig
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 检查 messages 是否存在
|
||||||
|
if (!params.messages || params.messages.length === 0) {
|
||||||
|
throw new Error('No messages provided for image generation.')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从最后一条用户消息中提取 prompt
|
||||||
|
const lastUserMessage = params.messages.findLast((m) => m.role === 'user')
|
||||||
|
if (!lastUserMessage) {
|
||||||
|
throw new Error('No user message found for image generation.')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 直接使用消息内容,避免类型转换问题
|
||||||
|
const prompt =
|
||||||
|
typeof lastUserMessage.content === 'string'
|
||||||
|
? lastUserMessage.content
|
||||||
|
: lastUserMessage.content?.map((part) => ('text' in part ? part.text : '')).join('') || ''
|
||||||
|
|
||||||
|
if (!prompt) {
|
||||||
|
throw new Error('No prompt found in user message.')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送图像生成开始事件
|
||||||
|
if (onChunk) {
|
||||||
|
onChunk({ type: ChunkType.IMAGE_CREATED })
|
||||||
|
}
|
||||||
|
|
||||||
|
const startTime = Date.now()
|
||||||
|
|
||||||
|
// 构建图像生成参数
|
||||||
|
const imageParams = {
|
||||||
|
prompt,
|
||||||
|
size: '1024x1024' as `${number}x${number}`, // 默认尺寸,使用正确的类型
|
||||||
|
n: 1,
|
||||||
|
...(params.abortSignal && { abortSignal: params.abortSignal })
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用新 AI SDK 的图像生成功能
|
||||||
|
const result = await generateImage(this.config.providerId, this.config.options, modelId, imageParams)
|
||||||
|
|
||||||
|
// 转换结果格式
|
||||||
|
const images: string[] = []
|
||||||
|
const imageType: 'url' | 'base64' = 'base64'
|
||||||
|
|
||||||
|
if (result.images) {
|
||||||
|
for (const image of result.images) {
|
||||||
|
if ('base64' in image && image.base64) {
|
||||||
|
images.push(`data:image/png;base64,${image.base64}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送图像生成完成事件
|
||||||
|
if (onChunk && images.length > 0) {
|
||||||
|
onChunk({
|
||||||
|
type: ChunkType.IMAGE_COMPLETE,
|
||||||
|
image: { type: imageType, images }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送响应完成事件
|
||||||
|
if (onChunk) {
|
||||||
|
const usage = {
|
||||||
|
prompt_tokens: 0,
|
||||||
|
completion_tokens: 0,
|
||||||
|
total_tokens: 0
|
||||||
|
}
|
||||||
|
|
||||||
|
onChunk({
|
||||||
|
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||||
|
response: {
|
||||||
|
usage,
|
||||||
|
metrics: {
|
||||||
|
completion_tokens: usage.completion_tokens,
|
||||||
|
time_first_token_millsec: 0,
|
||||||
|
time_completion_millsec: Date.now() - startTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
getText: () => '' // 图像生成不返回文本
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// 发送错误事件
|
||||||
|
if (onChunk) {
|
||||||
|
onChunk({ type: ChunkType.ERROR, error: error as any })
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 代理其他方法到原有实现
|
// 代理其他方法到原有实现
|
||||||
public async models() {
|
public async models() {
|
||||||
return this.legacyProvider.models()
|
return this.legacyProvider.models()
|
||||||
@ -281,9 +394,51 @@ export default class ModernAiProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||||
|
// 如果支持新的 AI SDK,使用现代化实现
|
||||||
|
if (isModernSdkSupported(this.actualProvider)) {
|
||||||
|
try {
|
||||||
|
const result = await this.modernGenerateImage(params)
|
||||||
|
return result
|
||||||
|
} catch (error) {
|
||||||
|
console.warn('Modern AI SDK generateImage failed, falling back to legacy:', error)
|
||||||
|
// fallback 到传统实现
|
||||||
|
return this.legacyProvider.generateImage(params)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 直接使用传统实现
|
||||||
return this.legacyProvider.generateImage(params)
|
return this.legacyProvider.generateImage(params)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 使用现代化 AI SDK 的图像生成实现
|
||||||
|
*/
|
||||||
|
private async modernGenerateImage(params: GenerateImageParams): Promise<string[]> {
|
||||||
|
const { model, prompt, imageSize, batchSize, signal } = params
|
||||||
|
|
||||||
|
// 转换参数格式
|
||||||
|
const aiSdkParams = {
|
||||||
|
prompt,
|
||||||
|
size: (imageSize || '1024x1024') as `${number}x${number}`,
|
||||||
|
n: batchSize || 1,
|
||||||
|
...(signal && { abortSignal: signal })
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = await generateImage(this.config.providerId, this.config.options, model, aiSdkParams)
|
||||||
|
|
||||||
|
// 转换结果格式
|
||||||
|
const images: string[] = []
|
||||||
|
if (result.images) {
|
||||||
|
for (const image of result.images) {
|
||||||
|
if ('base64' in image && image.base64) {
|
||||||
|
images.push(`data:image/png;base64,${image.base64}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return images
|
||||||
|
}
|
||||||
|
|
||||||
public getBaseURL(): string {
|
public getBaseURL(): string {
|
||||||
return this.legacyProvider.getBaseURL()
|
return this.legacyProvider.getBaseURL()
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user