fix: enhance search functionality with optional HTTP options (#5765)

* feat: enhance search functionality with optional HTTP options

- Updated the search method signatures in BaseWebSearchProvider, WebSearchEngineProvider, LocalSearchProvider, and WebSearchService to accept optional HTTP options.
- Modified fetchWebContent and fetchWebContents to utilize the new HTTP options parameter for improved request handling.
- Enhanced error handling in messageThunk to manage abort errors more effectively.

* feat: implement abortable promises for web search and fetch operations

- Added createAbortPromise utility to handle abort signals for promises.
- Updated LocalSearchProvider and fetchWebContent to utilize abortable promises, allowing for better control over ongoing requests.
- Enhanced error handling in ApiService to log errors consistently.
This commit is contained in:
MyPrototypeWhat 2025-05-08 18:03:19 +08:00 committed by GitHub
parent 1f96b326b6
commit 9390b92ebb
8 changed files with 160 additions and 89 deletions

View File

@ -10,7 +10,11 @@ export default abstract class BaseWebSearchProvider {
this.provider = provider
this.apiKey = this.getApiKey()
}
abstract search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse>
abstract search(
query: string,
websearch: WebSearchState,
httpOptions?: RequestInit
): Promise<WebSearchProviderResponse>
public getApiKey() {
const keys = this.provider.apiKey?.split(',').map((key) => key.trim()) || []

View File

@ -1,6 +1,8 @@
import { nanoid } from '@reduxjs/toolkit'
import { WebSearchState } from '@renderer/store/websearch'
import { WebSearchProvider, WebSearchProviderResponse, WebSearchProviderResult } from '@renderer/types'
import { createAbortPromise } from '@renderer/utils/abortController'
import { isAbortError } from '@renderer/utils/error'
import { fetchWebContent, noContent } from '@renderer/utils/fetch'
import BaseWebSearchProvider from './BaseWebSearchProvider'
@ -18,7 +20,11 @@ export default class LocalSearchProvider extends BaseWebSearchProvider {
super(provider)
}
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
public async search(
query: string,
websearch: WebSearchState,
httpOptions?: RequestInit
): Promise<WebSearchProviderResponse> {
const uid = nanoid()
try {
if (!query.trim()) {
@ -30,7 +36,13 @@ export default class LocalSearchProvider extends BaseWebSearchProvider {
const cleanedQuery = query.split('\r\n')[1] ?? query
const url = this.provider.url.replace('%s', encodeURIComponent(cleanedQuery))
const content = await window.api.searchService.openUrlInSearchWindow(uid, url)
let content: string = ''
const promisesToRace: [Promise<string>] = [window.api.searchService.openUrlInSearchWindow(uid, url)]
if (httpOptions?.signal) {
const abortPromise = createAbortPromise(httpOptions.signal, promisesToRace[0])
promisesToRace.push(abortPromise)
}
content = await Promise.race(promisesToRace)
// Parse the content to extract URLs and metadata
const searchItems = this.parseValidUrls(content).slice(0, websearch.maxResults)
@ -43,7 +55,7 @@ export default class LocalSearchProvider extends BaseWebSearchProvider {
// Fetch content for each URL concurrently
const fetchPromises = validItems.map(async (item) => {
// console.log(`Fetching content for ${item.url}...`)
const result = await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser)
const result = await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser, httpOptions)
if (websearch.contentLimit && result.content.length > websearch.contentLimit) {
result.content = result.content.slice(0, websearch.contentLimit) + '...'
}
@ -58,6 +70,9 @@ export default class LocalSearchProvider extends BaseWebSearchProvider {
results: results.filter((result) => result.content != noContent)
}
} catch (error) {
if (isAbortError(error)) {
throw error
}
console.error('Local search failed:', error)
throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`)
} finally {

View File

@ -7,11 +7,16 @@ import WebSearchProviderFactory from './WebSearchProviderFactory'
export default class WebSearchEngineProvider {
private sdk: BaseWebSearchProvider
constructor(provider: WebSearchProvider) {
this.sdk = WebSearchProviderFactory.create(provider)
}
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
const result = await this.sdk.search(query, websearch)
public async search(
query: string,
websearch: WebSearchState,
httpOptions?: RequestInit
): Promise<WebSearchProviderResponse> {
const result = await this.sdk.search(query, websearch, httpOptions)
const filteredResult = await filterResultWithBlacklist(result, websearch)
return filteredResult

View File

@ -145,8 +145,8 @@ async function fetchExternalTool(
source: WebSearchSource.WEBSEARCH
}
} catch (error) {
console.error('Web search failed:', error)
if (isAbortError(error)) throw error
console.error('Web search failed:', error)
return
}
}
@ -251,8 +251,8 @@ async function fetchExternalTool(
return { mcpTools }
} catch (error) {
console.error('Tool execution failed:', error)
if (isAbortError(error)) throw error
console.error('Tool execution failed:', error)
// 发送错误状态
if (willUseTools) {

View File

@ -97,7 +97,11 @@ class WebSearchService {
* @param query
* @returns
*/
public async search(provider: WebSearchProvider, query: string): Promise<WebSearchProviderResponse> {
public async search(
provider: WebSearchProvider,
query: string,
httpOptions?: RequestInit
): Promise<WebSearchProviderResponse> {
const websearch = this.getWebSearchState()
const webSearchEngine = new WebSearchEngineProvider(provider)
@ -107,12 +111,12 @@ class WebSearchService {
formattedQuery = `today is ${dayjs().format('YYYY-MM-DD')} \r\n ${query}`
}
try {
return await webSearchEngine.search(formattedQuery, websearch)
} catch (error) {
console.error('Search failed:', error)
throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`)
}
// try {
return await webSearchEngine.search(formattedQuery, websearch, httpOptions)
// } catch (error) {
// console.error('Search failed:', error)
// throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`)
// }
}
/**
@ -136,42 +140,41 @@ class WebSearchService {
webSearchProvider: WebSearchProvider,
extractResults: ExtractResults
): Promise<WebSearchProviderResponse> {
try {
// 检查 websearch 和 question 是否有效
if (!extractResults.websearch?.question || extractResults.websearch.question.length === 0) {
console.log('No valid question found in extractResults.websearch')
return { results: [] }
}
// 检查 websearch 和 question 是否有效
if (!extractResults.websearch?.question || extractResults.websearch.question.length === 0) {
console.log('No valid question found in extractResults.websearch')
return { results: [] }
}
const questions = extractResults.websearch.question
const links = extractResults.websearch.links
const firstQuestion = questions[0]
if (firstQuestion === 'summarize' && links && links.length > 0) {
const contents = await fetchWebContents(links, undefined, undefined, this.signal)
return {
query: 'summaries',
results: contents
}
}
const searchPromises = questions.map((q) => this.search(webSearchProvider, q))
const searchResults = await Promise.allSettled(searchPromises)
const aggregatedResults: any[] = []
searchResults.forEach((result) => {
if (result.status === 'fulfilled') {
if (result.value.results) {
aggregatedResults.push(...result.value.results)
}
}
const questions = extractResults.websearch.question
const links = extractResults.websearch.links
const firstQuestion = questions[0]
if (firstQuestion === 'summarize' && links && links.length > 0) {
const contents = await fetchWebContents(links, undefined, undefined, {
signal: this.signal
})
return {
query: questions.join(' | '),
results: aggregatedResults
query: 'summaries',
results: contents
}
} catch (error) {
console.error('Failed to process enhanced search:', error)
return { results: [] }
}
const searchPromises = questions.map((q) => this.search(webSearchProvider, q, { signal: this.signal }))
const searchResults = await Promise.allSettled(searchPromises)
const aggregatedResults: any[] = []
searchResults.forEach((result) => {
if (result.status === 'fulfilled') {
if (result.value.results) {
aggregatedResults.push(...result.value.results)
}
}
if (result.status === 'rejected') {
throw result.reason
}
})
return {
query: questions.join(' | '),
results: aggregatedResults
}
}
}

View File

@ -251,7 +251,7 @@ const fetchAndProcessAssistantResponseImpl = async (
let mainTextBlockId: string | null = null
const toolCallIdToBlockIdMap = new Map<string, string>()
const handleBlockTransition = (newBlock: MessageBlock, newBlockType: MessageBlockType) => {
const handleBlockTransition = async (newBlock: MessageBlock, newBlockType: MessageBlockType) => {
lastBlockId = newBlock.id
lastBlockType = newBlockType
if (newBlockType !== MessageBlockType.MAIN_TEXT) {
@ -279,9 +279,12 @@ const fetchAndProcessAssistantResponseImpl = async (
const currentState = getState()
const updatedMessage = currentState.messages.entities[assistantMsgId]
if (updatedMessage) {
saveUpdatesToDB(assistantMsgId, topicId, { blocks: updatedMessage.blocks, status: updatedMessage.status }, [
newBlock
])
await saveUpdatesToDB(
assistantMsgId,
topicId,
{ blocks: updatedMessage.blocks, status: updatedMessage.status },
[newBlock]
)
} else {
console.error(`[handleBlockTransition] Failed to get updated message ${assistantMsgId} from state for DB save.`)
}
@ -530,10 +533,11 @@ const fetchAndProcessAssistantResponseImpl = async (
console.error('[onImageGenerated] Last block was not an Image block or ID is missing.')
}
},
onError: (error) => {
onError: async (error) => {
console.dir(error, { depth: null })
const isErrorTypeAbort = isAbortError(error)
let pauseErrorLanguagePlaceholder = ''
if (isAbortError(error)) {
if (isErrorTypeAbort) {
pauseErrorLanguagePlaceholder = 'pause_placeholder'
}
@ -548,16 +552,16 @@ const fetchAndProcessAssistantResponseImpl = async (
if (lastBlockId) {
// 更改上一个block的状态为ERROR
const changes: Partial<MessageBlock> = {
status: MessageBlockStatus.ERROR
status: isErrorTypeAbort ? MessageBlockStatus.PAUSED : MessageBlockStatus.ERROR
}
dispatch(updateOneBlock({ id: lastBlockId, changes }))
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
}
const errorBlock = createErrorBlock(assistantMsgId, serializableError, { status: MessageBlockStatus.SUCCESS })
handleBlockTransition(errorBlock, MessageBlockType.ERROR)
await handleBlockTransition(errorBlock, MessageBlockType.ERROR)
const messageErrorUpdate = {
status: isAbortError(error) ? AssistantMessageStatus.SUCCESS : AssistantMessageStatus.ERROR
status: isErrorTypeAbort ? AssistantMessageStatus.SUCCESS : AssistantMessageStatus.ERROR
}
dispatch(newMessagesActions.updateMessage({ topicId, messageId: assistantMsgId, updates: messageErrorUpdate }))
@ -566,7 +570,7 @@ const fetchAndProcessAssistantResponseImpl = async (
EventEmitter.emit(EVENT_NAMES.MESSAGE_COMPLETE, {
id: assistantMsgId,
topicId,
status: isAbortError(error) ? 'pause' : 'error',
status: isErrorTypeAbort ? 'pause' : 'error',
error: error.message
})
},
@ -838,14 +842,18 @@ export const resendMessageThunk =
(m) => m.askId === userMessageToResend.id && m.role === 'assistant'
)
const resetDataList: Message[] = []
if (assistantMessagesToReset.length === 0) {
console.warn(
`[resendMessageThunk] No assistant responses found for user message ${userMessageToResend.id}. Nothing to regenerate.`
)
return
// 没有用户消息,就创建一个
const assistantMessage = createAssistantMessage(assistant.id, topicId, {
askId: userMessageToResend.id,
model: assistant.model
})
resetDataList.push(assistantMessage)
dispatch(newMessagesActions.addMessage({ topicId, message: assistantMessage }))
}
const resetDataList: { resetMsg: Message }[] = []
const allBlockIdsToDelete: string[] = []
const messagesToUpdateInRedux: { topicId: string; messageId: string; updates: Partial<Message> }[] = []
@ -856,7 +864,7 @@ export const resendMessageThunk =
...(assistantMessagesToReset.length === 1 ? { model: assistant.model } : {})
})
resetDataList.push({ resetMsg })
resetDataList.push(resetMsg)
allBlockIdsToDelete.push(...blockIdsToDelete)
messagesToUpdateInRedux.push({ topicId, messageId: resetMsg.id, updates: resetMsg })
}
@ -877,7 +885,7 @@ export const resendMessageThunk =
}
const queue = getTopicQueue(topicId)
for (const { resetMsg } of resetDataList) {
for (const resetMsg of resetDataList) {
const assistantConfigForThisRegen = {
...assistant,
...(resetMsg.model ? { model: resetMsg.model } : {})
@ -1071,25 +1079,25 @@ export const initiateTranslationThunk =
// --- Thunk to update the translation block with new content ---
export const updateTranslationBlockThunk =
(blockId: string, accumulatedText: string, isComplete: boolean = false) =>
async (dispatch: AppDispatch) => {
console.log(`[updateTranslationBlockThunk] 更新翻译块 ${blockId}, isComplete: ${isComplete}`)
try {
const status = isComplete ? MessageBlockStatus.SUCCESS : MessageBlockStatus.STREAMING
const changes: Partial<MessageBlock> = {
content: accumulatedText,
status: status
}
// 更新Redux状态
dispatch(updateOneBlock({ id: blockId, changes }))
// 更新数据库
await db.message_blocks.update(blockId, changes)
console.log(`[updateTranslationBlockThunk] Successfully updated translation block ${blockId}.`)
} catch (error) {
console.error(`[updateTranslationBlockThunk] Failed to update translation block ${blockId}:`, error)
async (dispatch: AppDispatch) => {
console.log(`[updateTranslationBlockThunk] 更新翻译块 ${blockId}, isComplete: ${isComplete}`)
try {
const status = isComplete ? MessageBlockStatus.SUCCESS : MessageBlockStatus.STREAMING
const changes: Partial<MessageBlock> = {
content: accumulatedText,
status: status
}
// 更新Redux状态
dispatch(updateOneBlock({ id: blockId, changes }))
// 更新数据库
await db.message_blocks.update(blockId, changes)
console.log(`[updateTranslationBlockThunk] Successfully updated translation block ${blockId}.`)
} catch (error) {
console.error(`[updateTranslationBlockThunk] Failed to update translation block ${blockId}:`, error)
}
}
/**
* Thunk to append a new assistant response (using a potentially different model)

View File

@ -20,3 +20,23 @@ export const abortCompletion = (id: string) => {
}
}
}
export function createAbortPromise(signal: AbortSignal, finallyPromise: Promise<string>) {
return new Promise<string>((_resolve, reject) => {
if (signal.aborted) {
reject(new DOMException('Operation aborted', 'AbortError'))
return
}
const abortHandler = (e: Event) => {
console.log('abortHandler', e)
reject(new DOMException('Operation aborted', 'AbortError'))
}
signal.addEventListener('abort', abortHandler, { once: true })
finallyPromise.finally(() => {
signal.removeEventListener('abort', abortHandler)
})
})
}

View File

@ -1,6 +1,8 @@
import { Readability } from '@mozilla/readability'
import { nanoid } from '@reduxjs/toolkit'
import { WebSearchProviderResult } from '@renderer/types'
import { createAbortPromise } from '@renderer/utils/abortController'
import { isAbortError } from '@renderer/utils/error'
import TurndownService from 'turndown'
const turndownService = new TurndownService()
@ -24,10 +26,10 @@ export async function fetchWebContents(
urls: string[],
format: ResponseFormat = 'markdown',
usingBrowser: boolean = false,
signal: AbortSignal | null = null
httpOptions: RequestInit = {}
): Promise<WebSearchProviderResult[]> {
// parallel using fetchWebContent
const results = await Promise.allSettled(urls.map((url) => fetchWebContent(url, format, usingBrowser, signal)))
const results = await Promise.allSettled(urls.map((url) => fetchWebContent(url, format, usingBrowser, httpOptions)))
return results.map((result, index) => {
if (result.status === 'fulfilled') {
return result.value
@ -45,7 +47,7 @@ export async function fetchWebContent(
url: string,
format: ResponseFormat = 'markdown',
usingBrowser: boolean = false,
signal: AbortSignal | null = null
httpOptions: RequestInit = {}
): Promise<WebSearchProviderResult> {
try {
// Validate URL before attempting to fetch
@ -53,19 +55,29 @@ export async function fetchWebContent(
throw new Error(`Invalid URL format: ${url}`)
}
// const controller = new AbortController()
// const timeoutId = setTimeout(() => controller.abort(), 30000) // 30 second timeout
let html: string
if (usingBrowser) {
html = await window.api.searchService.openUrlInSearchWindow(`search-window-${nanoid()}`, url)
const windowApiPromise = window.api.searchService.openUrlInSearchWindow(`search-window-${nanoid()}`, url)
const promisesToRace: [Promise<string>] = [windowApiPromise]
if (httpOptions?.signal) {
const signal = httpOptions.signal
const abortPromise = createAbortPromise(signal, windowApiPromise)
promisesToRace.push(abortPromise)
}
html = await Promise.race(promisesToRace)
} else {
const response = await fetch(url, {
headers: {
'User-Agent':
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'
},
signal: signal ? AbortSignal.any([signal, AbortSignal.timeout(30000)]) : AbortSignal.timeout(30000)
...httpOptions,
signal: httpOptions?.signal
? AbortSignal.any([httpOptions.signal, AbortSignal.timeout(30000)])
: AbortSignal.timeout(30000)
})
if (!response.ok) {
throw new Error(`HTTP error: ${response.status}`)
@ -102,6 +114,10 @@ export async function fetchWebContent(
}
}
} catch (e: unknown) {
if (isAbortError(e)) {
throw e
}
console.error(`Failed to fetch ${url}`, e)
return {
title: url,