feat: integrate image generation capabilities and enhance testing framework

- Added support for image generation in the `RuntimeExecutor` with a new `generateImage` method.
- Updated `aiCore` package to include `vitest` for testing, with new test scripts added.
- Enhanced type definitions to accommodate image model handling in plugins.
- Introduced new methods for resolving and executing image generation with plugins.
- Updated package dependencies in `package.json` to include `vitest` and ensure compatibility with new features.
This commit is contained in:
suyao 2025-08-01 10:45:31 +08:00
parent 7216e9943c
commit ecc08bd3f7
No known key found for this signature in database
12 changed files with 1873 additions and 17 deletions

View File

@ -0,0 +1,167 @@
/**
* Image Generation Example
* 使 aiCore
*/
import { createExecutor, generateImage } from '../src/index'
async function main() {
// 方式1: 使用执行器实例
console.log('📸 创建 OpenAI 图像生成执行器...')
const executor = createExecutor('openai', {
apiKey: process.env.OPENAI_API_KEY!
})
try {
console.log('🎨 使用执行器生成图像...')
const result1 = await executor.generateImage('dall-e-3', {
prompt: 'A futuristic cityscape at sunset with flying cars',
size: '1024x1024',
n: 1
})
console.log('✅ 图像生成成功!')
console.log('📊 结果:', {
imagesCount: result1.images.length,
mediaType: result1.image.mediaType,
hasBase64: !!result1.image.base64,
providerMetadata: result1.providerMetadata
})
} catch (error) {
console.error('❌ 执行器生成失败:', error)
}
// 方式2: 使用直接调用 API
try {
console.log('🎨 使用直接 API 生成图像...')
const result2 = await generateImage('openai', { apiKey: process.env.OPENAI_API_KEY! }, 'dall-e-3', {
prompt: 'A magical forest with glowing mushrooms and fairy lights',
aspectRatio: '16:9',
providerOptions: {
openai: {
quality: 'hd',
style: 'vivid'
}
}
})
console.log('✅ 直接 API 生成成功!')
console.log('📊 结果:', {
imagesCount: result2.images.length,
mediaType: result2.image.mediaType,
hasBase64: !!result2.image.base64
})
} catch (error) {
console.error('❌ 直接 API 生成失败:', error)
}
// 方式3: 支持其他提供商 (Google Imagen)
if (process.env.GOOGLE_API_KEY) {
try {
console.log('🎨 使用 Google Imagen 生成图像...')
const googleExecutor = createExecutor('google', {
apiKey: process.env.GOOGLE_API_KEY!
})
const result3 = await googleExecutor.generateImage('imagen-3.0-generate-002', {
prompt: 'A serene mountain lake at dawn with mist rising from the water',
aspectRatio: '1:1'
})
console.log('✅ Google Imagen 生成成功!')
console.log('📊 结果:', {
imagesCount: result3.images.length,
mediaType: result3.image.mediaType,
hasBase64: !!result3.image.base64
})
} catch (error) {
console.error('❌ Google Imagen 生成失败:', error)
}
}
// 方式4: 支持插件系统
const pluginExample = async () => {
console.log('🔌 演示插件系统...')
// 创建一个示例插件,用于修改提示词
const promptEnhancerPlugin = {
name: 'prompt-enhancer',
transformParams: async (params: any) => {
console.log('🔧 插件: 增强提示词...')
return {
...params,
prompt: `${params.prompt}, highly detailed, cinematic lighting, 4K resolution`
}
},
transformResult: async (result: any) => {
console.log('🔧 插件: 处理结果...')
return {
...result,
enhanced: true
}
}
}
const executorWithPlugin = createExecutor(
'openai',
{
apiKey: process.env.OPENAI_API_KEY!
},
[promptEnhancerPlugin]
)
try {
const result4 = await executorWithPlugin.generateImage('dall-e-3', {
prompt: 'A cute robot playing in a garden'
})
console.log('✅ 插件系统生成成功!')
console.log('📊 结果:', {
imagesCount: result4.images.length,
enhanced: (result4 as any).enhanced,
mediaType: result4.image.mediaType
})
} catch (error) {
console.error('❌ 插件系统生成失败:', error)
}
}
await pluginExample()
}
// 错误处理演示
async function errorHandlingExample() {
console.log('⚠️ 演示错误处理...')
try {
const executor = createExecutor('openai', {
apiKey: 'invalid-key'
})
await executor.generateImage('dall-e-3', {
prompt: 'Test image'
})
} catch (error: any) {
console.log('✅ 成功捕获错误:', error.constructor.name)
console.log('📋 错误信息:', error.message)
console.log('🏷️ 提供商ID:', error.providerId)
console.log('🏷️ 模型ID:', error.modelId)
}
}
// 运行示例
if (require.main === module) {
main()
.then(() => {
console.log('🎉 所有示例完成!')
return errorHandlingExample()
})
.then(() => {
console.log('🎯 示例程序结束')
process.exit(0)
})
.catch((error) => {
console.error('💥 程序执行出错:', error)
process.exit(1)
})
}

