mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-07 13:59:28 +08:00
feat(image): support grok-2-image image and gpt-4o-image (#4767)
* feat(image): support grok image * feat: add gpt-4o-image * feat: 添加 gpt-image-1 到生成图像模型列表
This commit is contained in:
parent
7d69c1274b
commit
4d06af69a6
@ -2145,7 +2145,13 @@ export const TEXT_TO_IMAGES_MODELS_SUPPORT_IMAGE_ENHANCEMENT = [
|
|||||||
'stabilityai/stable-diffusion-xl-base-1.0'
|
'stabilityai/stable-diffusion-xl-base-1.0'
|
||||||
]
|
]
|
||||||
|
|
||||||
export const GENERATE_IMAGE_MODELS = ['gemini-2.0-flash-exp-image-generation', 'gemini-2.0-flash-exp']
|
export const GENERATE_IMAGE_MODELS = [
|
||||||
|
'gemini-2.0-flash-exp-image-generation',
|
||||||
|
'gemini-2.0-flash-exp',
|
||||||
|
'grok-2-image-1212',
|
||||||
|
'gpt-4o-image',
|
||||||
|
'gpt-image-1'
|
||||||
|
]
|
||||||
|
|
||||||
export const GEMINI_SEARCH_MODELS = [
|
export const GEMINI_SEARCH_MODELS = [
|
||||||
'gemini-2.0-flash',
|
'gemini-2.0-flash',
|
||||||
|
|||||||
@ -507,6 +507,10 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async generateImageByChat(): Promise<void> {
|
||||||
|
throw new Error('Method not implemented.')
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate suggestions
|
* Generate suggestions
|
||||||
* @returns The suggestions
|
* @returns The suggestions
|
||||||
|
|||||||
@ -39,6 +39,7 @@ export default abstract class BaseProvider {
|
|||||||
abstract check(model: Model): Promise<{ valid: boolean; error: Error | null }>
|
abstract check(model: Model): Promise<{ valid: boolean; error: Error | null }>
|
||||||
abstract models(): Promise<OpenAI.Models.Model[]>
|
abstract models(): Promise<OpenAI.Models.Model[]>
|
||||||
abstract generateImage(params: GenerateImageParams): Promise<string[]>
|
abstract generateImage(params: GenerateImageParams): Promise<string[]>
|
||||||
|
abstract generateImageByChat({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
|
||||||
abstract getEmbeddingDimensions(model: Model): Promise<number>
|
abstract getEmbeddingDimensions(model: Model): Promise<number>
|
||||||
|
|
||||||
public getBaseURL(): string {
|
public getBaseURL(): string {
|
||||||
|
|||||||
@ -718,4 +718,8 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
})
|
})
|
||||||
return data.embeddings?.[0]?.values?.length || 0
|
return data.embeddings?.[0]?.values?.length || 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public generateImageByChat(): Promise<void> {
|
||||||
|
throw new Error('Method not implemented.')
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -307,6 +307,10 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
* @returns The completions
|
* @returns The completions
|
||||||
*/
|
*/
|
||||||
async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
|
async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
|
||||||
|
if (assistant.enableGenerateImage) {
|
||||||
|
await this.generateImageByChat({ messages, assistant, onChunk } as CompletionsParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
const defaultModel = getDefaultModel()
|
const defaultModel = getDefaultModel()
|
||||||
const model = assistant.model || defaultModel
|
const model = assistant.model || defaultModel
|
||||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||||
@ -893,4 +897,33 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
const { token } = await window.api.copilot.getToken(defaultHeaders)
|
const { token } = await window.api.copilot.getToken(defaultHeaders)
|
||||||
this.sdk.apiKey = token
|
this.sdk.apiKey = token
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async generateImageByChat({ messages, assistant, onChunk }: CompletionsParams): Promise<void> {
|
||||||
|
const defaultModel = getDefaultModel()
|
||||||
|
const model = assistant.model || defaultModel
|
||||||
|
const lastUserMessage = messages.findLast((m) => m.role === 'user')
|
||||||
|
const { abortController, signalPromise } = this.createAbortController(lastUserMessage?.id, true)
|
||||||
|
const { signal } = abortController
|
||||||
|
const response = await this.sdk.images.generate(
|
||||||
|
{
|
||||||
|
model: model.id,
|
||||||
|
prompt: lastUserMessage?.content || ''
|
||||||
|
},
|
||||||
|
{
|
||||||
|
signal
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await signalPromise?.promise?.catch((error) => {
|
||||||
|
throw error
|
||||||
|
})
|
||||||
|
|
||||||
|
return onChunk({
|
||||||
|
text: '',
|
||||||
|
generateImage: {
|
||||||
|
type: 'url',
|
||||||
|
images: response.data.map((item) => item.url).filter((url): url is string => url !== undefined)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -109,6 +109,15 @@ export default class AiProvider {
|
|||||||
return this.sdk.generateImage(params)
|
return this.sdk.generateImage(params)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async generateImageByChat({
|
||||||
|
messages,
|
||||||
|
assistant,
|
||||||
|
onChunk,
|
||||||
|
onFilterMessages
|
||||||
|
}: CompletionsParams): Promise<void> {
|
||||||
|
return this.sdk.generateImageByChat({ messages, assistant, onChunk, onFilterMessages })
|
||||||
|
}
|
||||||
|
|
||||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||||
return this.sdk.getEmbeddingDimensions(model)
|
return this.sdk.getEmbeddingDimensions(model)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user