mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-27 21:01:32 +08:00
feat: integrate image generation capabilities and enhance testing framework
- Added support for image generation in the `RuntimeExecutor` with a new `generateImage` method. - Updated `aiCore` package to include `vitest` for testing, with new test scripts added. - Enhanced type definitions to accommodate image model handling in plugins. - Introduced new methods for resolving and executing image generation with plugins. - Updated package dependencies in `package.json` to include `vitest` and ensure compatibility with new features.
This commit is contained in:
parent
7216e9943c
commit
ecc08bd3f7
167
packages/aiCore/examples/image-generation.ts
Normal file
167
packages/aiCore/examples/image-generation.ts
Normal file
@ -0,0 +1,167 @@
|
||||
/**
|
||||
* Image Generation Example
|
||||
* 演示如何使用 aiCore 的文生图功能
|
||||
*/
|
||||
|
||||
import { createExecutor, generateImage } from '../src/index'
|
||||
|
||||
async function main() {
|
||||
// 方式1: 使用执行器实例
|
||||
console.log('📸 创建 OpenAI 图像生成执行器...')
|
||||
const executor = createExecutor('openai', {
|
||||
apiKey: process.env.OPENAI_API_KEY!
|
||||
})
|
||||
|
||||
try {
|
||||
console.log('🎨 使用执行器生成图像...')
|
||||
const result1 = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A futuristic cityscape at sunset with flying cars',
|
||||
size: '1024x1024',
|
||||
n: 1
|
||||
})
|
||||
|
||||
console.log('✅ 图像生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result1.images.length,
|
||||
mediaType: result1.image.mediaType,
|
||||
hasBase64: !!result1.image.base64,
|
||||
providerMetadata: result1.providerMetadata
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 执行器生成失败:', error)
|
||||
}
|
||||
|
||||
// 方式2: 使用直接调用 API
|
||||
try {
|
||||
console.log('🎨 使用直接 API 生成图像...')
|
||||
const result2 = await generateImage('openai', { apiKey: process.env.OPENAI_API_KEY! }, 'dall-e-3', {
|
||||
prompt: 'A magical forest with glowing mushrooms and fairy lights',
|
||||
aspectRatio: '16:9',
|
||||
providerOptions: {
|
||||
openai: {
|
||||
quality: 'hd',
|
||||
style: 'vivid'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
console.log('✅ 直接 API 生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result2.images.length,
|
||||
mediaType: result2.image.mediaType,
|
||||
hasBase64: !!result2.image.base64
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 直接 API 生成失败:', error)
|
||||
}
|
||||
|
||||
// 方式3: 支持其他提供商 (Google Imagen)
|
||||
if (process.env.GOOGLE_API_KEY) {
|
||||
try {
|
||||
console.log('🎨 使用 Google Imagen 生成图像...')
|
||||
const googleExecutor = createExecutor('google', {
|
||||
apiKey: process.env.GOOGLE_API_KEY!
|
||||
})
|
||||
|
||||
const result3 = await googleExecutor.generateImage('imagen-3.0-generate-002', {
|
||||
prompt: 'A serene mountain lake at dawn with mist rising from the water',
|
||||
aspectRatio: '1:1'
|
||||
})
|
||||
|
||||
console.log('✅ Google Imagen 生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result3.images.length,
|
||||
mediaType: result3.image.mediaType,
|
||||
hasBase64: !!result3.image.base64
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ Google Imagen 生成失败:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 方式4: 支持插件系统
|
||||
const pluginExample = async () => {
|
||||
console.log('🔌 演示插件系统...')
|
||||
|
||||
// 创建一个示例插件,用于修改提示词
|
||||
const promptEnhancerPlugin = {
|
||||
name: 'prompt-enhancer',
|
||||
transformParams: async (params: any) => {
|
||||
console.log('🔧 插件: 增强提示词...')
|
||||
return {
|
||||
...params,
|
||||
prompt: `${params.prompt}, highly detailed, cinematic lighting, 4K resolution`
|
||||
}
|
||||
},
|
||||
transformResult: async (result: any) => {
|
||||
console.log('🔧 插件: 处理结果...')
|
||||
return {
|
||||
...result,
|
||||
enhanced: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const executorWithPlugin = createExecutor(
|
||||
'openai',
|
||||
{
|
||||
apiKey: process.env.OPENAI_API_KEY!
|
||||
},
|
||||
[promptEnhancerPlugin]
|
||||
)
|
||||
|
||||
try {
|
||||
const result4 = await executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'A cute robot playing in a garden'
|
||||
})
|
||||
|
||||
console.log('✅ 插件系统生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result4.images.length,
|
||||
enhanced: (result4 as any).enhanced,
|
||||
mediaType: result4.image.mediaType
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 插件系统生成失败:', error)
|
||||
}
|
||||
}
|
||||
|
||||
await pluginExample()
|
||||
}
|
||||
|
||||
// 错误处理演示
|
||||
async function errorHandlingExample() {
|
||||
console.log('⚠️ 演示错误处理...')
|
||||
|
||||
try {
|
||||
const executor = createExecutor('openai', {
|
||||
apiKey: 'invalid-key'
|
||||
})
|
||||
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'Test image'
|
||||
})
|
||||
} catch (error: any) {
|
||||
console.log('✅ 成功捕获错误:', error.constructor.name)
|
||||
console.log('📋 错误信息:', error.message)
|
||||
console.log('🏷️ 提供商ID:', error.providerId)
|
||||
console.log('🏷️ 模型ID:', error.modelId)
|
||||
}
|
||||
}
|
||||
|
||||
// 运行示例
|
||||
if (require.main === module) {
|
||||
main()
|
||||
.then(() => {
|
||||
console.log('🎉 所有示例完成!')
|
||||
return errorHandlingExample()
|
||||
})
|
||||
.then(() => {
|
||||
console.log('🎯 示例程序结束')
|
||||
process.exit(0)
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('💥 程序执行出错:', error)
|
||||
process.exit(1)
|
||||
})
|
||||
}
|
||||
@ -9,7 +9,9 @@
|
||||
"scripts": {
|
||||
"build": "tsdown",
|
||||
"dev": "tsc -w",
|
||||
"clean": "rm -rf dist"
|
||||
"clean": "rm -rf dist",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest"
|
||||
},
|
||||
"keywords": [
|
||||
"ai",
|
||||
@ -45,7 +47,8 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"tsdown": "^0.12.9",
|
||||
"typescript": "^5.0.0"
|
||||
"typescript": "^5.0.0",
|
||||
"vitest": "^1.0.0"
|
||||
},
|
||||
"sideEffects": false,
|
||||
"engines": {
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import type { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { type ProviderId } from '../providers/types'
|
||||
@ -32,7 +33,10 @@ export interface AiPlugin {
|
||||
enforce?: 'pre' | 'post'
|
||||
|
||||
// 【First】首个钩子 - 只执行第一个返回值的插件
|
||||
resolveModel?: (modelId: string, context: AiRequestContext) => Promise<LanguageModel | null> | LanguageModel | null
|
||||
resolveModel?: (
|
||||
modelId: string,
|
||||
context: AiRequestContext
|
||||
) => Promise<LanguageModel | ImageModelV2 | null> | LanguageModel | ImageModelV2 | null
|
||||
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
|
||||
|
||||
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
||||
|
||||
533
packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts
Normal file
533
packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts
Normal file
@ -0,0 +1,533 @@
|
||||
import { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import { experimental_generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { createImageModel } from '../../models/ModelCreator'
|
||||
import { type AiPlugin } from '../../plugins'
|
||||
import { ImageGenerationError } from '../errors'
|
||||
import { RuntimeExecutor } from '../executor'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('ai', () => ({
|
||||
experimental_generateImage: vi.fn(),
|
||||
NoImageGeneratedError: class NoImageGeneratedError extends Error {
|
||||
static isInstance = vi.fn()
|
||||
constructor() {
|
||||
super('No image generated')
|
||||
this.name = 'NoImageGeneratedError'
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../../models/ModelCreator', () => ({
|
||||
createImageModel: vi.fn()
|
||||
}))
|
||||
|
||||
describe('RuntimeExecutor.generateImage', () => {
|
||||
let executor: RuntimeExecutor<'openai'>
|
||||
let mockImageModel: ImageModelV2
|
||||
let mockGenerateImageResult: any
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset all mocks
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Create executor instance
|
||||
executor = RuntimeExecutor.create('openai', {
|
||||
apiKey: 'test-key'
|
||||
})
|
||||
|
||||
// Mock image model
|
||||
mockImageModel = {
|
||||
modelId: 'dall-e-3',
|
||||
provider: 'openai'
|
||||
} as ImageModelV2
|
||||
|
||||
// Mock generateImage result
|
||||
mockGenerateImageResult = {
|
||||
image: {
|
||||
base64: 'base64-encoded-image-data',
|
||||
uint8Array: new Uint8Array([1, 2, 3]),
|
||||
mediaType: 'image/png'
|
||||
},
|
||||
images: [
|
||||
{
|
||||
base64: 'base64-encoded-image-data',
|
||||
uint8Array: new Uint8Array([1, 2, 3]),
|
||||
mediaType: 'image/png'
|
||||
}
|
||||
],
|
||||
warnings: [],
|
||||
providerMetadata: {
|
||||
openai: {
|
||||
images: [{ revisedPrompt: 'A detailed prompt' }]
|
||||
}
|
||||
},
|
||||
responses: []
|
||||
}
|
||||
|
||||
// Setup mocks
|
||||
vi.mocked(createImageModel).mockResolvedValue(mockImageModel)
|
||||
vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult)
|
||||
})
|
||||
|
||||
describe('Basic functionality', () => {
|
||||
it('should generate a single image with minimal parameters', async () => {
|
||||
const result = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A futuristic cityscape at sunset'
|
||||
})
|
||||
|
||||
expect(createImageModel).toHaveBeenCalledWith('openai', 'dall-e-3', { apiKey: 'test-key' })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A futuristic cityscape at sunset'
|
||||
})
|
||||
|
||||
expect(result).toEqual(mockGenerateImageResult)
|
||||
})
|
||||
|
||||
it('should generate image with pre-created model', async () => {
|
||||
const result = await executor.generateImage(mockImageModel, {
|
||||
prompt: 'A beautiful landscape'
|
||||
})
|
||||
|
||||
// Note: createImageModel may still be called due to resolveImageModel logic
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A beautiful landscape'
|
||||
})
|
||||
|
||||
expect(result).toEqual(mockGenerateImageResult)
|
||||
})
|
||||
|
||||
it('should support multiple images generation', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A futuristic cityscape',
|
||||
n: 3
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A futuristic cityscape',
|
||||
n: 3
|
||||
})
|
||||
})
|
||||
|
||||
it('should support size specification', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A beautiful sunset',
|
||||
size: '1024x1024'
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A beautiful sunset',
|
||||
size: '1024x1024'
|
||||
})
|
||||
})
|
||||
|
||||
it('should support aspect ratio specification', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A mountain landscape',
|
||||
aspectRatio: '16:9'
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A mountain landscape',
|
||||
aspectRatio: '16:9'
|
||||
})
|
||||
})
|
||||
|
||||
it('should support seed for consistent output', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A cat in space',
|
||||
seed: 1234567890
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A cat in space',
|
||||
seed: 1234567890
|
||||
})
|
||||
})
|
||||
|
||||
it('should support abort signal', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A cityscape',
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A cityscape',
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
})
|
||||
|
||||
it('should support provider-specific options', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A space station',
|
||||
providerOptions: {
|
||||
openai: {
|
||||
quality: 'hd',
|
||||
style: 'vivid'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A space station',
|
||||
providerOptions: {
|
||||
openai: {
|
||||
quality: 'hd',
|
||||
style: 'vivid'
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should support custom headers', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A robot',
|
||||
headers: {
|
||||
'X-Custom-Header': 'test-value'
|
||||
}
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A robot',
|
||||
headers: {
|
||||
'X-Custom-Header': 'test-value'
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Plugin integration', () => {
|
||||
it('should execute plugins in correct order', async () => {
|
||||
const pluginCallOrder: string[] = []
|
||||
|
||||
const testPlugin: AiPlugin = {
|
||||
name: 'test-plugin',
|
||||
onRequestStart: vi.fn(async () => {
|
||||
pluginCallOrder.push('onRequestStart')
|
||||
}),
|
||||
transformParams: vi.fn(async (params) => {
|
||||
pluginCallOrder.push('transformParams')
|
||||
return { ...params, size: '512x512' }
|
||||
}),
|
||||
transformResult: vi.fn(async (result) => {
|
||||
pluginCallOrder.push('transformResult')
|
||||
return { ...result, processed: true }
|
||||
}),
|
||||
onRequestEnd: vi.fn(async () => {
|
||||
pluginCallOrder.push('onRequestEnd')
|
||||
})
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
[testPlugin]
|
||||
)
|
||||
|
||||
const result = await executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
|
||||
expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
|
||||
|
||||
expect(testPlugin.transformParams).toHaveBeenCalledWith(
|
||||
{ prompt: 'A test image' },
|
||||
expect.objectContaining({
|
||||
providerId: 'openai',
|
||||
modelId: 'dall-e-3'
|
||||
})
|
||||
)
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A test image',
|
||||
size: '512x512' // Should be transformed by plugin
|
||||
})
|
||||
|
||||
expect(result).toEqual({
|
||||
...mockGenerateImageResult,
|
||||
processed: true // Should be transformed by plugin
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle model resolution through plugins', async () => {
|
||||
const customImageModel = {
|
||||
modelId: 'custom-model',
|
||||
provider: 'openai'
|
||||
} as ImageModelV2
|
||||
|
||||
const modelResolutionPlugin: AiPlugin = {
|
||||
name: 'model-resolver',
|
||||
resolveModel: vi.fn(async () => customImageModel)
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
[modelResolutionPlugin]
|
||||
)
|
||||
|
||||
await executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
|
||||
expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith(
|
||||
'dall-e-3',
|
||||
expect.objectContaining({
|
||||
providerId: 'openai',
|
||||
modelId: 'dall-e-3'
|
||||
})
|
||||
)
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: customImageModel,
|
||||
prompt: 'A test image'
|
||||
})
|
||||
})
|
||||
|
||||
it('should support recursive calls from plugins', async () => {
|
||||
const recursivePlugin: AiPlugin = {
|
||||
name: 'recursive-plugin',
|
||||
transformParams: vi.fn(async (params, context) => {
|
||||
if (!context.isRecursiveCall && params.prompt === 'original') {
|
||||
// Make a recursive call with modified prompt
|
||||
await context.recursiveCall({
|
||||
prompt: 'modified'
|
||||
})
|
||||
}
|
||||
return params
|
||||
})
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
[recursivePlugin]
|
||||
)
|
||||
|
||||
await executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'original'
|
||||
})
|
||||
|
||||
expect(recursivePlugin.transformParams).toHaveBeenCalledTimes(2)
|
||||
expect(aiGenerateImage).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error handling', () => {
|
||||
it('should handle model creation errors', async () => {
|
||||
const modelError = new Error('Failed to create image model')
|
||||
vi.mocked(createImageModel).mockRejectedValue(modelError)
|
||||
|
||||
await expect(
|
||||
executor.generateImage('invalid-model', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
).rejects.toThrow(ImageGenerationError)
|
||||
})
|
||||
|
||||
it('should handle image generation API errors', async () => {
|
||||
const apiError = new Error('API request failed')
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(apiError)
|
||||
|
||||
await expect(
|
||||
executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
).rejects.toThrow('Failed to generate image: API request failed')
|
||||
})
|
||||
|
||||
it('should handle NoImageGeneratedError', async () => {
|
||||
const noImageError = new NoImageGeneratedError({
|
||||
cause: new Error('No image generated'),
|
||||
responses: []
|
||||
})
|
||||
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(noImageError)
|
||||
vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true)
|
||||
|
||||
await expect(
|
||||
executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
).rejects.toThrow('Failed to generate image: No image generated')
|
||||
})
|
||||
|
||||
it('should execute onError plugin hook on failure', async () => {
|
||||
const error = new Error('Generation failed')
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(error)
|
||||
|
||||
const errorPlugin: AiPlugin = {
|
||||
name: 'error-handler',
|
||||
onError: vi.fn()
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
[errorPlugin]
|
||||
)
|
||||
|
||||
await expect(
|
||||
executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
).rejects.toThrow('Failed to generate image: Generation failed')
|
||||
|
||||
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
||||
error,
|
||||
expect.objectContaining({
|
||||
providerId: 'openai',
|
||||
modelId: 'dall-e-3'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle abort signal timeout', async () => {
|
||||
const abortError = new Error('Operation was aborted')
|
||||
abortError.name = 'AbortError'
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(abortError)
|
||||
|
||||
const abortController = new AbortController()
|
||||
setTimeout(() => abortController.abort(), 10)
|
||||
|
||||
await expect(
|
||||
executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image',
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
).rejects.toThrow('Operation was aborted')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Multiple providers support', () => {
|
||||
it('should work with different providers', async () => {
|
||||
const googleExecutor = RuntimeExecutor.create('google', {
|
||||
apiKey: 'google-key'
|
||||
})
|
||||
|
||||
await googleExecutor.generateImage('imagen-3.0-generate-002', {
|
||||
prompt: 'A landscape'
|
||||
})
|
||||
|
||||
expect(createImageModel).toHaveBeenCalledWith('google', 'imagen-3.0-generate-002', { apiKey: 'google-key' })
|
||||
})
|
||||
|
||||
it('should support xAI Grok image models', async () => {
|
||||
const xaiExecutor = RuntimeExecutor.create('xai', {
|
||||
apiKey: 'xai-key'
|
||||
})
|
||||
|
||||
await xaiExecutor.generateImage('grok-2-image', {
|
||||
prompt: 'A futuristic robot'
|
||||
})
|
||||
|
||||
expect(createImageModel).toHaveBeenCalledWith('xai', 'grok-2-image', { apiKey: 'xai-key' })
|
||||
})
|
||||
})
|
||||
|
||||
describe('Advanced features', () => {
|
||||
it('should support batch image generation with maxImagesPerCall', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image',
|
||||
n: 10,
|
||||
maxImagesPerCall: 5
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A test image',
|
||||
n: 10,
|
||||
maxImagesPerCall: 5
|
||||
})
|
||||
})
|
||||
|
||||
it('should support retries with maxRetries', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image',
|
||||
maxRetries: 3
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A test image',
|
||||
maxRetries: 3
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle warnings from the model', async () => {
|
||||
const resultWithWarnings = {
|
||||
...mockGenerateImageResult,
|
||||
warnings: [
|
||||
{
|
||||
type: 'unsupported-setting',
|
||||
message: 'Size parameter not supported for this model'
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithWarnings)
|
||||
|
||||
const result = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image',
|
||||
size: '2048x2048' // Unsupported size
|
||||
})
|
||||
|
||||
expect(result.warnings).toHaveLength(1)
|
||||
expect(result.warnings[0].type).toBe('unsupported-setting')
|
||||
})
|
||||
|
||||
it('should provide access to provider metadata', async () => {
|
||||
const result = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
|
||||
expect(result.providerMetadata).toBeDefined()
|
||||
expect(result.providerMetadata.openai).toBeDefined()
|
||||
})
|
||||
|
||||
it('should provide response metadata', async () => {
|
||||
const resultWithMetadata = {
|
||||
...mockGenerateImageResult,
|
||||
responses: [
|
||||
{
|
||||
timestamp: new Date(),
|
||||
modelId: 'dall-e-3',
|
||||
headers: { 'x-request-id': 'test-123' }
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithMetadata)
|
||||
|
||||
const result = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
|
||||
expect(result.responses).toHaveLength(1)
|
||||
expect(result.responses[0].modelId).toBe('dall-e-3')
|
||||
expect(result.responses[0].headers).toEqual({ 'x-request-id': 'test-123' })
|
||||
})
|
||||
})
|
||||
})
|
||||
38
packages/aiCore/src/core/runtime/errors.ts
Normal file
38
packages/aiCore/src/core/runtime/errors.ts
Normal file
@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Error classes for runtime operations
|
||||
*/
|
||||
|
||||
/**
|
||||
* Error thrown when image generation fails
|
||||
*/
|
||||
export class ImageGenerationError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public providerId?: string,
|
||||
public modelId?: string,
|
||||
public cause?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'ImageGenerationError'
|
||||
|
||||
// Maintain proper stack trace (for V8 engines)
|
||||
if (Error.captureStackTrace) {
|
||||
Error.captureStackTrace(this, ImageGenerationError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Error thrown when model resolution fails during image generation
|
||||
*/
|
||||
export class ImageModelResolutionError extends ImageGenerationError {
|
||||
constructor(modelId: string, providerId?: string, cause?: Error) {
|
||||
super(
|
||||
`Failed to resolve image model: ${modelId}${providerId ? ` for provider: ${providerId}` : ''}`,
|
||||
providerId,
|
||||
modelId,
|
||||
cause
|
||||
)
|
||||
this.name = 'ImageModelResolutionError'
|
||||
}
|
||||
}
|
||||
@ -2,13 +2,21 @@
|
||||
* 运行时执行器
|
||||
* 专注于插件化的AI调用处理
|
||||
*/
|
||||
import { LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
import { generateObject, generateText, LanguageModel, streamObject, streamText } from 'ai'
|
||||
import { ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
import {
|
||||
experimental_generateImage as generateImage,
|
||||
generateObject,
|
||||
generateText,
|
||||
LanguageModel,
|
||||
streamObject,
|
||||
streamText
|
||||
} from 'ai'
|
||||
|
||||
import { type ProviderId } from '../../types'
|
||||
import { createModel, getProviderInfo } from '../models'
|
||||
import { createImageModel, createModel, getProviderInfo } from '../models'
|
||||
import { type ModelConfig } from '../models/types'
|
||||
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
||||
import { ImageGenerationError, ImageModelResolutionError } from './errors'
|
||||
import { PluginEngine } from './pluginEngine'
|
||||
import { type RuntimeConfig } from './types'
|
||||
|
||||
@ -42,6 +50,17 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
})
|
||||
}
|
||||
|
||||
createResolveImageModelPlugin() {
|
||||
return definePlugin({
|
||||
name: '_internal_resolveImageModel',
|
||||
enforce: 'post',
|
||||
|
||||
resolveModel: async (modelId: string) => {
|
||||
return await this.resolveImageModel(modelId)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
createConfigureContextPlugin() {
|
||||
return definePlugin({
|
||||
name: '_internal_configureContext',
|
||||
@ -203,6 +222,48 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成图像 - 使用已创建的图像模型
|
||||
*/
|
||||
async generateImage(
|
||||
model: ImageModelV2,
|
||||
params: Omit<Parameters<typeof generateImage>[0], 'model'>
|
||||
): Promise<ReturnType<typeof generateImage>>
|
||||
async generateImage(
|
||||
modelId: string,
|
||||
params: Omit<Parameters<typeof generateImage>[0], 'model'>,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateImage>>
|
||||
async generateImage(
|
||||
modelOrId: ImageModelV2 | string,
|
||||
params: Omit<Parameters<typeof generateImage>[0], 'model'>
|
||||
): Promise<ReturnType<typeof generateImage>> {
|
||||
try {
|
||||
this.pluginEngine.usePlugins([this.createResolveImageModelPlugin(), this.createConfigureContextPlugin()])
|
||||
|
||||
return await this.pluginEngine.executeImageWithPlugins(
|
||||
'generateImage',
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
params,
|
||||
async (model, transformedParams) => {
|
||||
return await generateImage({ model, ...transformedParams })
|
||||
}
|
||||
)
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new ImageGenerationError(
|
||||
`Failed to generate image: ${error.message}`,
|
||||
this.config.providerId,
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
error
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// === 辅助方法 ===
|
||||
|
||||
/**
|
||||
@ -228,6 +289,27 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析图像模型:如果是字符串则创建图像模型,如果是模型则直接返回
|
||||
*/
|
||||
private async resolveImageModel(modelOrId: ImageModelV2 | string): Promise<ImageModelV2> {
|
||||
try {
|
||||
if (typeof modelOrId === 'string') {
|
||||
// 字符串modelId,需要创建图像模型
|
||||
return await createImageModel(this.config.providerId, modelOrId, this.config.providerSettings)
|
||||
} else {
|
||||
// 已经是模型,直接返回
|
||||
return modelOrId
|
||||
}
|
||||
} catch (error) {
|
||||
throw new ImageModelResolutionError(
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
this.config.providerId,
|
||||
error instanceof Error ? error : undefined
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取客户端信息
|
||||
*/
|
||||
|
||||
@ -100,6 +100,21 @@ export async function streamObject<T extends ProviderId>(
|
||||
return executor.streamObject(modelId, params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接生成图像 - 支持middlewares
|
||||
*/
|
||||
export async function generateImage<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['generateImage']>[1],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateImage']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.generateImage(modelId, params, { middlewares })
|
||||
}
|
||||
|
||||
// === Agent 功能预留 ===
|
||||
// 未来将在 ../agents/ 文件夹中添加:
|
||||
// - AgentExecutor.ts
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import { LanguageModel } from 'ai'
|
||||
|
||||
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
||||
@ -113,6 +114,62 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行带插件的图像生成操作
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeImageWithPlugins<TParams, TResult>(
|
||||
methodName: string,
|
||||
modelId: string,
|
||||
params: TParams,
|
||||
executor: (model: ImageModelV2, transformedParams: TParams) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 使用正确的createContext创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeImageWithPlugins(methodName, modelId, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
|
||||
try {
|
||||
// 0. 配置上下文
|
||||
await this.pluginManager.executeConfigureContext(context)
|
||||
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型
|
||||
const model = await this.pluginManager.executeFirst<ImageModelV2>('resolveModel', modelId, context)
|
||||
if (!model) {
|
||||
throw new Error(`Failed to resolve image model: ${modelId}`)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 执行具体的 API 调用
|
||||
const result = await executor(model, transformedParams)
|
||||
|
||||
// 5. 转换结果
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
|
||||
// 6. 触发完成事件
|
||||
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 7. 触发错误事件
|
||||
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行流式调用的通用逻辑(支持流转换器)
|
||||
* 提供给AiExecutor使用
|
||||
|
||||
@ -17,6 +17,7 @@ import { createExecutor } from './core/runtime'
|
||||
export {
|
||||
createExecutor,
|
||||
createOpenAICompatibleExecutor,
|
||||
generateImage,
|
||||
generateObject,
|
||||
generateText,
|
||||
streamText
|
||||
|
||||
13
packages/aiCore/vitest.config.ts
Normal file
13
packages/aiCore/vitest.config.ts
Normal file
@ -0,0 +1,13 @@
|
||||
import { defineConfig } from 'vitest/config'
|
||||
|
||||
export default defineConfig({
|
||||
test: {
|
||||
environment: 'node',
|
||||
globals: true
|
||||
},
|
||||
resolve: {
|
||||
alias: {
|
||||
'@': './src'
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -12,6 +12,7 @@ import {
|
||||
AiCore,
|
||||
AiPlugin,
|
||||
createExecutor,
|
||||
generateImage,
|
||||
ProviderConfigFactory,
|
||||
type ProviderId,
|
||||
type ProviderSettingsMap,
|
||||
@ -22,6 +23,7 @@ import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import { cloneDeep } from 'lodash'
|
||||
|
||||
@ -127,9 +129,11 @@ function isModernSdkSupported(provider: Provider, model?: Model): boolean {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否为图像生成模型(暂时不支持)
|
||||
// 图像生成模型现在支持新的 AI SDK
|
||||
// (但需要确保 provider 是支持的
|
||||
|
||||
if (model && isDedicatedImageGenerationModel(model)) {
|
||||
return false
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
@ -211,6 +215,12 @@ export default class ModernAiProvider {
|
||||
middlewareConfig: AiSdkMiddlewareConfig
|
||||
): Promise<CompletionsResult> {
|
||||
console.log('completions', modelId, params, middlewareConfig)
|
||||
|
||||
// 检查是否为图像生成模型
|
||||
if (middlewareConfig.model && isDedicatedImageGenerationModel(middlewareConfig.model)) {
|
||||
return await this.modernImageGeneration(modelId, params, middlewareConfig)
|
||||
}
|
||||
|
||||
return await this.modernCompletions(modelId, params, middlewareConfig)
|
||||
}
|
||||
|
||||
@ -271,6 +281,109 @@ export default class ModernAiProvider {
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化 AI SDK 的图像生成实现,支持流式输出
|
||||
*/
|
||||
private async modernImageGeneration(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
middlewareConfig: AiSdkMiddlewareConfig
|
||||
): Promise<CompletionsResult> {
|
||||
const { onChunk } = middlewareConfig
|
||||
|
||||
try {
|
||||
// 检查 messages 是否存在
|
||||
if (!params.messages || params.messages.length === 0) {
|
||||
throw new Error('No messages provided for image generation.')
|
||||
}
|
||||
|
||||
// 从最后一条用户消息中提取 prompt
|
||||
const lastUserMessage = params.messages.findLast((m) => m.role === 'user')
|
||||
if (!lastUserMessage) {
|
||||
throw new Error('No user message found for image generation.')
|
||||
}
|
||||
|
||||
// 直接使用消息内容,避免类型转换问题
|
||||
const prompt =
|
||||
typeof lastUserMessage.content === 'string'
|
||||
? lastUserMessage.content
|
||||
: lastUserMessage.content?.map((part) => ('text' in part ? part.text : '')).join('') || ''
|
||||
|
||||
if (!prompt) {
|
||||
throw new Error('No prompt found in user message.')
|
||||
}
|
||||
|
||||
// 发送图像生成开始事件
|
||||
if (onChunk) {
|
||||
onChunk({ type: ChunkType.IMAGE_CREATED })
|
||||
}
|
||||
|
||||
const startTime = Date.now()
|
||||
|
||||
// 构建图像生成参数
|
||||
const imageParams = {
|
||||
prompt,
|
||||
size: '1024x1024' as `${number}x${number}`, // 默认尺寸,使用正确的类型
|
||||
n: 1,
|
||||
...(params.abortSignal && { abortSignal: params.abortSignal })
|
||||
}
|
||||
|
||||
// 调用新 AI SDK 的图像生成功能
|
||||
const result = await generateImage(this.config.providerId, this.config.options, modelId, imageParams)
|
||||
|
||||
// 转换结果格式
|
||||
const images: string[] = []
|
||||
const imageType: 'url' | 'base64' = 'base64'
|
||||
|
||||
if (result.images) {
|
||||
for (const image of result.images) {
|
||||
if ('base64' in image && image.base64) {
|
||||
images.push(`data:image/png;base64,${image.base64}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 发送图像生成完成事件
|
||||
if (onChunk && images.length > 0) {
|
||||
onChunk({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: { type: imageType, images }
|
||||
})
|
||||
}
|
||||
|
||||
// 发送响应完成事件
|
||||
if (onChunk) {
|
||||
const usage = {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0
|
||||
}
|
||||
|
||||
onChunk({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage,
|
||||
metrics: {
|
||||
completion_tokens: usage.completion_tokens,
|
||||
time_first_token_millsec: 0,
|
||||
time_completion_millsec: Date.now() - startTime
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
getText: () => '' // 图像生成不返回文本
|
||||
}
|
||||
} catch (error) {
|
||||
// 发送错误事件
|
||||
if (onChunk) {
|
||||
onChunk({ type: ChunkType.ERROR, error: error as any })
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 代理其他方法到原有实现
|
||||
public async models() {
|
||||
return this.legacyProvider.models()
|
||||
@ -281,9 +394,51 @@ export default class ModernAiProvider {
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
// 如果支持新的 AI SDK,使用现代化实现
|
||||
if (isModernSdkSupported(this.actualProvider)) {
|
||||
try {
|
||||
const result = await this.modernGenerateImage(params)
|
||||
return result
|
||||
} catch (error) {
|
||||
console.warn('Modern AI SDK generateImage failed, falling back to legacy:', error)
|
||||
// fallback 到传统实现
|
||||
return this.legacyProvider.generateImage(params)
|
||||
}
|
||||
}
|
||||
|
||||
// 直接使用传统实现
|
||||
return this.legacyProvider.generateImage(params)
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化 AI SDK 的图像生成实现
|
||||
*/
|
||||
private async modernGenerateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
const { model, prompt, imageSize, batchSize, signal } = params
|
||||
|
||||
// 转换参数格式
|
||||
const aiSdkParams = {
|
||||
prompt,
|
||||
size: (imageSize || '1024x1024') as `${number}x${number}`,
|
||||
n: batchSize || 1,
|
||||
...(signal && { abortSignal: signal })
|
||||
}
|
||||
|
||||
const result = await generateImage(this.config.providerId, this.config.options, model, aiSdkParams)
|
||||
|
||||
// 转换结果格式
|
||||
const images: string[] = []
|
||||
if (result.images) {
|
||||
for (const image of result.images) {
|
||||
if ('base64' in image && image.base64) {
|
||||
images.push(`data:image/png;base64,${image.base64}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return images
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.legacyProvider.getBaseURL()
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user