From aaf396f83aa23a619207c9d494b01ddf8e7932e4 Mon Sep 17 00:00:00 2001 From: one Date: Wed, 30 Apr 2025 22:25:32 +0800 Subject: [PATCH] feat: support streaming for model health check (#5546) --- .../providers/AiProvider/AnthropicProvider.ts | 39 +++++++++++++---- .../src/providers/AiProvider/BaseProvider.ts | 2 +- .../providers/AiProvider/GeminiProvider.ts | 42 ++++++++++++++----- .../providers/AiProvider/OpenAIProvider.ts | 29 +++++++++---- .../src/providers/AiProvider/index.ts | 4 +- src/renderer/src/services/ModelService.ts | 9 +++- 6 files changed, 97 insertions(+), 28 deletions(-) diff --git a/src/renderer/src/providers/AiProvider/AnthropicProvider.ts b/src/renderer/src/providers/AiProvider/AnthropicProvider.ts index 560de81225..e06787d35c 100644 --- a/src/renderer/src/providers/AiProvider/AnthropicProvider.ts +++ b/src/renderer/src/providers/AiProvider/AnthropicProvider.ts @@ -531,25 +531,50 @@ export default class AnthropicProvider extends BaseProvider { /** * Check if the model is valid * @param model - The model + * @param stream - Whether to use streaming interface * @returns The validity of the model */ - public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> { + public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> { if (!model) { return { valid: false, error: new Error('No model found') } } const body = { model: model.id, - messages: [{ role: 'user', content: 'hi' }], + messages: [{ role: 'user' as const, content: 'hi' }], max_tokens: 100, - stream: false + stream } try { - const message = await this.sdk.messages.create(body as MessageCreateParamsNonStreaming) - return { - valid: message.content.length > 0, - error: null + if (!stream) { + const message = await this.sdk.messages.create(body as MessageCreateParamsNonStreaming) + return { + valid: message.content.length > 0, + error: null + } + } else { + return await new Promise((resolve, reject) => { + let hasContent = false + this.sdk.messages + .stream(body) + .on('text', (text) => { + if (!hasContent && text) { + hasContent = true + resolve({ valid: true, error: null }) + } + }) + .on('finalMessage', (message) => { + if (!hasContent && message.content && message.content.length > 0) { + hasContent = true + resolve({ valid: true, error: null }) + } + if (!hasContent) { + reject(new Error('Empty streaming response')) + } + }) + .on('error', (error) => reject(error)) + }) } } catch (error: any) { return { diff --git a/src/renderer/src/providers/AiProvider/BaseProvider.ts b/src/renderer/src/providers/AiProvider/BaseProvider.ts index 9e52fa2def..dfca89790d 100644 --- a/src/renderer/src/providers/AiProvider/BaseProvider.ts +++ b/src/renderer/src/providers/AiProvider/BaseProvider.ts @@ -43,7 +43,7 @@ export default abstract class BaseProvider { abstract summaryForSearch(messages: Message[], assistant: Assistant): Promise abstract suggestions(messages: Message[], assistant: Assistant): Promise abstract generateText({ prompt, content }: { prompt: string; content: string }): Promise - abstract check(model: Model): Promise<{ valid: boolean; error: Error | null }> + abstract check(model: Model, stream: boolean): Promise<{ valid: boolean; error: Error | null }> abstract models(): Promise abstract generateImage(params: GenerateImageParams): Promise abstract generateImageByChat({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise diff --git a/src/renderer/src/providers/AiProvider/GeminiProvider.ts b/src/renderer/src/providers/AiProvider/GeminiProvider.ts index 09beed9b9b..7897bd762d 100644 --- a/src/renderer/src/providers/AiProvider/GeminiProvider.ts +++ b/src/renderer/src/providers/AiProvider/GeminiProvider.ts @@ -740,25 +740,47 @@ export default class GeminiProvider extends BaseProvider { /** * Check if the model is valid * @param model - The model + * @param stream - Whether to use streaming interface * @returns The validity of the model */ - public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> { + public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> { if (!model) { return { valid: false, error: new Error('No model found') } } try { - const result = await this.sdk.models.generateContent({ - model: model.id, - contents: [{ role: 'user', parts: [{ text: 'hi' }] }], - config: { - maxOutputTokens: 100 + if (!stream) { + const result = await this.sdk.models.generateContent({ + model: model.id, + contents: [{ role: 'user', parts: [{ text: 'hi' }] }], + config: { + maxOutputTokens: 100 + } + }) + if (isEmpty(result.text)) { + throw new Error('Empty response') + } + } else { + const response = await this.sdk.models.generateContentStream({ + model: model.id, + contents: [{ role: 'user', parts: [{ text: 'hi' }] }], + config: { + maxOutputTokens: 100 + } + }) + // 等待整个流式响应结束 + let hasContent = false + for await (const chunk of response) { + if (chunk.text && chunk.text.length > 0) { + hasContent = true + break + } + } + if (!hasContent) { + throw new Error('Empty streaming response') } - }) - return { - valid: !isEmpty(result.text), - error: null } + return { valid: true, error: null } } catch (error: any) { return { valid: false, diff --git a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts index fb918f5ce7..df9799b966 100644 --- a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts +++ b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts @@ -962,26 +962,41 @@ export default class OpenAIProvider extends BaseProvider { /** * Check if the model is valid * @param model - The model + * @param stream - Whether to use streaming interface * @returns The validity of the model */ - public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> { + public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> { if (!model) { return { valid: false, error: new Error('No model found') } } const body = { model: model.id, messages: [{ role: 'user', content: 'hi' }], - stream: false + stream } try { await this.checkIsCopilot() console.debug('[checkModel] body', model.id, body) - const response = await this.sdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming) - - return { - valid: Boolean(response?.choices[0].message), - error: null + if (!stream) { + const response = await this.sdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming) + if (!response?.choices[0].message) { + throw new Error('Empty response') + } + return { valid: true, error: null } + } else { + const response: any = await this.sdk.chat.completions.create(body as any) + // 等待整个流式响应结束 + let hasContent = false + for await (const chunk of response) { + if (chunk.choices?.[0]?.delta?.content) { + hasContent = true + } + } + if (hasContent) { + return { valid: true, error: null } + } + throw new Error('Empty streaming response') } } catch (error: any) { return { diff --git a/src/renderer/src/providers/AiProvider/index.ts b/src/renderer/src/providers/AiProvider/index.ts index d8fca2fa1f..954ab68b9c 100644 --- a/src/renderer/src/providers/AiProvider/index.ts +++ b/src/renderer/src/providers/AiProvider/index.ts @@ -59,8 +59,8 @@ export default class AiProvider { return this.sdk.generateText({ prompt, content }) } - public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> { - return this.sdk.check(model) + public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> { + return this.sdk.check(model, stream) } public async models(): Promise { diff --git a/src/renderer/src/services/ModelService.ts b/src/renderer/src/services/ModelService.ts index 5520159864..04867e1d8e 100644 --- a/src/renderer/src/services/ModelService.ts +++ b/src/renderer/src/services/ModelService.ts @@ -82,7 +82,14 @@ export async function checkModel(provider: Provider, model: Model) { return performModelCheck( provider, model, - (ai, model) => ai.check(model), + async (ai, model) => { + const result = await ai.check(model, false) + if (result.valid && !result.error) { + return result + } + // Try streaming check + return ai.check(model, true) + }, ({ valid, error }) => ({ valid, error: error || null }) ) }