View File

@ -9,7 +9,9 @@
"scripts": {
"build": "tsdown",
"dev": "tsc -w",
"clean": "rm -rf dist"
"clean": "rm -rf dist",
"test": "vitest run",
"test:watch": "vitest"
},
"keywords": [
"ai",
@ -45,7 +47,8 @@
},
"devDependencies": {
"tsdown": "^0.12.9",
"typescript": "^5.0.0"
"typescript": "^5.0.0",
"vitest": "^1.0.0"
},
"sideEffects": false,
"engines": {

View File

@ -1,3 +1,4 @@
import type { ImageModelV2 } from '@ai-sdk/provider'
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
import { type ProviderId } from '../providers/types'
@ -32,7 +33,10 @@ export interface AiPlugin {
enforce?: 'pre' | 'post'
// 【First】首个钩子 - 只执行第一个返回值的插件
resolveModel?: (modelId: string, context: AiRequestContext) => Promise<LanguageModel | null> | LanguageModel | null
resolveModel?: (
modelId: string,
context: AiRequestContext
) => Promise<LanguageModel | ImageModelV2 | null> | LanguageModel | ImageModelV2 | null
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
// 【Sequential】串行钩子 - 链式执行,支持数据转换

View File

@ -0,0 +1,533 @@
import { ImageModelV2 } from '@ai-sdk/provider'
import { experimental_generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createImageModel } from '../../models/ModelCreator'
import { type AiPlugin } from '../../plugins'
import { ImageGenerationError } from '../errors'
import { RuntimeExecutor } from '../executor'
// Mock dependencies
vi.mock('ai', () => ({
experimental_generateImage: vi.fn(),
NoImageGeneratedError: class NoImageGeneratedError extends Error {
static isInstance = vi.fn()
constructor() {
super('No image generated')
this.name = 'NoImageGeneratedError'
}
}
}))
vi.mock('../../models/ModelCreator', () => ({
createImageModel: vi.fn()
}))
describe('RuntimeExecutor.generateImage', () => {
let executor: RuntimeExecutor<'openai'>
let mockImageModel: ImageModelV2
let mockGenerateImageResult: any
beforeEach(() => {
// Reset all mocks
vi.clearAllMocks()
// Create executor instance
executor = RuntimeExecutor.create('openai', {
apiKey: 'test-key'
})
// Mock image model
mockImageModel = {
modelId: 'dall-e-3',
provider: 'openai'
} as ImageModelV2
// Mock generateImage result
mockGenerateImageResult = {
image: {
base64: 'base64-encoded-image-data',
uint8Array: new Uint8Array([1, 2, 3]),
mediaType: 'image/png'
},
images: [
{
base64: 'base64-encoded-image-data',
uint8Array: new Uint8Array([1, 2, 3]),
mediaType: 'image/png'
}
],
warnings: [],
providerMetadata: {
openai: {
images: [{ revisedPrompt: 'A detailed prompt' }]
}
},
responses: []
}
// Setup mocks
vi.mocked(createImageModel).mockResolvedValue(mockImageModel)
vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult)
})
describe('Basic functionality', () => {
it('should generate a single image with minimal parameters', async () => {
const result = await executor.generateImage('dall-e-3', {
prompt: 'A futuristic cityscape at sunset'
})
expect(createImageModel).toHaveBeenCalledWith('openai', 'dall-e-3', { apiKey: 'test-key' })
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A futuristic cityscape at sunset'
})
expect(result).toEqual(mockGenerateImageResult)
})
it('should generate image with pre-created model', async () => {
const result = await executor.generateImage(mockImageModel, {
prompt: 'A beautiful landscape'
})
// Note: createImageModel may still be called due to resolveImageModel logic
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A beautiful landscape'
})
expect(result).toEqual(mockGenerateImageResult)
})
it('should support multiple images generation', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A futuristic cityscape',
n: 3
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A futuristic cityscape',
n: 3
})
})
it('should support size specification', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A beautiful sunset',
size: '1024x1024'
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A beautiful sunset',
size: '1024x1024'
})
})
it('should support aspect ratio specification', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A mountain landscape',
aspectRatio: '16:9'
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A mountain landscape',
aspectRatio: '16:9'
})
})
it('should support seed for consistent output', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A cat in space',
seed: 1234567890
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A cat in space',
seed: 1234567890
})
})
it('should support abort signal', async () => {
const abortController = new AbortController()
await executor.generateImage('dall-e-3', {
prompt: 'A cityscape',
abortSignal: abortController.signal
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A cityscape',
abortSignal: abortController.signal
})
})
it('should support provider-specific options', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A space station',
providerOptions: {
openai: {
quality: 'hd',
style: 'vivid'
}
}
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A space station',
providerOptions: {
openai: {
quality: 'hd',
style: 'vivid'
}
}
})
})
it('should support custom headers', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A robot',
headers: {
'X-Custom-Header': 'test-value'
}
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A robot',
headers: {
'X-Custom-Header': 'test-value'
}
})
})
})
describe('Plugin integration', () => {
it('should execute plugins in correct order', async () => {
const pluginCallOrder: string[] = []
const testPlugin: AiPlugin = {
name: 'test-plugin',
onRequestStart: vi.fn(async () => {
pluginCallOrder.push('onRequestStart')
}),
transformParams: vi.fn(async (params) => {
pluginCallOrder.push('transformParams')
return { ...params, size: '512x512' }
}),
transformResult: vi.fn(async (result) => {
pluginCallOrder.push('transformResult')
return { ...result, processed: true }
}),
onRequestEnd: vi.fn(async () => {
pluginCallOrder.push('onRequestEnd')
})
}
const executorWithPlugin = RuntimeExecutor.create(
'openai',
{
apiKey: 'test-key'
},
[testPlugin]
)
const result = await executorWithPlugin.generateImage('dall-e-3', {
prompt: 'A test image'
})
expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
expect(testPlugin.transformParams).toHaveBeenCalledWith(
{ prompt: 'A test image' },
expect.objectContaining({
providerId: 'openai',
modelId: 'dall-e-3'
})
)
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A test image',
size: '512x512' // Should be transformed by plugin
})
expect(result).toEqual({
...mockGenerateImageResult,
processed: true // Should be transformed by plugin
})
})
it('should handle model resolution through plugins', async () => {
const customImageModel = {
modelId: 'custom-model',
provider: 'openai'
} as ImageModelV2
const modelResolutionPlugin: AiPlugin = {
name: 'model-resolver',
resolveModel: vi.fn(async () => customImageModel)
}
const executorWithPlugin = RuntimeExecutor.create(
'openai',
{
apiKey: 'test-key'
},
[modelResolutionPlugin]
)
await executorWithPlugin.generateImage('dall-e-3', {
prompt: 'A test image'
})
expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith(
'dall-e-3',
expect.objectContaining({
providerId: 'openai',
modelId: 'dall-e-3'
})
)
expect(aiGenerateImage).toHaveBeenCalledWith({
model: customImageModel,
prompt: 'A test image'
})
})
it('should support recursive calls from plugins', async () => {
const recursivePlugin: AiPlugin = {
name: 'recursive-plugin',
transformParams: vi.fn(async (params, context) => {
if (!context.isRecursiveCall && params.prompt === 'original') {
// Make a recursive call with modified prompt
await context.recursiveCall({
prompt: 'modified'
})
}
return params
})
}
const executorWithPlugin = RuntimeExecutor.create(
'openai',
{
apiKey: 'test-key'
},
[recursivePlugin]
)
await executorWithPlugin.generateImage('dall-e-3', {
prompt: 'original'
})
expect(recursivePlugin.transformParams).toHaveBeenCalledTimes(2)
expect(aiGenerateImage).toHaveBeenCalledTimes(2)
})
})
describe('Error handling', () => {
it('should handle model creation errors', async () => {
const modelError = new Error('Failed to create image model')
vi.mocked(createImageModel).mockRejectedValue(modelError)
await expect(
executor.generateImage('invalid-model', {
prompt: 'A test image'
})
).rejects.toThrow(ImageGenerationError)
})
it('should handle image generation API errors', async () => {
const apiError = new Error('API request failed')
vi.mocked(aiGenerateImage).mockRejectedValue(apiError)
await expect(
executor.generateImage('dall-e-3', {
prompt: 'A test image'
})
).rejects.toThrow('Failed to generate image: API request failed')
})
it('should handle NoImageGeneratedError', async () => {
const noImageError = new NoImageGeneratedError({
cause: new Error('No image generated'),
responses: []
})
vi.mocked(aiGenerateImage).mockRejectedValue(noImageError)
vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true)
await expect(
executor.generateImage('dall-e-3', {
prompt: 'A test image'
})
).rejects.toThrow('Failed to generate image: No image generated')
})
it('should execute onError plugin hook on failure', async () => {
const error = new Error('Generation failed')
vi.mocked(aiGenerateImage).mockRejectedValue(error)
const errorPlugin: AiPlugin = {
name: 'error-handler',
onError: vi.fn()
}
const executorWithPlugin = RuntimeExecutor.create(
'openai',
{
apiKey: 'test-key'
},
[errorPlugin]
)
await expect(
executorWithPlugin.generateImage('dall-e-3', {
prompt: 'A test image'
})
).rejects.toThrow('Failed to generate image: Generation failed')
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
expect.objectContaining({
providerId: 'openai',
modelId: 'dall-e-3'
})
)
})
it('should handle abort signal timeout', async () => {
const abortError = new Error('Operation was aborted')
abortError.name = 'AbortError'
vi.mocked(aiGenerateImage).mockRejectedValue(abortError)
const abortController = new AbortController()
setTimeout(() => abortController.abort(), 10)
await expect(
executor.generateImage('dall-e-3', {
prompt: 'A test image',
abortSignal: abortController.signal
})
).rejects.toThrow('Operation was aborted')
})
})
describe('Multiple providers support', () => {
it('should work with different providers', async () => {
const googleExecutor = RuntimeExecutor.create('google', {
apiKey: 'google-key'
})
await googleExecutor.generateImage('imagen-3.0-generate-002', {
prompt: 'A landscape'
})
expect(createImageModel).toHaveBeenCalledWith('google', 'imagen-3.0-generate-002', { apiKey: 'google-key' })
})
it('should support xAI Grok image models', async () => {
const xaiExecutor = RuntimeExecutor.create('xai', {
apiKey: 'xai-key'
})
await xaiExecutor.generateImage('grok-2-image', {
prompt: 'A futuristic robot'
})
expect(createImageModel).toHaveBeenCalledWith('xai', 'grok-2-image', { apiKey: 'xai-key' })
})
})
describe('Advanced features', () => {
it('should support batch image generation with maxImagesPerCall', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A test image',
n: 10,
maxImagesPerCall: 5
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A test image',
n: 10,
maxImagesPerCall: 5
})
})
it('should support retries with maxRetries', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A test image',
maxRetries: 3
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A test image',
maxRetries: 3
})
})
it('should handle warnings from the model', async () => {
const resultWithWarnings = {
...mockGenerateImageResult,
warnings: [
{
type: 'unsupported-setting',
message: 'Size parameter not supported for this model'
}
]
}
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithWarnings)
const result = await executor.generateImage('dall-e-3', {
prompt: 'A test image',
size: '2048x2048' // Unsupported size
})
expect(result.warnings).toHaveLength(1)
expect(result.warnings[0].type).toBe('unsupported-setting')
})
it('should provide access to provider metadata', async () => {
const result = await executor.generateImage('dall-e-3', {
prompt: 'A test image'
})
expect(result.providerMetadata).toBeDefined()
expect(result.providerMetadata.openai).toBeDefined()
})
it('should provide response metadata', async () => {
const resultWithMetadata = {
...mockGenerateImageResult,
responses: [
{
timestamp: new Date(),
modelId: 'dall-e-3',
headers: { 'x-request-id': 'test-123' }
}
]
}
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithMetadata)
const result = await executor.generateImage('dall-e-3', {
prompt: 'A test image'
})
expect(result.responses).toHaveLength(1)
expect(result.responses[0].modelId).toBe('dall-e-3')
expect(result.responses[0].headers).toEqual({ 'x-request-id': 'test-123' })
})
})
})

