mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-05 04:19:02 +08:00
refactor: Move abort controller to utils and update imports
This commit is contained in:
parent
01bf84b8ca
commit
bb02ebe818
@ -24,10 +24,10 @@ import FileManager from '@renderer/services/FileManager'
|
|||||||
import { estimateTextTokens as estimateTxtTokens } from '@renderer/services/TokenService'
|
import { estimateTextTokens as estimateTxtTokens } from '@renderer/services/TokenService'
|
||||||
import { translateText } from '@renderer/services/TranslateService'
|
import { translateText } from '@renderer/services/TranslateService'
|
||||||
import store, { useAppDispatch, useAppSelector } from '@renderer/store'
|
import store, { useAppDispatch, useAppSelector } from '@renderer/store'
|
||||||
import { abortCompletion } from '@renderer/store/abortController'
|
|
||||||
import { setGenerating, setSearching } from '@renderer/store/runtime'
|
import { setGenerating, setSearching } from '@renderer/store/runtime'
|
||||||
import { Assistant, FileType, KnowledgeBase, Message, Model, Topic } from '@renderer/types'
|
import { Assistant, FileType, KnowledgeBase, Message, Model, Topic } from '@renderer/types'
|
||||||
import { classNames, delay, getFileExtension, uuid } from '@renderer/utils'
|
import { classNames, delay, getFileExtension, uuid } from '@renderer/utils'
|
||||||
|
import { abortCompletion } from '@renderer/utils/abortController'
|
||||||
import { documentExts, imageExts, textExts } from '@shared/config/constant'
|
import { documentExts, imageExts, textExts } from '@shared/config/constant'
|
||||||
import { Button, Popconfirm, Tooltip } from 'antd'
|
import { Button, Popconfirm, Tooltip } from 'antd'
|
||||||
import TextArea, { TextAreaRef } from 'antd/es/input/TextArea'
|
import TextArea, { TextAreaRef } from 'antd/es/input/TextArea'
|
||||||
|
|||||||
@ -6,7 +6,6 @@ import i18n from '@renderer/i18n'
|
|||||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
import { EVENT_NAMES } from '@renderer/services/EventService'
|
||||||
import { filterContextMessages } from '@renderer/services/MessagesService'
|
import { filterContextMessages } from '@renderer/services/MessagesService'
|
||||||
import { addAbortController, removeAbortController } from '@renderer/store/abortController'
|
|
||||||
import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
|
import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
import { removeSpecialCharacters } from '@renderer/utils'
|
import { removeSpecialCharacters } from '@renderer/utils'
|
||||||
import { first, flatten, sum, takeRight } from 'lodash'
|
import { first, flatten, sum, takeRight } from 'lodash'
|
||||||
@ -107,13 +106,12 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
const abortController = new AbortController()
|
|
||||||
const { signal } = abortController
|
|
||||||
// 获取最后一条用户消息的 ID 作为 askId
|
|
||||||
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
|
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
|
||||||
if (lastUserMessage?.id) {
|
|
||||||
addAbortController(lastUserMessage.id, () => abortController.abort())
|
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
|
||||||
}
|
const { signal } = abortController
|
||||||
|
|
||||||
return new Promise<void>((resolve, reject) => {
|
return new Promise<void>((resolve, reject) => {
|
||||||
const stream = this.sdk.messages
|
const stream = this.sdk.messages
|
||||||
.stream({ ...body, stream: true }, { signal })
|
.stream({ ...body, stream: true }, { signal })
|
||||||
@ -152,11 +150,7 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
resolve()
|
resolve()
|
||||||
})
|
})
|
||||||
.on('error', (error) => reject(error))
|
.on('error', (error) => reject(error))
|
||||||
}).finally(() => {
|
}).finally(cleanup)
|
||||||
if (lastUserMessage?.id) {
|
|
||||||
removeAbortController(lastUserMessage.id)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
|
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
|
||||||
|
|||||||
@ -3,12 +3,13 @@ import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
|
|||||||
import { getOllamaKeepAliveTime } from '@renderer/hooks/useOllama'
|
import { getOllamaKeepAliveTime } from '@renderer/hooks/useOllama'
|
||||||
import { getKnowledgeReferences } from '@renderer/services/KnowledgeService'
|
import { getKnowledgeReferences } from '@renderer/services/KnowledgeService'
|
||||||
import store from '@renderer/store'
|
import store from '@renderer/store'
|
||||||
import { Assistant, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
|
import type { Assistant, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
import { delay, isJSON, parseJSON } from '@renderer/utils'
|
import { delay, isJSON, parseJSON } from '@renderer/utils'
|
||||||
|
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||||
import { t } from 'i18next'
|
import { t } from 'i18next'
|
||||||
import OpenAI from 'openai'
|
import type OpenAI from 'openai'
|
||||||
|
|
||||||
import { CompletionsParams } from '.'
|
import type { CompletionsParams } from '.'
|
||||||
|
|
||||||
export default abstract class BaseProvider {
|
export default abstract class BaseProvider {
|
||||||
protected provider: Provider
|
protected provider: Provider
|
||||||
@ -135,4 +136,21 @@ export default abstract class BaseProvider {
|
|||||||
}, {}) || {}
|
}, {}) || {}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected createAbortController(messageId?: string) {
|
||||||
|
const abortController = new AbortController()
|
||||||
|
|
||||||
|
if (messageId) {
|
||||||
|
addAbortController(messageId, () => abortController.abort())
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
abortController,
|
||||||
|
cleanup: () => {
|
||||||
|
if (messageId) {
|
||||||
|
removeAbortController(messageId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,7 +16,6 @@ import i18n from '@renderer/i18n'
|
|||||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
import { EVENT_NAMES } from '@renderer/services/EventService'
|
||||||
import { filterContextMessages } from '@renderer/services/MessagesService'
|
import { filterContextMessages } from '@renderer/services/MessagesService'
|
||||||
import { addAbortController, removeAbortController } from '@renderer/store/abortController'
|
|
||||||
import { Assistant, FileType, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
|
import { Assistant, FileType, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
import { removeSpecialCharacters } from '@renderer/utils'
|
import { removeSpecialCharacters } from '@renderer/utils'
|
||||||
import axios from 'axios'
|
import axios from 'axios'
|
||||||
@ -204,19 +203,11 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
const abortController = new AbortController()
|
|
||||||
const { signal } = abortController
|
|
||||||
// 获取最后一条用户消息的 ID 作为 askId
|
|
||||||
const lastUserMessage = userMessages.findLast((m) => m.role === 'user')
|
const lastUserMessage = userMessages.findLast((m) => m.role === 'user')
|
||||||
if (lastUserMessage?.id) {
|
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
|
||||||
addAbortController(lastUserMessage.id, () => abortController.abort())
|
const { signal } = abortController
|
||||||
}
|
|
||||||
|
|
||||||
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }).finally(() => {
|
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }).finally(cleanup)
|
||||||
if (lastUserMessage?.id) {
|
|
||||||
removeAbortController(lastUserMessage.id)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
let time_first_token_millsec = 0
|
let time_first_token_millsec = 0
|
||||||
|
|
||||||
for await (const chunk of userMessagesStream.stream) {
|
for await (const chunk of userMessagesStream.stream) {
|
||||||
|
|||||||
@ -4,7 +4,6 @@ import i18n from '@renderer/i18n'
|
|||||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
import { EVENT_NAMES } from '@renderer/services/EventService'
|
||||||
import { filterContextMessages } from '@renderer/services/MessagesService'
|
import { filterContextMessages } from '@renderer/services/MessagesService'
|
||||||
import { addAbortController, removeAbortController } from '@renderer/store/abortController'
|
|
||||||
import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
|
import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
import { removeSpecialCharacters } from '@renderer/utils'
|
import { removeSpecialCharacters } from '@renderer/utils'
|
||||||
import { takeRight } from 'lodash'
|
import { takeRight } from 'lodash'
|
||||||
@ -214,14 +213,9 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
let time_first_token_millsec = 0
|
let time_first_token_millsec = 0
|
||||||
let time_first_content_millsec = 0
|
let time_first_content_millsec = 0
|
||||||
const start_time_millsec = new Date().getTime()
|
const start_time_millsec = new Date().getTime()
|
||||||
const abortController = new AbortController()
|
|
||||||
const { signal } = abortController
|
|
||||||
|
|
||||||
// 获取最后一条用户消息的 ID 作为 askId
|
|
||||||
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
|
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
|
||||||
if (lastUserMessage?.id) {
|
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
|
||||||
addAbortController(lastUserMessage.id, () => abortController.abort())
|
const { signal } = abortController
|
||||||
}
|
|
||||||
|
|
||||||
const stream = await this.sdk.chat.completions
|
const stream = await this.sdk.chat.completions
|
||||||
// @ts-ignore key is not typed
|
// @ts-ignore key is not typed
|
||||||
@ -243,11 +237,7 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
signal
|
signal
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
.finally(() => {
|
.finally(cleanup)
|
||||||
if (lastUserMessage?.id) {
|
|
||||||
removeAbortController(lastUserMessage.id)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if (!isSupportStreamOutput()) {
|
if (!isSupportStreamOutput()) {
|
||||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import i18n from '@renderer/i18n'
|
import i18n from '@renderer/i18n'
|
||||||
import store from '@renderer/store'
|
import store from '@renderer/store'
|
||||||
import { addAbortController } from '@renderer/store/abortController'
|
|
||||||
import { setGenerating } from '@renderer/store/runtime'
|
import { setGenerating } from '@renderer/store/runtime'
|
||||||
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types'
|
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
|
import { addAbortController } from '@renderer/utils/abortController'
|
||||||
import { formatMessageError } from '@renderer/utils/error'
|
import { formatMessageError } from '@renderer/utils/error'
|
||||||
import { isEmpty } from 'lodash'
|
import { isEmpty } from 'lodash'
|
||||||
|
|
||||||
|
|||||||
@ -1,25 +0,0 @@
|
|||||||
export const abortMap = new Map<string, () => void>()
|
|
||||||
|
|
||||||
export const addAbortController = (messageId: string, abortFn: () => void) => {
|
|
||||||
let callback = abortFn
|
|
||||||
const existingCallback = abortMap.get(messageId)
|
|
||||||
if (existingCallback) {
|
|
||||||
callback = () => {
|
|
||||||
existingCallback?.()
|
|
||||||
abortFn()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
abortMap.set(messageId, callback)
|
|
||||||
}
|
|
||||||
|
|
||||||
export const removeAbortController = (messageId: string) => {
|
|
||||||
abortMap.delete(messageId)
|
|
||||||
}
|
|
||||||
|
|
||||||
export const abortCompletion = (messageId: string) => {
|
|
||||||
const abortFn = abortMap.get(messageId)
|
|
||||||
if (abortFn) {
|
|
||||||
abortFn()
|
|
||||||
removeAbortController(messageId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
25
src/renderer/src/utils/abortController.ts
Normal file
25
src/renderer/src/utils/abortController.ts
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
export const abortMap = new Map<string, () => void>()
|
||||||
|
|
||||||
|
export const addAbortController = (id: string, abortFn: () => void) => {
|
||||||
|
let callback = abortFn
|
||||||
|
const existingCallback = abortMap.get(id)
|
||||||
|
if (existingCallback) {
|
||||||
|
callback = () => {
|
||||||
|
existingCallback?.()
|
||||||
|
abortFn()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
abortMap.set(id, callback)
|
||||||
|
}
|
||||||
|
|
||||||
|
export const removeAbortController = (id: string) => {
|
||||||
|
abortMap.delete(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
export const abortCompletion = (id: string) => {
|
||||||
|
const abortFn = abortMap.get(id)
|
||||||
|
if (abortFn) {
|
||||||
|
abortFn()
|
||||||
|
removeAbortController(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user