mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-26 11:44:28 +08:00
fix: enhance search functionality with optional HTTP options (#5765)
* feat: enhance search functionality with optional HTTP options - Updated the search method signatures in BaseWebSearchProvider, WebSearchEngineProvider, LocalSearchProvider, and WebSearchService to accept optional HTTP options. - Modified fetchWebContent and fetchWebContents to utilize the new HTTP options parameter for improved request handling. - Enhanced error handling in messageThunk to manage abort errors more effectively. * feat: implement abortable promises for web search and fetch operations - Added createAbortPromise utility to handle abort signals for promises. - Updated LocalSearchProvider and fetchWebContent to utilize abortable promises, allowing for better control over ongoing requests. - Enhanced error handling in ApiService to log errors consistently.
This commit is contained in:
parent
1f96b326b6
commit
9390b92ebb
@ -10,7 +10,11 @@ export default abstract class BaseWebSearchProvider {
|
||||
this.provider = provider
|
||||
this.apiKey = this.getApiKey()
|
||||
}
|
||||
abstract search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse>
|
||||
abstract search(
|
||||
query: string,
|
||||
websearch: WebSearchState,
|
||||
httpOptions?: RequestInit
|
||||
): Promise<WebSearchProviderResponse>
|
||||
|
||||
public getApiKey() {
|
||||
const keys = this.provider.apiKey?.split(',').map((key) => key.trim()) || []
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import { nanoid } from '@reduxjs/toolkit'
|
||||
import { WebSearchState } from '@renderer/store/websearch'
|
||||
import { WebSearchProvider, WebSearchProviderResponse, WebSearchProviderResult } from '@renderer/types'
|
||||
import { createAbortPromise } from '@renderer/utils/abortController'
|
||||
import { isAbortError } from '@renderer/utils/error'
|
||||
import { fetchWebContent, noContent } from '@renderer/utils/fetch'
|
||||
|
||||
import BaseWebSearchProvider from './BaseWebSearchProvider'
|
||||
@ -18,7 +20,11 @@ export default class LocalSearchProvider extends BaseWebSearchProvider {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
|
||||
public async search(
|
||||
query: string,
|
||||
websearch: WebSearchState,
|
||||
httpOptions?: RequestInit
|
||||
): Promise<WebSearchProviderResponse> {
|
||||
const uid = nanoid()
|
||||
try {
|
||||
if (!query.trim()) {
|
||||
@ -30,7 +36,13 @@ export default class LocalSearchProvider extends BaseWebSearchProvider {
|
||||
|
||||
const cleanedQuery = query.split('\r\n')[1] ?? query
|
||||
const url = this.provider.url.replace('%s', encodeURIComponent(cleanedQuery))
|
||||
const content = await window.api.searchService.openUrlInSearchWindow(uid, url)
|
||||
let content: string = ''
|
||||
const promisesToRace: [Promise<string>] = [window.api.searchService.openUrlInSearchWindow(uid, url)]
|
||||
if (httpOptions?.signal) {
|
||||
const abortPromise = createAbortPromise(httpOptions.signal, promisesToRace[0])
|
||||
promisesToRace.push(abortPromise)
|
||||
}
|
||||
content = await Promise.race(promisesToRace)
|
||||
|
||||
// Parse the content to extract URLs and metadata
|
||||
const searchItems = this.parseValidUrls(content).slice(0, websearch.maxResults)
|
||||
@ -43,7 +55,7 @@ export default class LocalSearchProvider extends BaseWebSearchProvider {
|
||||
// Fetch content for each URL concurrently
|
||||
const fetchPromises = validItems.map(async (item) => {
|
||||
// console.log(`Fetching content for ${item.url}...`)
|
||||
const result = await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser)
|
||||
const result = await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser, httpOptions)
|
||||
if (websearch.contentLimit && result.content.length > websearch.contentLimit) {
|
||||
result.content = result.content.slice(0, websearch.contentLimit) + '...'
|
||||
}
|
||||
@ -58,6 +70,9 @@ export default class LocalSearchProvider extends BaseWebSearchProvider {
|
||||
results: results.filter((result) => result.content != noContent)
|
||||
}
|
||||
} catch (error) {
|
||||
if (isAbortError(error)) {
|
||||
throw error
|
||||
}
|
||||
console.error('Local search failed:', error)
|
||||
throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`)
|
||||
} finally {
|
||||
|
||||
@ -7,11 +7,16 @@ import WebSearchProviderFactory from './WebSearchProviderFactory'
|
||||
|
||||
export default class WebSearchEngineProvider {
|
||||
private sdk: BaseWebSearchProvider
|
||||
|
||||
constructor(provider: WebSearchProvider) {
|
||||
this.sdk = WebSearchProviderFactory.create(provider)
|
||||
}
|
||||
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
|
||||
const result = await this.sdk.search(query, websearch)
|
||||
public async search(
|
||||
query: string,
|
||||
websearch: WebSearchState,
|
||||
httpOptions?: RequestInit
|
||||
): Promise<WebSearchProviderResponse> {
|
||||
const result = await this.sdk.search(query, websearch, httpOptions)
|
||||
const filteredResult = await filterResultWithBlacklist(result, websearch)
|
||||
|
||||
return filteredResult
|
||||
|
||||
@ -145,8 +145,8 @@ async function fetchExternalTool(
|
||||
source: WebSearchSource.WEBSEARCH
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Web search failed:', error)
|
||||
if (isAbortError(error)) throw error
|
||||
console.error('Web search failed:', error)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -251,8 +251,8 @@ async function fetchExternalTool(
|
||||
|
||||
return { mcpTools }
|
||||
} catch (error) {
|
||||
console.error('Tool execution failed:', error)
|
||||
if (isAbortError(error)) throw error
|
||||
console.error('Tool execution failed:', error)
|
||||
|
||||
// 发送错误状态
|
||||
if (willUseTools) {
|
||||
|
||||
@ -97,7 +97,11 @@ class WebSearchService {
|
||||
* @param query 搜索查询
|
||||
* @returns 搜索响应
|
||||
*/
|
||||
public async search(provider: WebSearchProvider, query: string): Promise<WebSearchProviderResponse> {
|
||||
public async search(
|
||||
provider: WebSearchProvider,
|
||||
query: string,
|
||||
httpOptions?: RequestInit
|
||||
): Promise<WebSearchProviderResponse> {
|
||||
const websearch = this.getWebSearchState()
|
||||
const webSearchEngine = new WebSearchEngineProvider(provider)
|
||||
|
||||
@ -107,12 +111,12 @@ class WebSearchService {
|
||||
formattedQuery = `today is ${dayjs().format('YYYY-MM-DD')} \r\n ${query}`
|
||||
}
|
||||
|
||||
try {
|
||||
return await webSearchEngine.search(formattedQuery, websearch)
|
||||
} catch (error) {
|
||||
console.error('Search failed:', error)
|
||||
throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`)
|
||||
}
|
||||
// try {
|
||||
return await webSearchEngine.search(formattedQuery, websearch, httpOptions)
|
||||
// } catch (error) {
|
||||
// console.error('Search failed:', error)
|
||||
// throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`)
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
@ -136,42 +140,41 @@ class WebSearchService {
|
||||
webSearchProvider: WebSearchProvider,
|
||||
extractResults: ExtractResults
|
||||
): Promise<WebSearchProviderResponse> {
|
||||
try {
|
||||
// 检查 websearch 和 question 是否有效
|
||||
if (!extractResults.websearch?.question || extractResults.websearch.question.length === 0) {
|
||||
console.log('No valid question found in extractResults.websearch')
|
||||
return { results: [] }
|
||||
}
|
||||
// 检查 websearch 和 question 是否有效
|
||||
if (!extractResults.websearch?.question || extractResults.websearch.question.length === 0) {
|
||||
console.log('No valid question found in extractResults.websearch')
|
||||
return { results: [] }
|
||||
}
|
||||
|
||||
const questions = extractResults.websearch.question
|
||||
const links = extractResults.websearch.links
|
||||
const firstQuestion = questions[0]
|
||||
|
||||
if (firstQuestion === 'summarize' && links && links.length > 0) {
|
||||
const contents = await fetchWebContents(links, undefined, undefined, this.signal)
|
||||
return {
|
||||
query: 'summaries',
|
||||
results: contents
|
||||
}
|
||||
}
|
||||
const searchPromises = questions.map((q) => this.search(webSearchProvider, q))
|
||||
const searchResults = await Promise.allSettled(searchPromises)
|
||||
const aggregatedResults: any[] = []
|
||||
|
||||
searchResults.forEach((result) => {
|
||||
if (result.status === 'fulfilled') {
|
||||
if (result.value.results) {
|
||||
aggregatedResults.push(...result.value.results)
|
||||
}
|
||||
}
|
||||
const questions = extractResults.websearch.question
|
||||
const links = extractResults.websearch.links
|
||||
const firstQuestion = questions[0]
|
||||
if (firstQuestion === 'summarize' && links && links.length > 0) {
|
||||
const contents = await fetchWebContents(links, undefined, undefined, {
|
||||
signal: this.signal
|
||||
})
|
||||
return {
|
||||
query: questions.join(' | '),
|
||||
results: aggregatedResults
|
||||
query: 'summaries',
|
||||
results: contents
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to process enhanced search:', error)
|
||||
return { results: [] }
|
||||
}
|
||||
const searchPromises = questions.map((q) => this.search(webSearchProvider, q, { signal: this.signal }))
|
||||
const searchResults = await Promise.allSettled(searchPromises)
|
||||
const aggregatedResults: any[] = []
|
||||
|
||||
searchResults.forEach((result) => {
|
||||
if (result.status === 'fulfilled') {
|
||||
if (result.value.results) {
|
||||
aggregatedResults.push(...result.value.results)
|
||||
}
|
||||
}
|
||||
if (result.status === 'rejected') {
|
||||
throw result.reason
|
||||
}
|
||||
})
|
||||
return {
|
||||
query: questions.join(' | '),
|
||||
results: aggregatedResults
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -251,7 +251,7 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
let mainTextBlockId: string | null = null
|
||||
const toolCallIdToBlockIdMap = new Map<string, string>()
|
||||
|
||||
const handleBlockTransition = (newBlock: MessageBlock, newBlockType: MessageBlockType) => {
|
||||
const handleBlockTransition = async (newBlock: MessageBlock, newBlockType: MessageBlockType) => {
|
||||
lastBlockId = newBlock.id
|
||||
lastBlockType = newBlockType
|
||||
if (newBlockType !== MessageBlockType.MAIN_TEXT) {
|
||||
@ -279,9 +279,12 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
const currentState = getState()
|
||||
const updatedMessage = currentState.messages.entities[assistantMsgId]
|
||||
if (updatedMessage) {
|
||||
saveUpdatesToDB(assistantMsgId, topicId, { blocks: updatedMessage.blocks, status: updatedMessage.status }, [
|
||||
newBlock
|
||||
])
|
||||
await saveUpdatesToDB(
|
||||
assistantMsgId,
|
||||
topicId,
|
||||
{ blocks: updatedMessage.blocks, status: updatedMessage.status },
|
||||
[newBlock]
|
||||
)
|
||||
} else {
|
||||
console.error(`[handleBlockTransition] Failed to get updated message ${assistantMsgId} from state for DB save.`)
|
||||
}
|
||||
@ -530,10 +533,11 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
console.error('[onImageGenerated] Last block was not an Image block or ID is missing.')
|
||||
}
|
||||
},
|
||||
onError: (error) => {
|
||||
onError: async (error) => {
|
||||
console.dir(error, { depth: null })
|
||||
const isErrorTypeAbort = isAbortError(error)
|
||||
let pauseErrorLanguagePlaceholder = ''
|
||||
if (isAbortError(error)) {
|
||||
if (isErrorTypeAbort) {
|
||||
pauseErrorLanguagePlaceholder = 'pause_placeholder'
|
||||
}
|
||||
|
||||
@ -548,16 +552,16 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
if (lastBlockId) {
|
||||
// 更改上一个block的状态为ERROR
|
||||
const changes: Partial<MessageBlock> = {
|
||||
status: MessageBlockStatus.ERROR
|
||||
status: isErrorTypeAbort ? MessageBlockStatus.PAUSED : MessageBlockStatus.ERROR
|
||||
}
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
}
|
||||
|
||||
const errorBlock = createErrorBlock(assistantMsgId, serializableError, { status: MessageBlockStatus.SUCCESS })
|
||||
handleBlockTransition(errorBlock, MessageBlockType.ERROR)
|
||||
await handleBlockTransition(errorBlock, MessageBlockType.ERROR)
|
||||
const messageErrorUpdate = {
|
||||
status: isAbortError(error) ? AssistantMessageStatus.SUCCESS : AssistantMessageStatus.ERROR
|
||||
status: isErrorTypeAbort ? AssistantMessageStatus.SUCCESS : AssistantMessageStatus.ERROR
|
||||
}
|
||||
dispatch(newMessagesActions.updateMessage({ topicId, messageId: assistantMsgId, updates: messageErrorUpdate }))
|
||||
|
||||
@ -566,7 +570,7 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
EventEmitter.emit(EVENT_NAMES.MESSAGE_COMPLETE, {
|
||||
id: assistantMsgId,
|
||||
topicId,
|
||||
status: isAbortError(error) ? 'pause' : 'error',
|
||||
status: isErrorTypeAbort ? 'pause' : 'error',
|
||||
error: error.message
|
||||
})
|
||||
},
|
||||
@ -838,14 +842,18 @@ export const resendMessageThunk =
|
||||
(m) => m.askId === userMessageToResend.id && m.role === 'assistant'
|
||||
)
|
||||
|
||||
const resetDataList: Message[] = []
|
||||
|
||||
if (assistantMessagesToReset.length === 0) {
|
||||
console.warn(
|
||||
`[resendMessageThunk] No assistant responses found for user message ${userMessageToResend.id}. Nothing to regenerate.`
|
||||
)
|
||||
return
|
||||
// 没有用户消息,就创建一个
|
||||
const assistantMessage = createAssistantMessage(assistant.id, topicId, {
|
||||
askId: userMessageToResend.id,
|
||||
model: assistant.model
|
||||
})
|
||||
resetDataList.push(assistantMessage)
|
||||
dispatch(newMessagesActions.addMessage({ topicId, message: assistantMessage }))
|
||||
}
|
||||
|
||||
const resetDataList: { resetMsg: Message }[] = []
|
||||
const allBlockIdsToDelete: string[] = []
|
||||
const messagesToUpdateInRedux: { topicId: string; messageId: string; updates: Partial<Message> }[] = []
|
||||
|
||||
@ -856,7 +864,7 @@ export const resendMessageThunk =
|
||||
...(assistantMessagesToReset.length === 1 ? { model: assistant.model } : {})
|
||||
})
|
||||
|
||||
resetDataList.push({ resetMsg })
|
||||
resetDataList.push(resetMsg)
|
||||
allBlockIdsToDelete.push(...blockIdsToDelete)
|
||||
messagesToUpdateInRedux.push({ topicId, messageId: resetMsg.id, updates: resetMsg })
|
||||
}
|
||||
@ -877,7 +885,7 @@ export const resendMessageThunk =
|
||||
}
|
||||
|
||||
const queue = getTopicQueue(topicId)
|
||||
for (const { resetMsg } of resetDataList) {
|
||||
for (const resetMsg of resetDataList) {
|
||||
const assistantConfigForThisRegen = {
|
||||
...assistant,
|
||||
...(resetMsg.model ? { model: resetMsg.model } : {})
|
||||
@ -1071,25 +1079,25 @@ export const initiateTranslationThunk =
|
||||
// --- Thunk to update the translation block with new content ---
|
||||
export const updateTranslationBlockThunk =
|
||||
(blockId: string, accumulatedText: string, isComplete: boolean = false) =>
|
||||
async (dispatch: AppDispatch) => {
|
||||
console.log(`[updateTranslationBlockThunk] 更新翻译块 ${blockId}, isComplete: ${isComplete}`)
|
||||
try {
|
||||
const status = isComplete ? MessageBlockStatus.SUCCESS : MessageBlockStatus.STREAMING
|
||||
const changes: Partial<MessageBlock> = {
|
||||
content: accumulatedText,
|
||||
status: status
|
||||
}
|
||||
|
||||
// 更新Redux状态
|
||||
dispatch(updateOneBlock({ id: blockId, changes }))
|
||||
|
||||
// 更新数据库
|
||||
await db.message_blocks.update(blockId, changes)
|
||||
console.log(`[updateTranslationBlockThunk] Successfully updated translation block ${blockId}.`)
|
||||
} catch (error) {
|
||||
console.error(`[updateTranslationBlockThunk] Failed to update translation block ${blockId}:`, error)
|
||||
async (dispatch: AppDispatch) => {
|
||||
console.log(`[updateTranslationBlockThunk] 更新翻译块 ${blockId}, isComplete: ${isComplete}`)
|
||||
try {
|
||||
const status = isComplete ? MessageBlockStatus.SUCCESS : MessageBlockStatus.STREAMING
|
||||
const changes: Partial<MessageBlock> = {
|
||||
content: accumulatedText,
|
||||
status: status
|
||||
}
|
||||
|
||||
// 更新Redux状态
|
||||
dispatch(updateOneBlock({ id: blockId, changes }))
|
||||
|
||||
// 更新数据库
|
||||
await db.message_blocks.update(blockId, changes)
|
||||
console.log(`[updateTranslationBlockThunk] Successfully updated translation block ${blockId}.`)
|
||||
} catch (error) {
|
||||
console.error(`[updateTranslationBlockThunk] Failed to update translation block ${blockId}:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Thunk to append a new assistant response (using a potentially different model)
|
||||
|
||||
@ -20,3 +20,23 @@ export const abortCompletion = (id: string) => {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function createAbortPromise(signal: AbortSignal, finallyPromise: Promise<string>) {
|
||||
return new Promise<string>((_resolve, reject) => {
|
||||
if (signal.aborted) {
|
||||
reject(new DOMException('Operation aborted', 'AbortError'))
|
||||
return
|
||||
}
|
||||
|
||||
const abortHandler = (e: Event) => {
|
||||
console.log('abortHandler', e)
|
||||
reject(new DOMException('Operation aborted', 'AbortError'))
|
||||
}
|
||||
|
||||
signal.addEventListener('abort', abortHandler, { once: true })
|
||||
|
||||
finallyPromise.finally(() => {
|
||||
signal.removeEventListener('abort', abortHandler)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import { Readability } from '@mozilla/readability'
|
||||
import { nanoid } from '@reduxjs/toolkit'
|
||||
import { WebSearchProviderResult } from '@renderer/types'
|
||||
import { createAbortPromise } from '@renderer/utils/abortController'
|
||||
import { isAbortError } from '@renderer/utils/error'
|
||||
import TurndownService from 'turndown'
|
||||
|
||||
const turndownService = new TurndownService()
|
||||
@ -24,10 +26,10 @@ export async function fetchWebContents(
|
||||
urls: string[],
|
||||
format: ResponseFormat = 'markdown',
|
||||
usingBrowser: boolean = false,
|
||||
signal: AbortSignal | null = null
|
||||
httpOptions: RequestInit = {}
|
||||
): Promise<WebSearchProviderResult[]> {
|
||||
// parallel using fetchWebContent
|
||||
const results = await Promise.allSettled(urls.map((url) => fetchWebContent(url, format, usingBrowser, signal)))
|
||||
const results = await Promise.allSettled(urls.map((url) => fetchWebContent(url, format, usingBrowser, httpOptions)))
|
||||
return results.map((result, index) => {
|
||||
if (result.status === 'fulfilled') {
|
||||
return result.value
|
||||
@ -45,7 +47,7 @@ export async function fetchWebContent(
|
||||
url: string,
|
||||
format: ResponseFormat = 'markdown',
|
||||
usingBrowser: boolean = false,
|
||||
signal: AbortSignal | null = null
|
||||
httpOptions: RequestInit = {}
|
||||
): Promise<WebSearchProviderResult> {
|
||||
try {
|
||||
// Validate URL before attempting to fetch
|
||||
@ -53,19 +55,29 @@ export async function fetchWebContent(
|
||||
throw new Error(`Invalid URL format: ${url}`)
|
||||
}
|
||||
|
||||
// const controller = new AbortController()
|
||||
// const timeoutId = setTimeout(() => controller.abort(), 30000) // 30 second timeout
|
||||
|
||||
let html: string
|
||||
if (usingBrowser) {
|
||||
html = await window.api.searchService.openUrlInSearchWindow(`search-window-${nanoid()}`, url)
|
||||
const windowApiPromise = window.api.searchService.openUrlInSearchWindow(`search-window-${nanoid()}`, url)
|
||||
|
||||
const promisesToRace: [Promise<string>] = [windowApiPromise]
|
||||
|
||||
if (httpOptions?.signal) {
|
||||
const signal = httpOptions.signal
|
||||
const abortPromise = createAbortPromise(signal, windowApiPromise)
|
||||
promisesToRace.push(abortPromise)
|
||||
}
|
||||
|
||||
html = await Promise.race(promisesToRace)
|
||||
} else {
|
||||
const response = await fetch(url, {
|
||||
headers: {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'
|
||||
},
|
||||
signal: signal ? AbortSignal.any([signal, AbortSignal.timeout(30000)]) : AbortSignal.timeout(30000)
|
||||
...httpOptions,
|
||||
signal: httpOptions?.signal
|
||||
? AbortSignal.any([httpOptions.signal, AbortSignal.timeout(30000)])
|
||||
: AbortSignal.timeout(30000)
|
||||
})
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error: ${response.status}`)
|
||||
@ -102,6 +114,10 @@ export async function fetchWebContent(
|
||||
}
|
||||
}
|
||||
} catch (e: unknown) {
|
||||
if (isAbortError(e)) {
|
||||
throw e
|
||||
}
|
||||
|
||||
console.error(`Failed to fetch ${url}`, e)
|
||||
return {
|
||||
title: url,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user