diff --git a/src/renderer/src/providers/WebSearchProvider/BaseWebSearchProvider.ts b/src/renderer/src/providers/WebSearchProvider/BaseWebSearchProvider.ts index 9d7ced3389..031ad88de8 100644 --- a/src/renderer/src/providers/WebSearchProvider/BaseWebSearchProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/BaseWebSearchProvider.ts @@ -10,7 +10,11 @@ export default abstract class BaseWebSearchProvider { this.provider = provider this.apiKey = this.getApiKey() } - abstract search(query: string, websearch: WebSearchState): Promise + abstract search( + query: string, + websearch: WebSearchState, + httpOptions?: RequestInit + ): Promise public getApiKey() { const keys = this.provider.apiKey?.split(',').map((key) => key.trim()) || [] diff --git a/src/renderer/src/providers/WebSearchProvider/LocalSearchProvider.ts b/src/renderer/src/providers/WebSearchProvider/LocalSearchProvider.ts index bdec900573..20d041cae3 100644 --- a/src/renderer/src/providers/WebSearchProvider/LocalSearchProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/LocalSearchProvider.ts @@ -1,6 +1,8 @@ import { nanoid } from '@reduxjs/toolkit' import { WebSearchState } from '@renderer/store/websearch' import { WebSearchProvider, WebSearchProviderResponse, WebSearchProviderResult } from '@renderer/types' +import { createAbortPromise } from '@renderer/utils/abortController' +import { isAbortError } from '@renderer/utils/error' import { fetchWebContent, noContent } from '@renderer/utils/fetch' import BaseWebSearchProvider from './BaseWebSearchProvider' @@ -18,7 +20,11 @@ export default class LocalSearchProvider extends BaseWebSearchProvider { super(provider) } - public async search(query: string, websearch: WebSearchState): Promise { + public async search( + query: string, + websearch: WebSearchState, + httpOptions?: RequestInit + ): Promise { const uid = nanoid() try { if (!query.trim()) { @@ -30,7 +36,13 @@ export default class LocalSearchProvider extends BaseWebSearchProvider { const cleanedQuery = query.split('\r\n')[1] ?? query const url = this.provider.url.replace('%s', encodeURIComponent(cleanedQuery)) - const content = await window.api.searchService.openUrlInSearchWindow(uid, url) + let content: string = '' + const promisesToRace: [Promise] = [window.api.searchService.openUrlInSearchWindow(uid, url)] + if (httpOptions?.signal) { + const abortPromise = createAbortPromise(httpOptions.signal, promisesToRace[0]) + promisesToRace.push(abortPromise) + } + content = await Promise.race(promisesToRace) // Parse the content to extract URLs and metadata const searchItems = this.parseValidUrls(content).slice(0, websearch.maxResults) @@ -43,7 +55,7 @@ export default class LocalSearchProvider extends BaseWebSearchProvider { // Fetch content for each URL concurrently const fetchPromises = validItems.map(async (item) => { // console.log(`Fetching content for ${item.url}...`) - const result = await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser) + const result = await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser, httpOptions) if (websearch.contentLimit && result.content.length > websearch.contentLimit) { result.content = result.content.slice(0, websearch.contentLimit) + '...' } @@ -58,6 +70,9 @@ export default class LocalSearchProvider extends BaseWebSearchProvider { results: results.filter((result) => result.content != noContent) } } catch (error) { + if (isAbortError(error)) { + throw error + } console.error('Local search failed:', error) throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`) } finally { diff --git a/src/renderer/src/providers/WebSearchProvider/index.ts b/src/renderer/src/providers/WebSearchProvider/index.ts index dd6d414d48..9391d8d70e 100644 --- a/src/renderer/src/providers/WebSearchProvider/index.ts +++ b/src/renderer/src/providers/WebSearchProvider/index.ts @@ -7,11 +7,16 @@ import WebSearchProviderFactory from './WebSearchProviderFactory' export default class WebSearchEngineProvider { private sdk: BaseWebSearchProvider + constructor(provider: WebSearchProvider) { this.sdk = WebSearchProviderFactory.create(provider) } - public async search(query: string, websearch: WebSearchState): Promise { - const result = await this.sdk.search(query, websearch) + public async search( + query: string, + websearch: WebSearchState, + httpOptions?: RequestInit + ): Promise { + const result = await this.sdk.search(query, websearch, httpOptions) const filteredResult = await filterResultWithBlacklist(result, websearch) return filteredResult diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index c801a50b89..670eeb8a22 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -145,8 +145,8 @@ async function fetchExternalTool( source: WebSearchSource.WEBSEARCH } } catch (error) { - console.error('Web search failed:', error) if (isAbortError(error)) throw error + console.error('Web search failed:', error) return } } @@ -251,8 +251,8 @@ async function fetchExternalTool( return { mcpTools } } catch (error) { - console.error('Tool execution failed:', error) if (isAbortError(error)) throw error + console.error('Tool execution failed:', error) // 发送错误状态 if (willUseTools) { diff --git a/src/renderer/src/services/WebSearchService.ts b/src/renderer/src/services/WebSearchService.ts index acdade4c4b..3f3aeddae6 100644 --- a/src/renderer/src/services/WebSearchService.ts +++ b/src/renderer/src/services/WebSearchService.ts @@ -97,7 +97,11 @@ class WebSearchService { * @param query 搜索查询 * @returns 搜索响应 */ - public async search(provider: WebSearchProvider, query: string): Promise { + public async search( + provider: WebSearchProvider, + query: string, + httpOptions?: RequestInit + ): Promise { const websearch = this.getWebSearchState() const webSearchEngine = new WebSearchEngineProvider(provider) @@ -107,12 +111,12 @@ class WebSearchService { formattedQuery = `today is ${dayjs().format('YYYY-MM-DD')} \r\n ${query}` } - try { - return await webSearchEngine.search(formattedQuery, websearch) - } catch (error) { - console.error('Search failed:', error) - throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`) - } + // try { + return await webSearchEngine.search(formattedQuery, websearch, httpOptions) + // } catch (error) { + // console.error('Search failed:', error) + // throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`) + // } } /** @@ -136,42 +140,41 @@ class WebSearchService { webSearchProvider: WebSearchProvider, extractResults: ExtractResults ): Promise { - try { - // 检查 websearch 和 question 是否有效 - if (!extractResults.websearch?.question || extractResults.websearch.question.length === 0) { - console.log('No valid question found in extractResults.websearch') - return { results: [] } - } + // 检查 websearch 和 question 是否有效 + if (!extractResults.websearch?.question || extractResults.websearch.question.length === 0) { + console.log('No valid question found in extractResults.websearch') + return { results: [] } + } - const questions = extractResults.websearch.question - const links = extractResults.websearch.links - const firstQuestion = questions[0] - - if (firstQuestion === 'summarize' && links && links.length > 0) { - const contents = await fetchWebContents(links, undefined, undefined, this.signal) - return { - query: 'summaries', - results: contents - } - } - const searchPromises = questions.map((q) => this.search(webSearchProvider, q)) - const searchResults = await Promise.allSettled(searchPromises) - const aggregatedResults: any[] = [] - - searchResults.forEach((result) => { - if (result.status === 'fulfilled') { - if (result.value.results) { - aggregatedResults.push(...result.value.results) - } - } + const questions = extractResults.websearch.question + const links = extractResults.websearch.links + const firstQuestion = questions[0] + if (firstQuestion === 'summarize' && links && links.length > 0) { + const contents = await fetchWebContents(links, undefined, undefined, { + signal: this.signal }) return { - query: questions.join(' | '), - results: aggregatedResults + query: 'summaries', + results: contents } - } catch (error) { - console.error('Failed to process enhanced search:', error) - return { results: [] } + } + const searchPromises = questions.map((q) => this.search(webSearchProvider, q, { signal: this.signal })) + const searchResults = await Promise.allSettled(searchPromises) + const aggregatedResults: any[] = [] + + searchResults.forEach((result) => { + if (result.status === 'fulfilled') { + if (result.value.results) { + aggregatedResults.push(...result.value.results) + } + } + if (result.status === 'rejected') { + throw result.reason + } + }) + return { + query: questions.join(' | '), + results: aggregatedResults } } } diff --git a/src/renderer/src/store/thunk/messageThunk.ts b/src/renderer/src/store/thunk/messageThunk.ts index 20de6d804b..8094a5ff29 100644 --- a/src/renderer/src/store/thunk/messageThunk.ts +++ b/src/renderer/src/store/thunk/messageThunk.ts @@ -251,7 +251,7 @@ const fetchAndProcessAssistantResponseImpl = async ( let mainTextBlockId: string | null = null const toolCallIdToBlockIdMap = new Map() - const handleBlockTransition = (newBlock: MessageBlock, newBlockType: MessageBlockType) => { + const handleBlockTransition = async (newBlock: MessageBlock, newBlockType: MessageBlockType) => { lastBlockId = newBlock.id lastBlockType = newBlockType if (newBlockType !== MessageBlockType.MAIN_TEXT) { @@ -279,9 +279,12 @@ const fetchAndProcessAssistantResponseImpl = async ( const currentState = getState() const updatedMessage = currentState.messages.entities[assistantMsgId] if (updatedMessage) { - saveUpdatesToDB(assistantMsgId, topicId, { blocks: updatedMessage.blocks, status: updatedMessage.status }, [ - newBlock - ]) + await saveUpdatesToDB( + assistantMsgId, + topicId, + { blocks: updatedMessage.blocks, status: updatedMessage.status }, + [newBlock] + ) } else { console.error(`[handleBlockTransition] Failed to get updated message ${assistantMsgId} from state for DB save.`) } @@ -530,10 +533,11 @@ const fetchAndProcessAssistantResponseImpl = async ( console.error('[onImageGenerated] Last block was not an Image block or ID is missing.') } }, - onError: (error) => { + onError: async (error) => { console.dir(error, { depth: null }) + const isErrorTypeAbort = isAbortError(error) let pauseErrorLanguagePlaceholder = '' - if (isAbortError(error)) { + if (isErrorTypeAbort) { pauseErrorLanguagePlaceholder = 'pause_placeholder' } @@ -548,16 +552,16 @@ const fetchAndProcessAssistantResponseImpl = async ( if (lastBlockId) { // 更改上一个block的状态为ERROR const changes: Partial = { - status: MessageBlockStatus.ERROR + status: isErrorTypeAbort ? MessageBlockStatus.PAUSED : MessageBlockStatus.ERROR } dispatch(updateOneBlock({ id: lastBlockId, changes })) saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState) } const errorBlock = createErrorBlock(assistantMsgId, serializableError, { status: MessageBlockStatus.SUCCESS }) - handleBlockTransition(errorBlock, MessageBlockType.ERROR) + await handleBlockTransition(errorBlock, MessageBlockType.ERROR) const messageErrorUpdate = { - status: isAbortError(error) ? AssistantMessageStatus.SUCCESS : AssistantMessageStatus.ERROR + status: isErrorTypeAbort ? AssistantMessageStatus.SUCCESS : AssistantMessageStatus.ERROR } dispatch(newMessagesActions.updateMessage({ topicId, messageId: assistantMsgId, updates: messageErrorUpdate })) @@ -566,7 +570,7 @@ const fetchAndProcessAssistantResponseImpl = async ( EventEmitter.emit(EVENT_NAMES.MESSAGE_COMPLETE, { id: assistantMsgId, topicId, - status: isAbortError(error) ? 'pause' : 'error', + status: isErrorTypeAbort ? 'pause' : 'error', error: error.message }) }, @@ -838,14 +842,18 @@ export const resendMessageThunk = (m) => m.askId === userMessageToResend.id && m.role === 'assistant' ) + const resetDataList: Message[] = [] + if (assistantMessagesToReset.length === 0) { - console.warn( - `[resendMessageThunk] No assistant responses found for user message ${userMessageToResend.id}. Nothing to regenerate.` - ) - return + // 没有用户消息,就创建一个 + const assistantMessage = createAssistantMessage(assistant.id, topicId, { + askId: userMessageToResend.id, + model: assistant.model + }) + resetDataList.push(assistantMessage) + dispatch(newMessagesActions.addMessage({ topicId, message: assistantMessage })) } - const resetDataList: { resetMsg: Message }[] = [] const allBlockIdsToDelete: string[] = [] const messagesToUpdateInRedux: { topicId: string; messageId: string; updates: Partial }[] = [] @@ -856,7 +864,7 @@ export const resendMessageThunk = ...(assistantMessagesToReset.length === 1 ? { model: assistant.model } : {}) }) - resetDataList.push({ resetMsg }) + resetDataList.push(resetMsg) allBlockIdsToDelete.push(...blockIdsToDelete) messagesToUpdateInRedux.push({ topicId, messageId: resetMsg.id, updates: resetMsg }) } @@ -877,7 +885,7 @@ export const resendMessageThunk = } const queue = getTopicQueue(topicId) - for (const { resetMsg } of resetDataList) { + for (const resetMsg of resetDataList) { const assistantConfigForThisRegen = { ...assistant, ...(resetMsg.model ? { model: resetMsg.model } : {}) @@ -1071,25 +1079,25 @@ export const initiateTranslationThunk = // --- Thunk to update the translation block with new content --- export const updateTranslationBlockThunk = (blockId: string, accumulatedText: string, isComplete: boolean = false) => - async (dispatch: AppDispatch) => { - console.log(`[updateTranslationBlockThunk] 更新翻译块 ${blockId}, isComplete: ${isComplete}`) - try { - const status = isComplete ? MessageBlockStatus.SUCCESS : MessageBlockStatus.STREAMING - const changes: Partial = { - content: accumulatedText, - status: status - } - - // 更新Redux状态 - dispatch(updateOneBlock({ id: blockId, changes })) - - // 更新数据库 - await db.message_blocks.update(blockId, changes) - console.log(`[updateTranslationBlockThunk] Successfully updated translation block ${blockId}.`) - } catch (error) { - console.error(`[updateTranslationBlockThunk] Failed to update translation block ${blockId}:`, error) + async (dispatch: AppDispatch) => { + console.log(`[updateTranslationBlockThunk] 更新翻译块 ${blockId}, isComplete: ${isComplete}`) + try { + const status = isComplete ? MessageBlockStatus.SUCCESS : MessageBlockStatus.STREAMING + const changes: Partial = { + content: accumulatedText, + status: status } + + // 更新Redux状态 + dispatch(updateOneBlock({ id: blockId, changes })) + + // 更新数据库 + await db.message_blocks.update(blockId, changes) + console.log(`[updateTranslationBlockThunk] Successfully updated translation block ${blockId}.`) + } catch (error) { + console.error(`[updateTranslationBlockThunk] Failed to update translation block ${blockId}:`, error) } + } /** * Thunk to append a new assistant response (using a potentially different model) diff --git a/src/renderer/src/utils/abortController.ts b/src/renderer/src/utils/abortController.ts index 8d3f59bd70..be2387626c 100644 --- a/src/renderer/src/utils/abortController.ts +++ b/src/renderer/src/utils/abortController.ts @@ -20,3 +20,23 @@ export const abortCompletion = (id: string) => { } } } + +export function createAbortPromise(signal: AbortSignal, finallyPromise: Promise) { + return new Promise((_resolve, reject) => { + if (signal.aborted) { + reject(new DOMException('Operation aborted', 'AbortError')) + return + } + + const abortHandler = (e: Event) => { + console.log('abortHandler', e) + reject(new DOMException('Operation aborted', 'AbortError')) + } + + signal.addEventListener('abort', abortHandler, { once: true }) + + finallyPromise.finally(() => { + signal.removeEventListener('abort', abortHandler) + }) + }) +} diff --git a/src/renderer/src/utils/fetch.ts b/src/renderer/src/utils/fetch.ts index a01b34027d..ff88973f31 100644 --- a/src/renderer/src/utils/fetch.ts +++ b/src/renderer/src/utils/fetch.ts @@ -1,6 +1,8 @@ import { Readability } from '@mozilla/readability' import { nanoid } from '@reduxjs/toolkit' import { WebSearchProviderResult } from '@renderer/types' +import { createAbortPromise } from '@renderer/utils/abortController' +import { isAbortError } from '@renderer/utils/error' import TurndownService from 'turndown' const turndownService = new TurndownService() @@ -24,10 +26,10 @@ export async function fetchWebContents( urls: string[], format: ResponseFormat = 'markdown', usingBrowser: boolean = false, - signal: AbortSignal | null = null + httpOptions: RequestInit = {} ): Promise { // parallel using fetchWebContent - const results = await Promise.allSettled(urls.map((url) => fetchWebContent(url, format, usingBrowser, signal))) + const results = await Promise.allSettled(urls.map((url) => fetchWebContent(url, format, usingBrowser, httpOptions))) return results.map((result, index) => { if (result.status === 'fulfilled') { return result.value @@ -45,7 +47,7 @@ export async function fetchWebContent( url: string, format: ResponseFormat = 'markdown', usingBrowser: boolean = false, - signal: AbortSignal | null = null + httpOptions: RequestInit = {} ): Promise { try { // Validate URL before attempting to fetch @@ -53,19 +55,29 @@ export async function fetchWebContent( throw new Error(`Invalid URL format: ${url}`) } - // const controller = new AbortController() - // const timeoutId = setTimeout(() => controller.abort(), 30000) // 30 second timeout - let html: string if (usingBrowser) { - html = await window.api.searchService.openUrlInSearchWindow(`search-window-${nanoid()}`, url) + const windowApiPromise = window.api.searchService.openUrlInSearchWindow(`search-window-${nanoid()}`, url) + + const promisesToRace: [Promise] = [windowApiPromise] + + if (httpOptions?.signal) { + const signal = httpOptions.signal + const abortPromise = createAbortPromise(signal, windowApiPromise) + promisesToRace.push(abortPromise) + } + + html = await Promise.race(promisesToRace) } else { const response = await fetch(url, { headers: { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36' }, - signal: signal ? AbortSignal.any([signal, AbortSignal.timeout(30000)]) : AbortSignal.timeout(30000) + ...httpOptions, + signal: httpOptions?.signal + ? AbortSignal.any([httpOptions.signal, AbortSignal.timeout(30000)]) + : AbortSignal.timeout(30000) }) if (!response.ok) { throw new Error(`HTTP error: ${response.status}`) @@ -102,6 +114,10 @@ export async function fetchWebContent( } } } catch (e: unknown) { + if (isAbortError(e)) { + throw e + } + console.error(`Failed to fetch ${url}`, e) return { title: url,