View File

@ -0,0 +1,38 @@
/**
* Error classes for runtime operations
*/
/**
* Error thrown when image generation fails
*/
export class ImageGenerationError extends Error {
constructor(
message: string,
public providerId?: string,
public modelId?: string,
public cause?: Error
) {
super(message)
this.name = 'ImageGenerationError'
// Maintain proper stack trace (for V8 engines)
if (Error.captureStackTrace) {
Error.captureStackTrace(this, ImageGenerationError)
}
}
}
/**
* Error thrown when model resolution fails during image generation
*/
export class ImageModelResolutionError extends ImageGenerationError {
constructor(modelId: string, providerId?: string, cause?: Error) {
super(
`Failed to resolve image model: ${modelId}${providerId ? ` for provider: ${providerId}` : ''}`,
providerId,
modelId,
cause
)
this.name = 'ImageModelResolutionError'
}
}

View File

@ -2,13 +2,21 @@
*
* AI调用处理
*/
import { LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
import { generateObject, generateText, LanguageModel, streamObject, streamText } from 'ai'
import { ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
import {
experimental_generateImage as generateImage,
generateObject,
generateText,
LanguageModel,
streamObject,
streamText
} from 'ai'
import { type ProviderId } from '../../types'
import { createModel, getProviderInfo } from '../models'
import { createImageModel, createModel, getProviderInfo } from '../models'
import { type ModelConfig } from '../models/types'
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
import { ImageGenerationError, ImageModelResolutionError } from './errors'
import { PluginEngine } from './pluginEngine'
import { type RuntimeConfig } from './types'
@ -42,6 +50,17 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
})
}
createResolveImageModelPlugin() {
return definePlugin({
name: '_internal_resolveImageModel',
enforce: 'post',
resolveModel: async (modelId: string) => {
return await this.resolveImageModel(modelId)
}
})
}
createConfigureContextPlugin() {
return definePlugin({
name: '_internal_configureContext',
@ -203,6 +222,48 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
)
}
/**
* - 使
*/
async generateImage(
model: ImageModelV2,
params: Omit<Parameters<typeof generateImage>[0], 'model'>
): Promise<ReturnType<typeof generateImage>>
async generateImage(
modelId: string,
params: Omit<Parameters<typeof generateImage>[0], 'model'>,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateImage>>
async generateImage(
modelOrId: ImageModelV2 | string,
params: Omit<Parameters<typeof generateImage>[0], 'model'>
): Promise<ReturnType<typeof generateImage>> {
try {
this.pluginEngine.usePlugins([this.createResolveImageModelPlugin(), this.createConfigureContextPlugin()])
return await this.pluginEngine.executeImageWithPlugins(
'generateImage',
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params,
async (model, transformedParams) => {
return await generateImage({ model, ...transformedParams })
}
)
} catch (error) {
if (error instanceof Error) {
throw new ImageGenerationError(
`Failed to generate image: ${error.message}`,
this.config.providerId,
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
error
)
}
throw error
}
}
// === 辅助方法 ===
/**
@ -228,6 +289,27 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
}
}
/**
*
*/
private async resolveImageModel(modelOrId: ImageModelV2 | string): Promise<ImageModelV2> {
try {
if (typeof modelOrId === 'string') {
// 字符串modelId需要创建图像模型
return await createImageModel(this.config.providerId, modelOrId, this.config.providerSettings)
} else {
// 已经是模型,直接返回
return modelOrId
}
} catch (error) {
throw new ImageModelResolutionError(
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
this.config.providerId,
error instanceof Error ? error : undefined
)
}
}
/**
*
*/

View File

@ -100,6 +100,21 @@ export async function streamObject<T extends ProviderId>(
return executor.streamObject(modelId, params, { middlewares })
}
/**
* - middlewares
*/
export async function generateImage<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
modelId: string,
params: Parameters<RuntimeExecutor<T>['generateImage']>[1],
plugins?: AiPlugin[],
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['generateImage']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.generateImage(modelId, params, { middlewares })
}
// === Agent 功能预留 ===
// 未来将在 ../agents/ 文件夹中添加:
// - AgentExecutor.ts

View File

@ -1,3 +1,4 @@
import { ImageModelV2 } from '@ai-sdk/provider'
import { LanguageModel } from 'ai'
import { type AiPlugin, createContext, PluginManager } from '../plugins'
@ -113,6 +114,62 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
}
}
/**
*
* AiExecutor使用
*/
async executeImageWithPlugins<TParams, TResult>(
methodName: string,
modelId: string,
params: TParams,
executor: (model: ImageModelV2, transformedParams: TParams) => Promise<TResult>,
_context?: ReturnType<typeof createContext>
): Promise<TResult> {
// 使用正确的createContext创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params)
// 🔥 为上下文添加递归调用能力
context.recursiveCall = async (newParams: any): Promise<TResult> => {
// 递归调用自身,重新走完整的插件流程
context.isRecursiveCall = true
const result = await this.executeImageWithPlugins(methodName, modelId, newParams, executor, context)
context.isRecursiveCall = false
return result
}
try {
// 0. 配置上下文
await this.pluginManager.executeConfigureContext(context)
// 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型
const model = await this.pluginManager.executeFirst<ImageModelV2>('resolveModel', modelId, context)
if (!model) {
throw new Error(`Failed to resolve image model: ${modelId}`)
}
// 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
// 4. 执行具体的 API 调用
const result = await executor(model, transformedParams)
// 5. 转换结果
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
// 6. 触发完成事件
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
return transformedResult
} catch (error) {
// 7. 触发错误事件
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
throw error
}
}
/**
*
* AiExecutor使用

View File

@ -17,6 +17,7 @@ import { createExecutor } from './core/runtime'
export {
createExecutor,
createOpenAICompatibleExecutor,
generateImage,
generateObject,
generateText,
streamText

View File

@ -0,0 +1,13 @@
import { defineConfig } from 'vitest/config'
export default defineConfig({
test: {
environment: 'node',
globals: true
},
resolve: {
alias: {
'@': './src'
}
}
})

View File

@ -12,6 +12,7 @@ import {
AiCore,
AiPlugin,
createExecutor,
generateImage,
ProviderConfigFactory,
type ProviderId,
type ProviderSettingsMap,
@ -22,6 +23,7 @@ import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
import { ChunkType } from '@renderer/types/chunk'
import { formatApiHost } from '@renderer/utils/api'
import { cloneDeep } from 'lodash'
@ -127,9 +129,11 @@ function isModernSdkSupported(provider: Provider, model?: Model): boolean {
return false
}
// 检查是否为图像生成模型(暂时不支持)
// 图像生成模型现在支持新的 AI SDK
// (但需要确保 provider 是支持的
if (model && isDedicatedImageGenerationModel(model)) {
return false
return true
}
return true
@ -211,6 +215,12 @@ export default class ModernAiProvider {
middlewareConfig: AiSdkMiddlewareConfig
): Promise<CompletionsResult> {
console.log('completions', modelId, params, middlewareConfig)
// 检查是否为图像生成模型
if (middlewareConfig.model && isDedicatedImageGenerationModel(middlewareConfig.model)) {
return await this.modernImageGeneration(modelId, params, middlewareConfig)
}
return await this.modernCompletions(modelId, params, middlewareConfig)
}
@ -271,6 +281,109 @@ export default class ModernAiProvider {
// }
}
/**
* 使 AI SDK
*/
private async modernImageGeneration(
modelId: string,
params: StreamTextParams,
middlewareConfig: AiSdkMiddlewareConfig
): Promise<CompletionsResult> {
const { onChunk } = middlewareConfig
try {
// 检查 messages 是否存在
if (!params.messages || params.messages.length === 0) {
throw new Error('No messages provided for image generation.')
}
// 从最后一条用户消息中提取 prompt
const lastUserMessage = params.messages.findLast((m) => m.role === 'user')
if (!lastUserMessage) {
throw new Error('No user message found for image generation.')
}
// 直接使用消息内容,避免类型转换问题
const prompt =
typeof lastUserMessage.content === 'string'
? lastUserMessage.content
: lastUserMessage.content?.map((part) => ('text' in part ? part.text : '')).join('') || ''
if (!prompt) {
throw new Error('No prompt found in user message.')
}
// 发送图像生成开始事件
if (onChunk) {
onChunk({ type: ChunkType.IMAGE_CREATED })
}
const startTime = Date.now()
// 构建图像生成参数
const imageParams = {
prompt,
size: '1024x1024' as `${number}x${number}`, // 默认尺寸,使用正确的类型
n: 1,
...(params.abortSignal && { abortSignal: params.abortSignal })
}
// 调用新 AI SDK 的图像生成功能
const result = await generateImage(this.config.providerId, this.config.options, modelId, imageParams)
// 转换结果格式
const images: string[] = []
const imageType: 'url' | 'base64' = 'base64'
if (result.images) {
for (const image of result.images) {
if ('base64' in image && image.base64) {
images.push(`data:image/png;base64,${image.base64}`)
}
}
}
// 发送图像生成完成事件
if (onChunk && images.length > 0) {
onChunk({
type: ChunkType.IMAGE_COMPLETE,
image: { type: imageType, images }
})
}
// 发送响应完成事件
if (onChunk) {
const usage = {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0
}
onChunk({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage,
metrics: {
completion_tokens: usage.completion_tokens,
time_first_token_millsec: 0,
time_completion_millsec: Date.now() - startTime
}
}
})
}
return {
getText: () => '' // 图像生成不返回文本
}
} catch (error) {
// 发送错误事件
if (onChunk) {
onChunk({ type: ChunkType.ERROR, error: error as any })
}
throw error
}
}
// 代理其他方法到原有实现
public async models() {
return this.legacyProvider.models()
@ -281,9 +394,51 @@ export default class ModernAiProvider {
}
public async generateImage(params: GenerateImageParams): Promise<string[]> {
// 如果支持新的 AI SDK使用现代化实现
if (isModernSdkSupported(this.actualProvider)) {
try {
const result = await this.modernGenerateImage(params)
return result
} catch (error) {
console.warn('Modern AI SDK generateImage failed, falling back to legacy:', error)
// fallback 到传统实现
return this.legacyProvider.generateImage(params)
}
}
// 直接使用传统实现
return this.legacyProvider.generateImage(params)
}
/**
* 使 AI SDK
*/
private async modernGenerateImage(params: GenerateImageParams): Promise<string[]> {
const { model, prompt, imageSize, batchSize, signal } = params
// 转换参数格式
const aiSdkParams = {
prompt,
size: (imageSize || '1024x1024') as `${number}x${number}`,
n: batchSize || 1,
...(signal && { abortSignal: signal })
}
const result = await generateImage(this.config.providerId, this.config.options, model, aiSdkParams)
// 转换结果格式
const images: string[] = []
if (result.images) {
for (const image of result.images) {
if ('base64' in image && image.base64) {
images.push(`data:image/png;base64,${image.base64}`)
}
}
}
return images
}
public getBaseURL(): string {
return this.legacyProvider.getBaseURL()
}

806
yarn.lock

File diff suppressed because it is too large Load Diff