mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-24 18:50:56 +08:00
feat(image): support grok-2-image image and gpt-4o-image (#4767)
* feat(image): support grok image * feat: add gpt-4o-image * feat: 添加 gpt-image-1 到生成图像模型列表
This commit is contained in:
parent
7d69c1274b
commit
4d06af69a6
@ -2145,7 +2145,13 @@ export const TEXT_TO_IMAGES_MODELS_SUPPORT_IMAGE_ENHANCEMENT = [
|
||||
'stabilityai/stable-diffusion-xl-base-1.0'
|
||||
]
|
||||
|
||||
export const GENERATE_IMAGE_MODELS = ['gemini-2.0-flash-exp-image-generation', 'gemini-2.0-flash-exp']
|
||||
export const GENERATE_IMAGE_MODELS = [
|
||||
'gemini-2.0-flash-exp-image-generation',
|
||||
'gemini-2.0-flash-exp',
|
||||
'grok-2-image-1212',
|
||||
'gpt-4o-image',
|
||||
'gpt-image-1'
|
||||
]
|
||||
|
||||
export const GEMINI_SEARCH_MODELS = [
|
||||
'gemini-2.0-flash',
|
||||
|
||||
@ -507,6 +507,10 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
return []
|
||||
}
|
||||
|
||||
public async generateImageByChat(): Promise<void> {
|
||||
throw new Error('Method not implemented.')
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate suggestions
|
||||
* @returns The suggestions
|
||||
|
||||
@ -39,6 +39,7 @@ export default abstract class BaseProvider {
|
||||
abstract check(model: Model): Promise<{ valid: boolean; error: Error | null }>
|
||||
abstract models(): Promise<OpenAI.Models.Model[]>
|
||||
abstract generateImage(params: GenerateImageParams): Promise<string[]>
|
||||
abstract generateImageByChat({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
|
||||
abstract getEmbeddingDimensions(model: Model): Promise<number>
|
||||
|
||||
public getBaseURL(): string {
|
||||
|
||||
@ -718,4 +718,8 @@ export default class GeminiProvider extends BaseProvider {
|
||||
})
|
||||
return data.embeddings?.[0]?.values?.length || 0
|
||||
}
|
||||
|
||||
public generateImageByChat(): Promise<void> {
|
||||
throw new Error('Method not implemented.')
|
||||
}
|
||||
}
|
||||
|
||||
@ -307,6 +307,10 @@ export default class OpenAIProvider extends BaseProvider {
|
||||
* @returns The completions
|
||||
*/
|
||||
async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
|
||||
if (assistant.enableGenerateImage) {
|
||||
await this.generateImageByChat({ messages, assistant, onChunk } as CompletionsParams)
|
||||
return
|
||||
}
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||
@ -893,4 +897,33 @@ export default class OpenAIProvider extends BaseProvider {
|
||||
const { token } = await window.api.copilot.getToken(defaultHeaders)
|
||||
this.sdk.apiKey = token
|
||||
}
|
||||
|
||||
public async generateImageByChat({ messages, assistant, onChunk }: CompletionsParams): Promise<void> {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const lastUserMessage = messages.findLast((m) => m.role === 'user')
|
||||
const { abortController, signalPromise } = this.createAbortController(lastUserMessage?.id, true)
|
||||
const { signal } = abortController
|
||||
const response = await this.sdk.images.generate(
|
||||
{
|
||||
model: model.id,
|
||||
prompt: lastUserMessage?.content || ''
|
||||
},
|
||||
{
|
||||
signal
|
||||
}
|
||||
)
|
||||
|
||||
await signalPromise?.promise?.catch((error) => {
|
||||
throw error
|
||||
})
|
||||
|
||||
return onChunk({
|
||||
text: '',
|
||||
generateImage: {
|
||||
type: 'url',
|
||||
images: response.data.map((item) => item.url).filter((url): url is string => url !== undefined)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -109,6 +109,15 @@ export default class AiProvider {
|
||||
return this.sdk.generateImage(params)
|
||||
}
|
||||
|
||||
public async generateImageByChat({
|
||||
messages,
|
||||
assistant,
|
||||
onChunk,
|
||||
onFilterMessages
|
||||
}: CompletionsParams): Promise<void> {
|
||||
return this.sdk.generateImageByChat({ messages, assistant, onChunk, onFilterMessages })
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
return this.sdk.getEmbeddingDimensions(model)
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user