mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-01 01:30:51 +08:00
fix(MessageOperations): Improve message pause functionality and error handling
- Update pauseMessage method to handle both askId and messageId - Add loading state reset when pausing messages - Enhance error handling in providers with abort error detection - Modify ApiService to handle aborted requests gracefully - Add comprehensive isAbortError utility function
This commit is contained in:
parent
87ae293636
commit
e2e8e1efe7
@ -10,6 +10,7 @@ import {
|
|||||||
selectTopicLoading,
|
selectTopicLoading,
|
||||||
selectTopicMessages,
|
selectTopicMessages,
|
||||||
setStreamMessage,
|
setStreamMessage,
|
||||||
|
setTopicLoading,
|
||||||
updateMessage,
|
updateMessage,
|
||||||
updateMessages
|
updateMessages
|
||||||
} from '@renderer/store/messages'
|
} from '@renderer/store/messages'
|
||||||
@ -155,14 +156,18 @@ export function useMessageOperations(topic: Topic) {
|
|||||||
* 暂停消息生成
|
* 暂停消息生成
|
||||||
*/
|
*/
|
||||||
const pauseMessage = useCallback(
|
const pauseMessage = useCallback(
|
||||||
async (messageId: string) => {
|
// 存的是用户消息的id,也就是助手消息的askId
|
||||||
|
async (askId: string, messageId: string) => {
|
||||||
// 1. 调用 abort
|
// 1. 调用 abort
|
||||||
abortCompletion(messageId)
|
abortCompletion(askId)
|
||||||
|
console.log('messageId', messageId)
|
||||||
// 2. 更新消息状态
|
// 2. 更新消息状态
|
||||||
await editMessage(messageId, { status: 'paused' })
|
await editMessage(messageId, { status: 'paused' })
|
||||||
|
|
||||||
// 3. 清理流式消息
|
// 3.更改loading状态
|
||||||
|
dispatch(setTopicLoading({ topicId: topic.id, loading: false }))
|
||||||
|
|
||||||
|
// 4. 清理流式消息
|
||||||
clearStreamMessageAction(messageId)
|
clearStreamMessageAction(messageId)
|
||||||
},
|
},
|
||||||
[editMessage, clearStreamMessageAction]
|
[editMessage, clearStreamMessageAction]
|
||||||
@ -173,15 +178,13 @@ export function useMessageOperations(topic: Topic) {
|
|||||||
const streamMessages = store.getState().messages.streamMessagesByTopic[topic.id]
|
const streamMessages = store.getState().messages.streamMessagesByTopic[topic.id]
|
||||||
if (streamMessages) {
|
if (streamMessages) {
|
||||||
// 获取所有流式消息的 askId
|
// 获取所有流式消息的 askId
|
||||||
const askIds = new Set(
|
const askIds = Object.values(streamMessages)
|
||||||
Object.values(streamMessages)
|
.map((msg) => [msg.askId, msg.id])
|
||||||
.map((msg) => msg.askId)
|
.filter(([askId, id]) => askId && id)
|
||||||
.filter(Boolean)
|
|
||||||
)
|
|
||||||
|
|
||||||
// 对每个 askId 执行暂停
|
// 对每个 askId 执行暂停
|
||||||
for (const askId of askIds) {
|
for (const [askId, id] of askIds) {
|
||||||
await pauseMessage(askId)
|
await pauseMessage(askId, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, [topic.id, pauseMessage])
|
}, [topic.id, pauseMessage])
|
||||||
|
|||||||
@ -208,7 +208,7 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
const { signal } = abortController
|
const { signal } = abortController
|
||||||
const toolResponses: MCPToolResponse[] = []
|
const toolResponses: MCPToolResponse[] = []
|
||||||
|
|
||||||
const processStream = async (body: MessageCreateParamsNonStreaming) => {
|
const processStream = (body: MessageCreateParamsNonStreaming) => {
|
||||||
return new Promise<void>((resolve, reject) => {
|
return new Promise<void>((resolve, reject) => {
|
||||||
const toolCalls: ToolUseBlock[] = []
|
const toolCalls: ToolUseBlock[] = []
|
||||||
let hasThinkingContent = false
|
let hasThinkingContent = false
|
||||||
@ -326,7 +326,12 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
await processStream(body).finally(cleanup)
|
await processStream(body)
|
||||||
|
.catch((error) => {
|
||||||
|
// 不加这个错误抛不出来
|
||||||
|
throw error
|
||||||
|
})
|
||||||
|
.finally(cleanup)
|
||||||
}
|
}
|
||||||
|
|
||||||
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
|
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
|
||||||
|
|||||||
@ -160,13 +160,20 @@ export default abstract class BaseProvider {
|
|||||||
addAbortController(messageId, () => abortController.abort())
|
addAbortController(messageId, () => abortController.abort())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const cleanup = () => {
|
||||||
|
if (messageId) {
|
||||||
|
removeAbortController(messageId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
abortController.signal.addEventListener('abort', () => {
|
||||||
|
// 兼容
|
||||||
|
cleanup()
|
||||||
|
})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
abortController,
|
abortController,
|
||||||
cleanup: () => {
|
cleanup
|
||||||
if (messageId) {
|
|
||||||
removeAbortController(messageId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -197,9 +197,10 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
const messageContents = await this.getMessageContents(userLastMessage!)
|
const messageContents = await this.getMessageContents(userLastMessage!)
|
||||||
|
|
||||||
const start_time_millsec = new Date().getTime()
|
const start_time_millsec = new Date().getTime()
|
||||||
|
const { abortController, cleanup } = this.createAbortController(userLastMessage?.id)
|
||||||
|
const { signal } = abortController
|
||||||
if (!streamOutput) {
|
if (!streamOutput) {
|
||||||
const { response } = await chat.sendMessage(messageContents.parts)
|
const { response } = await chat.sendMessage(messageContents.parts, { signal })
|
||||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||||
onChunk({
|
onChunk({
|
||||||
text: response.candidates?.[0].content.parts[0].text,
|
text: response.candidates?.[0].content.parts[0].text,
|
||||||
@ -218,13 +219,8 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
const lastUserMessage = userMessages.findLast((m) => m.role === 'user')
|
|
||||||
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
|
|
||||||
const { signal } = abortController
|
|
||||||
|
|
||||||
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal })
|
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal })
|
||||||
let time_first_token_millsec = 0
|
let time_first_token_millsec = 0
|
||||||
|
|
||||||
const processStream = async (stream: GenerateContentStreamResult) => {
|
const processStream = async (stream: GenerateContentStreamResult) => {
|
||||||
for await (const chunk of stream.stream) {
|
for await (const chunk of stream.stream) {
|
||||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
||||||
@ -297,7 +293,6 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
await processStream(userMessagesStream).finally(cleanup)
|
await processStream(userMessagesStream).finally(cleanup)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import i18n from '@renderer/i18n'
|
|||||||
import store from '@renderer/store'
|
import store from '@renderer/store'
|
||||||
import { setGenerating } from '@renderer/store/runtime'
|
import { setGenerating } from '@renderer/store/runtime'
|
||||||
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types'
|
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
import { formatMessageError } from '@renderer/utils/error'
|
import { formatMessageError, isAbortError } from '@renderer/utils/error'
|
||||||
import { cloneDeep, findLast, isEmpty } from 'lodash'
|
import { cloneDeep, findLast, isEmpty } from 'lodash'
|
||||||
|
|
||||||
import AiProvider from '../providers/AiProvider'
|
import AiProvider from '../providers/AiProvider'
|
||||||
@ -116,12 +116,18 @@ export async function fetchChatCompletion({
|
|||||||
// Set metrics.completion_tokens
|
// Set metrics.completion_tokens
|
||||||
if (message.metrics && message?.usage?.completion_tokens) {
|
if (message.metrics && message?.usage?.completion_tokens) {
|
||||||
if (!message.metrics?.completion_tokens) {
|
if (!message.metrics?.completion_tokens) {
|
||||||
message.metrics.completion_tokens = message.usage.completion_tokens
|
message = {
|
||||||
|
...message,
|
||||||
|
metrics: {
|
||||||
|
...message.metrics,
|
||||||
|
completion_tokens: message.usage.completion_tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
console.log('error', error)
|
if (isAbortError(error)) return
|
||||||
message.status = 'error'
|
message.status = 'error'
|
||||||
message.error = formatMessageError(error)
|
message.error = formatMessageError(error)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -345,7 +345,6 @@ export const sendMessage =
|
|||||||
onResponse: async (msg) => {
|
onResponse: async (msg) => {
|
||||||
// 允许在回调外维护一个最新的消息状态,每次都更新这个对象,但只通过节流函数分发到Redux
|
// 允许在回调外维护一个最新的消息状态,每次都更新这个对象,但只通过节流函数分发到Redux
|
||||||
const updateMessage = { ...msg, status: msg.status || 'pending', content: msg.content || '' }
|
const updateMessage = { ...msg, status: msg.status || 'pending', content: msg.content || '' }
|
||||||
// 创建节流函数,限制Redux更新频率
|
|
||||||
// 使用节流函数更新Redux
|
// 使用节流函数更新Redux
|
||||||
throttledDispatch(
|
throttledDispatch(
|
||||||
assistant,
|
assistant,
|
||||||
|
|||||||
@ -62,3 +62,30 @@ export function formatMessageError(error: any): Record<string, any> {
|
|||||||
export function getErrorMessage(error: any): string {
|
export function getErrorMessage(error: any): string {
|
||||||
return error?.message || error?.toString() || ''
|
return error?.message || error?.toString() || ''
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const isAbortError = (error: any): boolean => {
|
||||||
|
// 检查错误消息
|
||||||
|
if (error?.message === 'Request was aborted.') {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否为 DOMException 类型的中止错误
|
||||||
|
if (error instanceof DOMException && error.name === 'AbortError') {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
console.log(
|
||||||
|
typeof error === 'object',
|
||||||
|
error.message === 'Request was aborted.' || error?.message?.includes('signal is aborted without reason')
|
||||||
|
)
|
||||||
|
// 检查 OpenAI 特定的错误结构
|
||||||
|
if (
|
||||||
|
(error &&
|
||||||
|
typeof error === 'object' &&
|
||||||
|
(error.message === 'Request was aborted.' || error?.message?.includes('signal is aborted without reason'))) ||
|
||||||
|
error.stack?.includes('OpenAI.makeRequest')
|
||||||
|
) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user