From bb02ebe8187ca8c6589b7f5bcc5dbe0030e6fd32 Mon Sep 17 00:00:00 2001 From: MyPrototypeWhat Date: Fri, 21 Feb 2025 12:28:08 +0800 Subject: [PATCH] refactor: Move abort controller to utils and update imports --- .../src/pages/home/Inputbar/Inputbar.tsx | 2 +- .../src/providers/AnthropicProvider.ts | 18 +++++-------- src/renderer/src/providers/BaseProvider.ts | 24 +++++++++++++++--- src/renderer/src/providers/GeminiProvider.ts | 15 +++-------- src/renderer/src/providers/OpenAIProvider.ts | 16 +++--------- src/renderer/src/services/ApiService.ts | 2 +- src/renderer/src/store/abortController.ts | 25 ------------------- src/renderer/src/utils/abortController.ts | 25 +++++++++++++++++++ 8 files changed, 60 insertions(+), 67 deletions(-) delete mode 100644 src/renderer/src/store/abortController.ts create mode 100644 src/renderer/src/utils/abortController.ts diff --git a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx index 9edbf52b25..f4fca7209a 100644 --- a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx @@ -24,10 +24,10 @@ import FileManager from '@renderer/services/FileManager' import { estimateTextTokens as estimateTxtTokens } from '@renderer/services/TokenService' import { translateText } from '@renderer/services/TranslateService' import store, { useAppDispatch, useAppSelector } from '@renderer/store' -import { abortCompletion } from '@renderer/store/abortController' import { setGenerating, setSearching } from '@renderer/store/runtime' import { Assistant, FileType, KnowledgeBase, Message, Model, Topic } from '@renderer/types' import { classNames, delay, getFileExtension, uuid } from '@renderer/utils' +import { abortCompletion } from '@renderer/utils/abortController' import { documentExts, imageExts, textExts } from '@shared/config/constant' import { Button, Popconfirm, Tooltip } from 'antd' import TextArea, { TextAreaRef } from 'antd/es/input/TextArea' diff --git a/src/renderer/src/providers/AnthropicProvider.ts b/src/renderer/src/providers/AnthropicProvider.ts index d844a19a9b..cac42ac3cc 100644 --- a/src/renderer/src/providers/AnthropicProvider.ts +++ b/src/renderer/src/providers/AnthropicProvider.ts @@ -6,7 +6,6 @@ import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { EVENT_NAMES } from '@renderer/services/EventService' import { filterContextMessages } from '@renderer/services/MessagesService' -import { addAbortController, removeAbortController } from '@renderer/store/abortController' import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types' import { removeSpecialCharacters } from '@renderer/utils' import { first, flatten, sum, takeRight } from 'lodash' @@ -107,13 +106,12 @@ export default class AnthropicProvider extends BaseProvider { } }) } - const abortController = new AbortController() - const { signal } = abortController - // 获取最后一条用户消息的 ID 作为 askId + const lastUserMessage = _messages.findLast((m) => m.role === 'user') - if (lastUserMessage?.id) { - addAbortController(lastUserMessage.id, () => abortController.abort()) - } + + const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id) + const { signal } = abortController + return new Promise((resolve, reject) => { const stream = this.sdk.messages .stream({ ...body, stream: true }, { signal }) @@ -152,11 +150,7 @@ export default class AnthropicProvider extends BaseProvider { resolve() }) .on('error', (error) => reject(error)) - }).finally(() => { - if (lastUserMessage?.id) { - removeAbortController(lastUserMessage.id) - } - }) + }).finally(cleanup) } public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) { diff --git a/src/renderer/src/providers/BaseProvider.ts b/src/renderer/src/providers/BaseProvider.ts index 619b2d68c6..8618286f84 100644 --- a/src/renderer/src/providers/BaseProvider.ts +++ b/src/renderer/src/providers/BaseProvider.ts @@ -3,12 +3,13 @@ import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio' import { getOllamaKeepAliveTime } from '@renderer/hooks/useOllama' import { getKnowledgeReferences } from '@renderer/services/KnowledgeService' import store from '@renderer/store' -import { Assistant, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types' +import type { Assistant, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types' import { delay, isJSON, parseJSON } from '@renderer/utils' +import { addAbortController, removeAbortController } from '@renderer/utils/abortController' import { t } from 'i18next' -import OpenAI from 'openai' +import type OpenAI from 'openai' -import { CompletionsParams } from '.' +import type { CompletionsParams } from '.' export default abstract class BaseProvider { protected provider: Provider @@ -135,4 +136,21 @@ export default abstract class BaseProvider { }, {}) || {} ) } + + protected createAbortController(messageId?: string) { + const abortController = new AbortController() + + if (messageId) { + addAbortController(messageId, () => abortController.abort()) + } + + return { + abortController, + cleanup: () => { + if (messageId) { + removeAbortController(messageId) + } + } + } + } } diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index b2e38b90a3..d712c3d978 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -16,7 +16,6 @@ import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { EVENT_NAMES } from '@renderer/services/EventService' import { filterContextMessages } from '@renderer/services/MessagesService' -import { addAbortController, removeAbortController } from '@renderer/store/abortController' import { Assistant, FileType, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types' import { removeSpecialCharacters } from '@renderer/utils' import axios from 'axios' @@ -204,19 +203,11 @@ export default class GeminiProvider extends BaseProvider { return } - const abortController = new AbortController() - const { signal } = abortController - // 获取最后一条用户消息的 ID 作为 askId const lastUserMessage = userMessages.findLast((m) => m.role === 'user') - if (lastUserMessage?.id) { - addAbortController(lastUserMessage.id, () => abortController.abort()) - } + const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id) + const { signal } = abortController - const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }).finally(() => { - if (lastUserMessage?.id) { - removeAbortController(lastUserMessage.id) - } - }) + const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }).finally(cleanup) let time_first_token_millsec = 0 for await (const chunk of userMessagesStream.stream) { diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index 76cc490970..74d1319481 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -4,7 +4,6 @@ import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { EVENT_NAMES } from '@renderer/services/EventService' import { filterContextMessages } from '@renderer/services/MessagesService' -import { addAbortController, removeAbortController } from '@renderer/store/abortController' import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types' import { removeSpecialCharacters } from '@renderer/utils' import { takeRight } from 'lodash' @@ -214,14 +213,9 @@ export default class OpenAIProvider extends BaseProvider { let time_first_token_millsec = 0 let time_first_content_millsec = 0 const start_time_millsec = new Date().getTime() - const abortController = new AbortController() - const { signal } = abortController - - // 获取最后一条用户消息的 ID 作为 askId const lastUserMessage = _messages.findLast((m) => m.role === 'user') - if (lastUserMessage?.id) { - addAbortController(lastUserMessage.id, () => abortController.abort()) - } + const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id) + const { signal } = abortController const stream = await this.sdk.chat.completions // @ts-ignore key is not typed @@ -243,11 +237,7 @@ export default class OpenAIProvider extends BaseProvider { signal } ) - .finally(() => { - if (lastUserMessage?.id) { - removeAbortController(lastUserMessage.id) - } - }) + .finally(cleanup) if (!isSupportStreamOutput()) { const time_completion_millsec = new Date().getTime() - start_time_millsec diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 88f53051c3..6c77fba83e 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -1,8 +1,8 @@ import i18n from '@renderer/i18n' import store from '@renderer/store' -import { addAbortController } from '@renderer/store/abortController' import { setGenerating } from '@renderer/store/runtime' import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types' +import { addAbortController } from '@renderer/utils/abortController' import { formatMessageError } from '@renderer/utils/error' import { isEmpty } from 'lodash' diff --git a/src/renderer/src/store/abortController.ts b/src/renderer/src/store/abortController.ts deleted file mode 100644 index 2978e8f24b..0000000000 --- a/src/renderer/src/store/abortController.ts +++ /dev/null @@ -1,25 +0,0 @@ -export const abortMap = new Map void>() - -export const addAbortController = (messageId: string, abortFn: () => void) => { - let callback = abortFn - const existingCallback = abortMap.get(messageId) - if (existingCallback) { - callback = () => { - existingCallback?.() - abortFn() - } - } - abortMap.set(messageId, callback) -} - -export const removeAbortController = (messageId: string) => { - abortMap.delete(messageId) -} - -export const abortCompletion = (messageId: string) => { - const abortFn = abortMap.get(messageId) - if (abortFn) { - abortFn() - removeAbortController(messageId) - } -} diff --git a/src/renderer/src/utils/abortController.ts b/src/renderer/src/utils/abortController.ts new file mode 100644 index 0000000000..98195d801e --- /dev/null +++ b/src/renderer/src/utils/abortController.ts @@ -0,0 +1,25 @@ +export const abortMap = new Map void>() + +export const addAbortController = (id: string, abortFn: () => void) => { + let callback = abortFn + const existingCallback = abortMap.get(id) + if (existingCallback) { + callback = () => { + existingCallback?.() + abortFn() + } + } + abortMap.set(id, callback) +} + +export const removeAbortController = (id: string) => { + abortMap.delete(id) +} + +export const abortCompletion = (id: string) => { + const abortFn = abortMap.get(id) + if (abortFn) { + abortFn() + removeAbortController(id) + } +}