mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-23 18:10:26 +08:00
Merge branch 'main' into fix/next-release-bugs
This commit is contained in:
commit
f4963888e8
15
README.md
15
README.md
@ -19,7 +19,7 @@ Cherry Studio is a desktop client that supports for multiple LLM providers, avai
|
||||
|
||||
# 📖 Guide
|
||||
|
||||
https://docs.cherry-ai.com
|
||||
<https://docs.cherry-ai.com>
|
||||
|
||||
# 🌠 Screenshot
|
||||
|
||||
@ -82,17 +82,18 @@ https://docs.cherry-ai.com
|
||||
|
||||
# 🌈 Theme
|
||||
|
||||
- Theme Gallery: https://cherrycss.com
|
||||
- Aero Theme: https://github.com/hakadao/CherryStudio-Aero
|
||||
- PaperMaterial Theme: https://github.com/rainoffallingstar/CherryStudio-PaperMaterial
|
||||
- Claude dynamic-style: https://github.com/bjl101501/CherryStudio-Claudestyle-dynamic
|
||||
- Maple Neon Theme: https://github.com/BoningtonChen/CherryStudio_themes
|
||||
- Theme Gallery: <https://cherrycss.com>
|
||||
- Aero Theme: <https://github.com/hakadao/CherryStudio-Aero>
|
||||
- PaperMaterial Theme: <https://github.com/rainoffallingstar/CherryStudio-PaperMaterial>
|
||||
- Claude dynamic-style: <https://github.com/bjl101501/CherryStudio-Claudestyle-dynamic>
|
||||
- Maple Neon Theme: <https://github.com/BoningtonChen/CherryStudio_themes>
|
||||
|
||||
Welcome PR for more themes
|
||||
|
||||
# 🖥️ Develop
|
||||
|
||||
Refer to the [development documentation](docs/dev.md)
|
||||
Refer to the [Architecture overview documentation](https://deepwiki.com/CherryHQ/cherry-studio)
|
||||
|
||||
# 🤝 Contributing
|
||||
|
||||
@ -144,7 +145,7 @@ Thank you for your support and contributions!
|
||||
|
||||
# ✉️ Contact
|
||||
|
||||
yinsenho@cherry-ai.com
|
||||
<yinsenho@cherry-ai.com>
|
||||
|
||||
# ⭐️ Star History
|
||||
|
||||
|
||||
@ -17,6 +17,10 @@ export default abstract class BaseReranker {
|
||||
* Get Rerank Request Url
|
||||
*/
|
||||
protected getRerankUrl() {
|
||||
if (this.base.rerankModelProvider === 'dashscope') {
|
||||
return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
|
||||
}
|
||||
|
||||
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
|
||||
? this.base.rerankBaseURL.slice(0, -1)
|
||||
: this.base.rerankBaseURL
|
||||
@ -28,6 +32,56 @@ export default abstract class BaseReranker {
|
||||
return `${baseURL}/rerank`
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Rerank Request Body
|
||||
*/
|
||||
protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) {
|
||||
const provider = this.base.rerankModelProvider
|
||||
const documents = searchResults.map((doc) => doc.pageContent)
|
||||
const topN = this.base.topN || 5
|
||||
|
||||
if (provider === 'voyageai') {
|
||||
return {
|
||||
model: this.base.rerankModel,
|
||||
query,
|
||||
documents,
|
||||
top_k: topN
|
||||
}
|
||||
} else if (provider === 'dashscope') {
|
||||
return {
|
||||
model: this.base.rerankModel,
|
||||
input: {
|
||||
query,
|
||||
documents
|
||||
},
|
||||
parameters: {
|
||||
top_n: topN
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
model: this.base.rerankModel,
|
||||
query,
|
||||
documents,
|
||||
top_n: topN
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract Rerank Result
|
||||
*/
|
||||
protected extractRerankResult(data: any) {
|
||||
const provider = this.base.rerankModelProvider
|
||||
if (provider === 'dashscope') {
|
||||
return data.output.results
|
||||
} else if (provider === 'voyageai') {
|
||||
return data.data
|
||||
} else {
|
||||
return data.results
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Rerank Result
|
||||
* @param searchResults
|
||||
|
||||
@ -1,58 +0,0 @@
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import axiosProxy from '@main/services/AxiosProxy'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
|
||||
interface DashscopeRerankResultItem {
|
||||
document: {
|
||||
text: string
|
||||
}
|
||||
index: number
|
||||
relevance_score: number
|
||||
}
|
||||
|
||||
interface DashscopeRerankResponse {
|
||||
output: {
|
||||
results: DashscopeRerankResultItem[]
|
||||
}
|
||||
usage: {
|
||||
total_tokens: number
|
||||
}
|
||||
request_id: string
|
||||
}
|
||||
|
||||
export default class DashscopeReranker extends BaseReranker {
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
super(base)
|
||||
}
|
||||
|
||||
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||
const url = 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
|
||||
|
||||
const requestBody = {
|
||||
model: this.base.rerankModel,
|
||||
input: {
|
||||
query,
|
||||
documents: searchResults.map((doc) => doc.pageContent)
|
||||
},
|
||||
parameters: {
|
||||
return_documents: true, // Recommended to be true to get document details if needed, though scores are primary
|
||||
top_n: this.base.topN || 5 // Default to 5 if topN is not specified, as per API example
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const { data } = await axiosProxy.axios.post<DashscopeRerankResponse>(url, requestBody, {
|
||||
headers: this.defaultHeaders()
|
||||
})
|
||||
|
||||
const rerankResults = data.output.results
|
||||
return this.getRerankResult(searchResults, rerankResults)
|
||||
} catch (error: any) {
|
||||
const errorDetails = this.formatErrorMessage(url, error, requestBody)
|
||||
console.error('Dashscope Reranker API 错误:', errorDetails)
|
||||
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,14 +0,0 @@
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
|
||||
export default class DefaultReranker extends BaseReranker {
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
super(base)
|
||||
}
|
||||
|
||||
async rerank(): Promise<ExtractChunkData[]> {
|
||||
throw new Error('Method not implemented.')
|
||||
}
|
||||
}
|
||||
@ -4,7 +4,7 @@ import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
|
||||
export default class JinaReranker extends BaseReranker {
|
||||
export default class GeneralReranker extends BaseReranker {
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
super(base)
|
||||
}
|
||||
@ -12,21 +12,15 @@ export default class JinaReranker extends BaseReranker {
|
||||
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||
const url = this.getRerankUrl()
|
||||
|
||||
const requestBody = {
|
||||
model: this.base.rerankModel,
|
||||
query,
|
||||
documents: searchResults.map((doc) => doc.pageContent),
|
||||
top_n: this.base.topN
|
||||
}
|
||||
const requestBody = this.getRerankRequestBody(query, searchResults)
|
||||
|
||||
try {
|
||||
const { data } = await AxiosProxy.axios.post(url, requestBody, { headers: this.defaultHeaders() })
|
||||
|
||||
const rerankResults = data.results
|
||||
const rerankResults = this.extractRerankResult(data)
|
||||
return this.getRerankResult(searchResults, rerankResults)
|
||||
} catch (error: any) {
|
||||
const errorDetails = this.formatErrorMessage(url, error, requestBody)
|
||||
console.error('Jina Reranker API Error:', errorDetails)
|
||||
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
|
||||
}
|
||||
}
|
||||
@ -1,13 +1,12 @@
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
import RerankerFactory from './RerankerFactory'
|
||||
import GeneralReranker from './GeneralReranker'
|
||||
|
||||
export default class Reranker {
|
||||
private sdk: BaseReranker
|
||||
private sdk: GeneralReranker
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
this.sdk = RerankerFactory.create(base)
|
||||
this.sdk = new GeneralReranker(base)
|
||||
}
|
||||
public async rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> {
|
||||
return this.sdk.rerank(query, searchResults)
|
||||
|
||||
@ -1,23 +0,0 @@
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
import DashscopeReranker from './DashscopeReranker'
|
||||
import DefaultReranker from './DefaultReranker'
|
||||
import JinaReranker from './JinaReranker'
|
||||
import SiliconFlowReranker from './SiliconFlowReranker'
|
||||
import VoyageReranker from './VoyageReranker'
|
||||
|
||||
export default class RerankerFactory {
|
||||
static create(base: KnowledgeBaseParams): BaseReranker {
|
||||
if (base.rerankModelProvider === 'silicon') {
|
||||
return new SiliconFlowReranker(base)
|
||||
} else if (base.rerankModelProvider === 'jina') {
|
||||
return new JinaReranker(base)
|
||||
} else if (base.rerankModelProvider === 'voyageai') {
|
||||
return new VoyageReranker(base)
|
||||
} else if (base.rerankModelProvider === 'dashscope') {
|
||||
return new DashscopeReranker(base)
|
||||
}
|
||||
return new DefaultReranker(base)
|
||||
}
|
||||
}
|
||||
@ -1,36 +0,0 @@
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import axiosProxy from '@main/services/AxiosProxy'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
|
||||
export default class SiliconFlowReranker extends BaseReranker {
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
super(base)
|
||||
}
|
||||
|
||||
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||
const url = this.getRerankUrl()
|
||||
|
||||
const requestBody = {
|
||||
model: this.base.rerankModel,
|
||||
query,
|
||||
documents: searchResults.map((doc) => doc.pageContent),
|
||||
top_n: this.base.topN,
|
||||
max_chunks_per_doc: this.base.chunkSize,
|
||||
overlap_tokens: this.base.chunkOverlap
|
||||
}
|
||||
|
||||
try {
|
||||
const { data } = await axiosProxy.axios.post(url, requestBody, { headers: this.defaultHeaders() })
|
||||
|
||||
const rerankResults = data.results
|
||||
return this.getRerankResult(searchResults, rerankResults)
|
||||
} catch (error: any) {
|
||||
const errorDetails = this.formatErrorMessage(url, error, requestBody)
|
||||
|
||||
console.error('SiliconFlow Reranker API 错误:', errorDetails)
|
||||
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,40 +0,0 @@
|
||||
import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import axiosProxy from '@main/services/AxiosProxy'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
|
||||
export default class VoyageReranker extends BaseReranker {
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
super(base)
|
||||
}
|
||||
|
||||
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||
const url = this.getRerankUrl()
|
||||
|
||||
const requestBody = {
|
||||
model: this.base.rerankModel,
|
||||
query,
|
||||
documents: searchResults.map((doc) => doc.pageContent),
|
||||
top_k: this.base.topN,
|
||||
return_documents: false,
|
||||
truncation: true
|
||||
}
|
||||
|
||||
try {
|
||||
const { data } = await axiosProxy.axios.post(url, requestBody, {
|
||||
headers: {
|
||||
...this.defaultHeaders()
|
||||
}
|
||||
})
|
||||
|
||||
const rerankResults = data.data
|
||||
return this.getRerankResult(searchResults, rerankResults)
|
||||
} catch (error: any) {
|
||||
const errorDetails = this.formatErrorMessage(url, error, requestBody)
|
||||
|
||||
console.error('Voyage Reranker API Error:', errorDetails)
|
||||
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
BIN
src/renderer/src/assets/images/search/bocha.webp
Normal file
BIN
src/renderer/src/assets/images/search/bocha.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.7 KiB |
41
src/renderer/src/components/Popups/SelectModelPopup/hook.ts
Normal file
41
src/renderer/src/components/Popups/SelectModelPopup/hook.ts
Normal file
@ -0,0 +1,41 @@
|
||||
import { useMemo, useReducer } from 'react'
|
||||
|
||||
import { initialScrollState, scrollReducer } from './reducer'
|
||||
import { FlatListItem, ScrollTrigger } from './types'
|
||||
|
||||
/**
|
||||
* 管理滚动和焦点状态的 hook
|
||||
*/
|
||||
export function useScrollState() {
|
||||
const [state, dispatch] = useReducer(scrollReducer, initialScrollState)
|
||||
|
||||
const actions = useMemo(
|
||||
() => ({
|
||||
setFocusedItemKey: (key: string) => dispatch({ type: 'SET_FOCUSED_ITEM_KEY', payload: key }),
|
||||
setScrollTrigger: (trigger: ScrollTrigger) => dispatch({ type: 'SET_SCROLL_TRIGGER', payload: trigger }),
|
||||
setLastScrollOffset: (offset: number) => dispatch({ type: 'SET_LAST_SCROLL_OFFSET', payload: offset }),
|
||||
setStickyGroup: (group: FlatListItem | null) => dispatch({ type: 'SET_STICKY_GROUP', payload: group }),
|
||||
setIsMouseOver: (isMouseOver: boolean) => dispatch({ type: 'SET_IS_MOUSE_OVER', payload: isMouseOver }),
|
||||
focusNextItem: (modelItems: FlatListItem[], step: number) =>
|
||||
dispatch({ type: 'FOCUS_NEXT_ITEM', payload: { modelItems, step } }),
|
||||
focusPage: (modelItems: FlatListItem[], currentIndex: number, step: number) =>
|
||||
dispatch({ type: 'FOCUS_PAGE', payload: { modelItems, currentIndex, step } }),
|
||||
searchChanged: (searchText: string) => dispatch({ type: 'SEARCH_CHANGED', payload: { searchText } }),
|
||||
updateOnListChange: (modelItems: FlatListItem[]) =>
|
||||
dispatch({ type: 'UPDATE_ON_LIST_CHANGE', payload: { modelItems } }),
|
||||
initScroll: () => dispatch({ type: 'INIT_SCROLL' })
|
||||
}),
|
||||
[]
|
||||
)
|
||||
|
||||
return {
|
||||
// 状态
|
||||
focusedItemKey: state.focusedItemKey,
|
||||
scrollTrigger: state.scrollTrigger,
|
||||
lastScrollOffset: state.lastScrollOffset,
|
||||
stickyGroup: state.stickyGroup,
|
||||
isMouseOver: state.isMouseOver,
|
||||
// 操作
|
||||
...actions
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,3 @@
|
||||
import { SelectModelPopup } from './popup'
|
||||
|
||||
export default SelectModelPopup
|
||||
@ -1,7 +1,9 @@
|
||||
import { PushpinOutlined } from '@ant-design/icons'
|
||||
import { HStack } from '@renderer/components/Layout'
|
||||
import ModelTagsWithLabel from '@renderer/components/ModelTagsWithLabel'
|
||||
import { TopView } from '@renderer/components/TopView'
|
||||
import { getModelLogo, isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||
import db from '@renderer/databases'
|
||||
import { usePinnedModels } from '@renderer/hooks/usePinnedModels'
|
||||
import { useProviders } from '@renderer/hooks/useProvider'
|
||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
import { Model } from '@renderer/types'
|
||||
@ -15,101 +17,61 @@ import { useTranslation } from 'react-i18next'
|
||||
import { FixedSizeList } from 'react-window'
|
||||
import styled from 'styled-components'
|
||||
|
||||
import { HStack } from '../Layout'
|
||||
import ModelTagsWithLabel from '../ModelTagsWithLabel'
|
||||
import { useScrollState } from './hook'
|
||||
import { FlatListItem } from './types'
|
||||
|
||||
const PAGE_SIZE = 9
|
||||
const ITEM_HEIGHT = 36
|
||||
|
||||
// 列表项类型,组名也作为列表项
|
||||
type ListItemType = 'group' | 'model'
|
||||
|
||||
// 滚动触发来源类型
|
||||
type ScrollTrigger = 'initial' | 'search' | 'keyboard' | 'none'
|
||||
|
||||
// 扁平化列表项接口
|
||||
interface FlatListItem {
|
||||
key: string
|
||||
type: ListItemType
|
||||
icon?: React.ReactNode
|
||||
name: React.ReactNode
|
||||
tags?: React.ReactNode
|
||||
model?: Model
|
||||
isPinned?: boolean
|
||||
isSelected?: boolean
|
||||
}
|
||||
|
||||
interface Props {
|
||||
interface PopupParams {
|
||||
model?: Model
|
||||
}
|
||||
|
||||
interface PopupContainerProps extends Props {
|
||||
interface Props extends PopupParams {
|
||||
resolve: (value: Model | undefined) => void
|
||||
}
|
||||
|
||||
const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
|
||||
const PopupContainer: React.FC<Props> = ({ model, resolve }) => {
|
||||
const { t } = useTranslation()
|
||||
const { providers } = useProviders()
|
||||
const { pinnedModels, togglePinnedModel, loading: loadingPinnedModels } = usePinnedModels()
|
||||
const [open, setOpen] = useState(true)
|
||||
const inputRef = useRef<InputRef>(null)
|
||||
const listRef = useRef<FixedSizeList>(null)
|
||||
const [_searchText, setSearchText] = useState('')
|
||||
const searchText = useDeferredValue(_searchText)
|
||||
const [isMouseOver, setIsMouseOver] = useState(false)
|
||||
const [pinnedModels, setPinnedModels] = useState<string[]>([])
|
||||
const [_focusedItemKey, setFocusedItemKey] = useState<string>('')
|
||||
const focusedItemKey = useDeferredValue(_focusedItemKey)
|
||||
const [_stickyGroup, setStickyGroup] = useState<FlatListItem | null>(null)
|
||||
const stickyGroup = useDeferredValue(_stickyGroup)
|
||||
const firstGroupRef = useRef<FlatListItem | null>(null)
|
||||
const scrollTriggerRef = useRef<ScrollTrigger>('initial')
|
||||
const lastScrollOffsetRef = useRef(0)
|
||||
|
||||
// 当前选中的模型ID
|
||||
const currentModelId = model ? getModelUniqId(model) : ''
|
||||
|
||||
// 加载置顶模型列表
|
||||
useEffect(() => {
|
||||
const loadPinnedModels = async () => {
|
||||
const setting = await db.settings.get('pinned:models')
|
||||
const savedPinnedModels = setting?.value || []
|
||||
// 管理滚动和焦点状态
|
||||
const {
|
||||
focusedItemKey,
|
||||
scrollTrigger,
|
||||
lastScrollOffset,
|
||||
stickyGroup: _stickyGroup,
|
||||
isMouseOver,
|
||||
setFocusedItemKey,
|
||||
setScrollTrigger,
|
||||
setLastScrollOffset,
|
||||
setStickyGroup,
|
||||
setIsMouseOver,
|
||||
focusNextItem,
|
||||
focusPage,
|
||||
searchChanged,
|
||||
updateOnListChange,
|
||||
initScroll
|
||||
} = useScrollState()
|
||||
|
||||
// Filter out invalid pinned models
|
||||
const allModelIds = providers.flatMap((p) => p.models || []).map((m) => getModelUniqId(m))
|
||||
const validPinnedModels = savedPinnedModels.filter((id) => allModelIds.includes(id))
|
||||
|
||||
// Update storage if there were invalid models
|
||||
if (validPinnedModels.length !== savedPinnedModels.length) {
|
||||
await db.settings.put({ id: 'pinned:models', value: validPinnedModels })
|
||||
}
|
||||
|
||||
setPinnedModels(sortBy(validPinnedModels))
|
||||
}
|
||||
|
||||
try {
|
||||
loadPinnedModels()
|
||||
} catch (error) {
|
||||
console.error('Failed to load pinned models', error)
|
||||
setPinnedModels([])
|
||||
}
|
||||
}, [providers])
|
||||
const stickyGroup = useDeferredValue(_stickyGroup)
|
||||
const firstGroupRef = useRef<FlatListItem | null>(null)
|
||||
|
||||
const togglePin = useCallback(
|
||||
async (modelId: string) => {
|
||||
const newPinnedModels = pinnedModels.includes(modelId)
|
||||
? pinnedModels.filter((id) => id !== modelId)
|
||||
: [...pinnedModels, modelId]
|
||||
|
||||
try {
|
||||
await db.settings.put({ id: 'pinned:models', value: newPinnedModels })
|
||||
setPinnedModels(sortBy(newPinnedModels))
|
||||
// Pin操作不触发滚动
|
||||
scrollTriggerRef.current = 'none'
|
||||
} catch (error) {
|
||||
console.error('Failed to update pinned models', error)
|
||||
}
|
||||
await togglePinnedModel(modelId)
|
||||
setScrollTrigger('none') // pin操作不触发滚动
|
||||
},
|
||||
[pinnedModels]
|
||||
[togglePinnedModel, setScrollTrigger]
|
||||
)
|
||||
|
||||
// 根据输入的文本筛选模型
|
||||
@ -222,6 +184,16 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
|
||||
return items
|
||||
}, [providers, getFilteredModels, pinnedModels, searchText, t, createModelItem])
|
||||
|
||||
// 获取可选择的模型项(过滤掉分组标题)
|
||||
const modelItems = useMemo(() => {
|
||||
return listItems.filter((item) => item.type === 'model')
|
||||
}, [listItems])
|
||||
|
||||
// 当搜索文本变化时更新滚动触发器
|
||||
useEffect(() => {
|
||||
searchChanged(searchText)
|
||||
}, [searchText, searchChanged])
|
||||
|
||||
// 基于滚动位置更新sticky分组标题
|
||||
const updateStickyGroup = useCallback(
|
||||
(scrollOffset?: number) => {
|
||||
@ -231,7 +203,7 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
|
||||
}
|
||||
|
||||
// 基于滚动位置计算当前可见的第一个项的索引
|
||||
const estimatedIndex = Math.floor((scrollOffset ?? lastScrollOffsetRef.current) / ITEM_HEIGHT)
|
||||
const estimatedIndex = Math.floor((scrollOffset ?? lastScrollOffset) / ITEM_HEIGHT)
|
||||
|
||||
// 从该索引向前查找最近的分组标题
|
||||
for (let i = estimatedIndex - 1; i >= 0; i--) {
|
||||
@ -242,9 +214,9 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
|
||||
}
|
||||
|
||||
// 找不到则使用第一个分组标题
|
||||
setStickyGroup(firstGroupRef.current ?? null)
|
||||
setStickyGroup(firstGroupRef.current)
|
||||
},
|
||||
[listItems]
|
||||
[listItems, lastScrollOffset, setStickyGroup]
|
||||
)
|
||||
|
||||
// 在listItems变化时更新sticky group
|
||||
@ -255,67 +227,46 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
|
||||
// 处理列表滚动事件,更新lastScrollOffset并更新sticky分组
|
||||
const handleScroll = useCallback(
|
||||
({ scrollOffset }) => {
|
||||
lastScrollOffsetRef.current = scrollOffset
|
||||
setLastScrollOffset(scrollOffset)
|
||||
updateStickyGroup(scrollOffset)
|
||||
},
|
||||
[updateStickyGroup]
|
||||
[updateStickyGroup, setLastScrollOffset]
|
||||
)
|
||||
|
||||
// 获取可选择的模型项(过滤掉分组标题)
|
||||
const modelItems = useMemo(() => {
|
||||
return listItems.filter((item) => item.type === 'model')
|
||||
}, [listItems])
|
||||
|
||||
// 搜索文本变化时设置滚动来源
|
||||
// 在列表项更新时,更新焦点项
|
||||
useEffect(() => {
|
||||
if (searchText.trim() !== '') {
|
||||
scrollTriggerRef.current = 'search'
|
||||
setFocusedItemKey('')
|
||||
}
|
||||
}, [searchText])
|
||||
|
||||
// 设置初始聚焦项以触发滚动
|
||||
useEffect(() => {
|
||||
if (scrollTriggerRef.current === 'initial' || scrollTriggerRef.current === 'search') {
|
||||
const selectedItem = modelItems.find((item) => item.isSelected)
|
||||
if (selectedItem) {
|
||||
setFocusedItemKey(selectedItem.key)
|
||||
} else if (scrollTriggerRef.current === 'initial' && modelItems.length > 0) {
|
||||
setFocusedItemKey(modelItems[0].key)
|
||||
}
|
||||
// 其余情况不设置focusedItemKey
|
||||
}
|
||||
}, [modelItems])
|
||||
updateOnListChange(modelItems)
|
||||
}, [modelItems, updateOnListChange])
|
||||
|
||||
// 滚动到聚焦项
|
||||
useEffect(() => {
|
||||
if (scrollTriggerRef.current === 'none' || !focusedItemKey) return
|
||||
if (scrollTrigger === 'none' || !focusedItemKey) return
|
||||
|
||||
const index = listItems.findIndex((item) => item.key === focusedItemKey)
|
||||
if (index < 0) return
|
||||
|
||||
// 根据触发源决定滚动对齐方式
|
||||
const alignment = scrollTriggerRef.current === 'keyboard' ? 'auto' : 'center'
|
||||
const alignment = scrollTrigger === 'keyboard' ? 'auto' : 'center'
|
||||
listRef.current?.scrollToItem(index, alignment)
|
||||
|
||||
// 滚动后重置触发器
|
||||
scrollTriggerRef.current = 'none'
|
||||
}, [focusedItemKey, listItems])
|
||||
setScrollTrigger('none')
|
||||
}, [focusedItemKey, scrollTrigger, listItems, setScrollTrigger])
|
||||
|
||||
const handleItemClick = useCallback(
|
||||
(item: FlatListItem) => {
|
||||
if (item.type === 'model') {
|
||||
scrollTriggerRef.current = 'none'
|
||||
setScrollTrigger('initial')
|
||||
resolve(item.model)
|
||||
setOpen(false)
|
||||
}
|
||||
},
|
||||
[resolve]
|
||||
[resolve, setScrollTrigger]
|
||||
)
|
||||
|
||||
// 处理键盘导航
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (e: KeyboardEvent) => {
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent) => {
|
||||
if (!open) return
|
||||
|
||||
if (modelItems.length === 0) {
|
||||
@ -329,43 +280,21 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
|
||||
setIsMouseOver(false)
|
||||
}
|
||||
|
||||
const getCurrentIndex = (currentKey: string) => {
|
||||
const currentIndex = modelItems.findIndex((item) => item.key === currentKey)
|
||||
return currentIndex < 0 ? 0 : currentIndex
|
||||
}
|
||||
const currentIndex = modelItems.findIndex((item) => item.key === focusedItemKey)
|
||||
const normalizedIndex = currentIndex < 0 ? 0 : currentIndex
|
||||
|
||||
switch (e.key) {
|
||||
case 'ArrowUp':
|
||||
scrollTriggerRef.current = 'keyboard'
|
||||
setFocusedItemKey((prev) => {
|
||||
const currentIndex = getCurrentIndex(prev)
|
||||
const nextIndex = (currentIndex - 1 + modelItems.length) % modelItems.length
|
||||
return modelItems[nextIndex].key
|
||||
})
|
||||
focusNextItem(modelItems, -1)
|
||||
break
|
||||
case 'ArrowDown':
|
||||
scrollTriggerRef.current = 'keyboard'
|
||||
setFocusedItemKey((prev) => {
|
||||
const currentIndex = getCurrentIndex(prev)
|
||||
const nextIndex = (currentIndex + 1) % modelItems.length
|
||||
return modelItems[nextIndex].key
|
||||
})
|
||||
focusNextItem(modelItems, 1)
|
||||
break
|
||||
case 'PageUp':
|
||||
scrollTriggerRef.current = 'keyboard'
|
||||
setFocusedItemKey((prev) => {
|
||||
const currentIndex = getCurrentIndex(prev)
|
||||
const nextIndex = Math.max(currentIndex - PAGE_SIZE, 0)
|
||||
return modelItems[nextIndex].key
|
||||
})
|
||||
focusPage(modelItems, normalizedIndex, -PAGE_SIZE)
|
||||
break
|
||||
case 'PageDown':
|
||||
scrollTriggerRef.current = 'keyboard'
|
||||
setFocusedItemKey((prev) => {
|
||||
const currentIndex = getCurrentIndex(prev)
|
||||
const nextIndex = Math.min(currentIndex + PAGE_SIZE, modelItems.length - 1)
|
||||
return modelItems[nextIndex].key
|
||||
})
|
||||
focusPage(modelItems, normalizedIndex, PAGE_SIZE)
|
||||
break
|
||||
case 'Enter':
|
||||
if (focusedItemKey) {
|
||||
@ -377,34 +306,47 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
|
||||
break
|
||||
case 'Escape':
|
||||
e.preventDefault()
|
||||
scrollTriggerRef.current = 'none'
|
||||
setScrollTrigger('none')
|
||||
setOpen(false)
|
||||
resolve(undefined)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
window.addEventListener('keydown', handleKeyDown)
|
||||
return () => window.removeEventListener('keydown', handleKeyDown)
|
||||
}, [focusedItemKey, modelItems, handleItemClick, open, resolve])
|
||||
|
||||
const onCancel = useCallback(() => {
|
||||
scrollTriggerRef.current = 'none'
|
||||
setOpen(false)
|
||||
}, [])
|
||||
|
||||
const onClose = useCallback(async () => {
|
||||
scrollTriggerRef.current = 'none'
|
||||
resolve(undefined)
|
||||
SelectModelPopup.hide()
|
||||
}, [resolve])
|
||||
},
|
||||
[
|
||||
focusedItemKey,
|
||||
modelItems,
|
||||
handleItemClick,
|
||||
open,
|
||||
resolve,
|
||||
setIsMouseOver,
|
||||
focusNextItem,
|
||||
focusPage,
|
||||
setScrollTrigger
|
||||
]
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
if (!open) return
|
||||
window.addEventListener('keydown', handleKeyDown)
|
||||
return () => window.removeEventListener('keydown', handleKeyDown)
|
||||
}, [handleKeyDown])
|
||||
|
||||
const onCancel = useCallback(() => {
|
||||
setScrollTrigger('initial')
|
||||
setOpen(false)
|
||||
}, [setScrollTrigger])
|
||||
|
||||
const onClose = useCallback(async () => {
|
||||
setScrollTrigger('initial')
|
||||
resolve(undefined)
|
||||
SelectModelPopup.hide()
|
||||
}, [resolve, setScrollTrigger])
|
||||
|
||||
// 初始化焦点和滚动位置
|
||||
useEffect(() => {
|
||||
if (!open || loadingPinnedModels) return
|
||||
setTimeout(() => inputRef.current?.focus(), 0)
|
||||
scrollTriggerRef.current = 'initial'
|
||||
lastScrollOffsetRef.current = 0
|
||||
}, [open])
|
||||
initScroll()
|
||||
}, [open, initScroll, loadingPinnedModels])
|
||||
|
||||
const RowData = useMemo(
|
||||
(): VirtualizedRowData => ({
|
||||
@ -415,7 +357,7 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
|
||||
handleItemClick,
|
||||
togglePin
|
||||
}),
|
||||
[stickyGroup, focusedItemKey, handleItemClick, listItems, togglePin]
|
||||
[stickyGroup, focusedItemKey, handleItemClick, listItems, togglePin, setFocusedItemKey]
|
||||
)
|
||||
|
||||
const listHeight = useMemo(() => {
|
||||
@ -470,7 +412,7 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
|
||||
<Divider style={{ margin: 0, marginTop: 4, borderBlockStartWidth: 0.5 }} />
|
||||
|
||||
{listItems.length > 0 ? (
|
||||
<ListContainer onMouseMove={() => setIsMouseOver(true)}>
|
||||
<ListContainer onMouseMove={() => !isMouseOver && setIsMouseOver(true)}>
|
||||
{/* Sticky Group Banner,它会替换第一个分组名称 */}
|
||||
<StickyGroupBanner>{stickyGroup?.name}</StickyGroupBanner>
|
||||
<FixedSizeList
|
||||
@ -685,14 +627,26 @@ const PinIconWrapper = styled.div.attrs({ className: 'pin-icon' })<{ $isPinned?:
|
||||
}
|
||||
`
|
||||
|
||||
export default class SelectModelPopup {
|
||||
const TopViewKey = 'SelectModelPopup'
|
||||
|
||||
export class SelectModelPopup {
|
||||
static topviewId = 0
|
||||
static hide() {
|
||||
TopView.hide('SelectModelPopup')
|
||||
TopView.hide(TopViewKey)
|
||||
}
|
||||
|
||||
static show(params: Props) {
|
||||
static show(params: PopupParams) {
|
||||
return new Promise<Model | undefined>((resolve) => {
|
||||
TopView.show(<PopupContainer {...params} resolve={resolve} />, 'SelectModelPopup')
|
||||
TopView.show(
|
||||
<PopupContainer
|
||||
{...params}
|
||||
resolve={(v) => {
|
||||
resolve(v)
|
||||
TopView.hide(TopViewKey)
|
||||
}}
|
||||
/>,
|
||||
TopViewKey
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
109
src/renderer/src/components/Popups/SelectModelPopup/reducer.ts
Normal file
109
src/renderer/src/components/Popups/SelectModelPopup/reducer.ts
Normal file
@ -0,0 +1,109 @@
|
||||
import { ScrollAction, ScrollState } from './types'
|
||||
|
||||
/**
|
||||
* 初始状态
|
||||
*/
|
||||
export const initialScrollState: ScrollState = {
|
||||
focusedItemKey: '',
|
||||
scrollTrigger: 'initial',
|
||||
lastScrollOffset: 0,
|
||||
stickyGroup: null,
|
||||
isMouseOver: false
|
||||
}
|
||||
|
||||
/**
|
||||
* 滚动状态的 reducer,用于避免复杂依赖可能带来的状态更新问题
|
||||
* @param state 当前状态
|
||||
* @param action 动作
|
||||
* @returns 新的状态
|
||||
*/
|
||||
export const scrollReducer = (state: ScrollState, action: ScrollAction): ScrollState => {
|
||||
switch (action.type) {
|
||||
case 'SET_FOCUSED_ITEM_KEY':
|
||||
return { ...state, focusedItemKey: action.payload }
|
||||
|
||||
case 'SET_SCROLL_TRIGGER':
|
||||
return { ...state, scrollTrigger: action.payload }
|
||||
|
||||
case 'SET_LAST_SCROLL_OFFSET':
|
||||
return { ...state, lastScrollOffset: action.payload }
|
||||
|
||||
case 'SET_STICKY_GROUP':
|
||||
return { ...state, stickyGroup: action.payload }
|
||||
|
||||
case 'SET_IS_MOUSE_OVER':
|
||||
return { ...state, isMouseOver: action.payload }
|
||||
|
||||
case 'FOCUS_NEXT_ITEM': {
|
||||
const { modelItems, step } = action.payload
|
||||
|
||||
if (modelItems.length === 0) {
|
||||
return {
|
||||
...state,
|
||||
focusedItemKey: '',
|
||||
scrollTrigger: 'keyboard'
|
||||
}
|
||||
}
|
||||
|
||||
const currentIndex = modelItems.findIndex((item) => item.key === state.focusedItemKey)
|
||||
const nextIndex = (currentIndex < 0 ? 0 : currentIndex + step + modelItems.length) % modelItems.length
|
||||
|
||||
return {
|
||||
...state,
|
||||
focusedItemKey: modelItems[nextIndex].key,
|
||||
scrollTrigger: 'keyboard'
|
||||
}
|
||||
}
|
||||
|
||||
case 'FOCUS_PAGE': {
|
||||
const { modelItems, currentIndex, step } = action.payload
|
||||
const nextIndex = Math.max(0, Math.min(currentIndex + step, modelItems.length - 1))
|
||||
|
||||
return {
|
||||
...state,
|
||||
focusedItemKey: modelItems.length > 0 ? modelItems[nextIndex].key : '',
|
||||
scrollTrigger: 'keyboard'
|
||||
}
|
||||
}
|
||||
|
||||
case 'SEARCH_CHANGED':
|
||||
return {
|
||||
...state,
|
||||
scrollTrigger: action.payload.searchText ? 'search' : 'initial'
|
||||
}
|
||||
|
||||
case 'UPDATE_ON_LIST_CHANGE': {
|
||||
const { modelItems } = action.payload
|
||||
|
||||
// 在列表变化时尝试聚焦一个模型:
|
||||
// - 如果是 initial 状态,先尝试聚焦当前选中的模型
|
||||
// - 如果是 search 状态,尝试聚焦第一个模型
|
||||
let newFocusedKey = ''
|
||||
if (state.scrollTrigger === 'initial' || state.scrollTrigger === 'search') {
|
||||
const selectedItem = modelItems.find((item) => item.isSelected)
|
||||
if (selectedItem && state.scrollTrigger === 'initial') {
|
||||
newFocusedKey = selectedItem.key
|
||||
} else if (modelItems.length > 0) {
|
||||
newFocusedKey = modelItems[0].key
|
||||
}
|
||||
} else {
|
||||
newFocusedKey = state.focusedItemKey
|
||||
}
|
||||
|
||||
return {
|
||||
...state,
|
||||
focusedItemKey: newFocusedKey
|
||||
}
|
||||
}
|
||||
|
||||
case 'INIT_SCROLL':
|
||||
return {
|
||||
...state,
|
||||
scrollTrigger: 'initial',
|
||||
lastScrollOffset: 0
|
||||
}
|
||||
|
||||
default:
|
||||
return state
|
||||
}
|
||||
}
|
||||
42
src/renderer/src/components/Popups/SelectModelPopup/types.ts
Normal file
42
src/renderer/src/components/Popups/SelectModelPopup/types.ts
Normal file
@ -0,0 +1,42 @@
|
||||
import { Model } from '@renderer/types'
|
||||
import { ReactNode } from 'react'
|
||||
|
||||
// 列表项类型,组名也作为列表项
|
||||
export type ListItemType = 'group' | 'model'
|
||||
|
||||
// 滚动触发来源类型
|
||||
export type ScrollTrigger = 'initial' | 'search' | 'keyboard' | 'none'
|
||||
|
||||
// 扁平化列表项接口
|
||||
export interface FlatListItem {
|
||||
key: string
|
||||
type: ListItemType
|
||||
icon?: ReactNode
|
||||
name: ReactNode
|
||||
tags?: ReactNode
|
||||
model?: Model
|
||||
isPinned?: boolean
|
||||
isSelected?: boolean
|
||||
}
|
||||
|
||||
// 滚动和焦点相关的状态类型
|
||||
export interface ScrollState {
|
||||
focusedItemKey: string
|
||||
scrollTrigger: ScrollTrigger
|
||||
lastScrollOffset: number
|
||||
stickyGroup: FlatListItem | null
|
||||
isMouseOver: boolean
|
||||
}
|
||||
|
||||
// 滚动和焦点相关的 action 类型
|
||||
export type ScrollAction =
|
||||
| { type: 'SET_FOCUSED_ITEM_KEY'; payload: string }
|
||||
| { type: 'SET_SCROLL_TRIGGER'; payload: ScrollTrigger }
|
||||
| { type: 'SET_LAST_SCROLL_OFFSET'; payload: number }
|
||||
| { type: 'SET_STICKY_GROUP'; payload: FlatListItem | null }
|
||||
| { type: 'SET_IS_MOUSE_OVER'; payload: boolean }
|
||||
| { type: 'FOCUS_NEXT_ITEM'; payload: { modelItems: FlatListItem[]; step: number } }
|
||||
| { type: 'FOCUS_PAGE'; payload: { modelItems: FlatListItem[]; currentIndex: number; step: number } }
|
||||
| { type: 'SEARCH_CHANGED'; payload: { searchText: string } }
|
||||
| { type: 'UPDATE_ON_LIST_CHANGE'; payload: { modelItems: FlatListItem[] } }
|
||||
| { type: 'INIT_SCROLL'; payload?: void }
|
||||
@ -237,23 +237,24 @@ export const CLAUDE_SUPPORTED_WEBSEARCH_REGEX = new RegExp(
|
||||
)
|
||||
|
||||
export function isFunctionCallingModel(model: Model): boolean {
|
||||
if (model.type?.includes('function_calling')) {
|
||||
return true
|
||||
}
|
||||
if (!model) return false
|
||||
if (model.type) {
|
||||
return model.type.includes('function_calling')
|
||||
} else {
|
||||
if (isEmbeddingModel(model)) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (isEmbeddingModel(model)) {
|
||||
return false
|
||||
}
|
||||
if (model.provider === 'qiniu') {
|
||||
return ['deepseek-v3-tool', 'deepseek-v3-0324', 'qwq-32b', 'qwen2.5-72b-instruct'].includes(model.id)
|
||||
}
|
||||
|
||||
if (model.provider === 'qiniu') {
|
||||
return ['deepseek-v3-tool', 'deepseek-v3-0324', 'qwq-32b', 'qwen2.5-72b-instruct'].includes(model.id)
|
||||
}
|
||||
if (['deepseek', 'anthropic'].includes(model.provider)) {
|
||||
return true
|
||||
}
|
||||
|
||||
if (['deepseek', 'anthropic'].includes(model.provider)) {
|
||||
return true
|
||||
return FUNCTION_CALLING_REGEX.test(model.id)
|
||||
}
|
||||
|
||||
return FUNCTION_CALLING_REGEX.test(model.id)
|
||||
}
|
||||
|
||||
export function getModelLogo(modelId: string) {
|
||||
@ -2188,20 +2189,23 @@ export function isEmbeddingModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
if (model.type) {
|
||||
return model.type.includes('embedding')
|
||||
} else {
|
||||
if (['anthropic'].includes(model?.provider)) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (['anthropic'].includes(model?.provider)) {
|
||||
return false
|
||||
if (model.provider === 'doubao') {
|
||||
return EMBEDDING_REGEX.test(model.name)
|
||||
}
|
||||
|
||||
if (isRerankModel(model)) {
|
||||
return false
|
||||
}
|
||||
|
||||
return EMBEDDING_REGEX.test(model.id)
|
||||
}
|
||||
|
||||
if (model.provider === 'doubao') {
|
||||
return EMBEDDING_REGEX.test(model.name)
|
||||
}
|
||||
|
||||
if (isRerankModel(model)) {
|
||||
return false
|
||||
}
|
||||
|
||||
return EMBEDDING_REGEX.test(model.id) || model.type?.includes('embedding') || false
|
||||
}
|
||||
|
||||
export function isRerankModel(model: Model): boolean {
|
||||
@ -2212,16 +2216,20 @@ export function isVisionModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
// 新添字段 copilot-vision-request 后可使用 vision
|
||||
// if (model.provider === 'copilot') {
|
||||
// return false
|
||||
// }
|
||||
if (model.type) {
|
||||
return model.type.includes('vision')
|
||||
} else {
|
||||
// 新添字段 copilot-vision-request 后可使用 vision
|
||||
// if (model.provider === 'copilot') {
|
||||
// return false
|
||||
// }
|
||||
|
||||
if (model.provider === 'doubao') {
|
||||
return VISION_REGEX.test(model.name) || model.type?.includes('vision') || false
|
||||
if (model.provider === 'doubao') {
|
||||
return VISION_REGEX.test(model.name)
|
||||
}
|
||||
|
||||
return VISION_REGEX.test(model.id)
|
||||
}
|
||||
|
||||
return VISION_REGEX.test(model.id) || model.type?.includes('vision') || false
|
||||
}
|
||||
|
||||
export function isOpenAIReasoningModel(model: Model): boolean {
|
||||
@ -2355,23 +2363,26 @@ export function isReasoningModel(model?: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
if (model.type) {
|
||||
return model.type.includes('reasoning')
|
||||
} else {
|
||||
if (model.provider === 'doubao') {
|
||||
return REASONING_REGEX.test(model.name)
|
||||
}
|
||||
|
||||
if (model.provider === 'doubao') {
|
||||
return REASONING_REGEX.test(model.name) || model.type?.includes('reasoning') || false
|
||||
if (
|
||||
isClaudeReasoningModel(model) ||
|
||||
isOpenAIReasoningModel(model) ||
|
||||
isGeminiReasoningModel(model) ||
|
||||
isQwenReasoningModel(model) ||
|
||||
isGrokReasoningModel(model) ||
|
||||
model.id.includes('glm-z1')
|
||||
) {
|
||||
return true
|
||||
}
|
||||
|
||||
return REASONING_REGEX.test(model.id)
|
||||
}
|
||||
|
||||
if (
|
||||
isClaudeReasoningModel(model) ||
|
||||
isOpenAIReasoningModel(model) ||
|
||||
isGeminiReasoningModel(model) ||
|
||||
isQwenReasoningModel(model) ||
|
||||
isGrokReasoningModel(model) ||
|
||||
model.id.includes('glm-z1')
|
||||
) {
|
||||
return true
|
||||
}
|
||||
|
||||
return REASONING_REGEX.test(model.id) || model.type?.includes('reasoning') || false
|
||||
}
|
||||
|
||||
export function isSupportedModel(model: OpenAI.Models.Model): boolean {
|
||||
@ -2386,89 +2397,86 @@ export function isWebSearchModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (model.type) {
|
||||
if (model.type.includes('web_search')) {
|
||||
return true
|
||||
return model.type.includes('web_search')
|
||||
} else {
|
||||
const provider = getProviderByModel(model)
|
||||
|
||||
if (!provider) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
const provider = getProviderByModel(model)
|
||||
const isEmbedding = isEmbeddingModel(model)
|
||||
|
||||
if (!provider) {
|
||||
return false
|
||||
}
|
||||
if (isEmbedding) {
|
||||
return false
|
||||
}
|
||||
|
||||
const isEmbedding = isEmbeddingModel(model)
|
||||
if (model.id.includes('claude')) {
|
||||
return CLAUDE_SUPPORTED_WEBSEARCH_REGEX.test(model.id)
|
||||
}
|
||||
|
||||
if (isEmbedding) {
|
||||
return false
|
||||
}
|
||||
if (provider.type === 'openai') {
|
||||
if (
|
||||
isOpenAILLMModel(model) &&
|
||||
!isTextToImageModel(model) &&
|
||||
!isOpenAIReasoningModel(model) &&
|
||||
!GENERATE_IMAGE_MODELS.includes(model.id)
|
||||
) {
|
||||
return true
|
||||
}
|
||||
|
||||
if (model.id.includes('claude')) {
|
||||
return CLAUDE_SUPPORTED_WEBSEARCH_REGEX.test(model.id)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if (provider.type === 'openai') {
|
||||
if (
|
||||
isOpenAILLMModel(model) &&
|
||||
!isTextToImageModel(model) &&
|
||||
!isOpenAIReasoningModel(model) &&
|
||||
!GENERATE_IMAGE_MODELS.includes(model.id)
|
||||
) {
|
||||
if (provider.id === 'perplexity') {
|
||||
return PERPLEXITY_SEARCH_MODELS.includes(model?.id)
|
||||
}
|
||||
|
||||
if (provider.id === 'aihubmix') {
|
||||
if (
|
||||
isOpenAILLMModel(model) &&
|
||||
!isTextToImageModel(model) &&
|
||||
!isOpenAIReasoningModel(model) &&
|
||||
!GENERATE_IMAGE_MODELS.includes(model.id)
|
||||
) {
|
||||
return true
|
||||
}
|
||||
|
||||
const models = ['gemini-2.0-flash-search', 'gemini-2.0-flash-exp-search', 'gemini-2.0-pro-exp-02-05-search']
|
||||
return models.includes(model?.id)
|
||||
}
|
||||
|
||||
if (provider?.type === 'openai-compatible') {
|
||||
if (GEMINI_SEARCH_MODELS.includes(model?.id) || isOpenAIWebSearch(model)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if (provider.id === 'gemini' || provider?.type === 'gemini') {
|
||||
return GEMINI_SEARCH_MODELS.includes(model?.id)
|
||||
}
|
||||
|
||||
if (provider.id === 'hunyuan') {
|
||||
return model?.id !== 'hunyuan-lite'
|
||||
}
|
||||
|
||||
if (provider.id === 'zhipu') {
|
||||
return model?.id?.startsWith('glm-4-')
|
||||
}
|
||||
|
||||
if (provider.id === 'dashscope') {
|
||||
const models = ['qwen-turbo', 'qwen-max', 'qwen-plus', 'qwq']
|
||||
// matches id like qwen-max-0919, qwen-max-latest
|
||||
return models.some((i) => model.id.startsWith(i))
|
||||
}
|
||||
|
||||
if (provider.id === 'openrouter') {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
if (provider.id === 'perplexity') {
|
||||
return PERPLEXITY_SEARCH_MODELS.includes(model?.id)
|
||||
}
|
||||
|
||||
if (provider.id === 'aihubmix') {
|
||||
if (
|
||||
isOpenAILLMModel(model) &&
|
||||
!isTextToImageModel(model) &&
|
||||
!isOpenAIReasoningModel(model) &&
|
||||
!GENERATE_IMAGE_MODELS.includes(model.id)
|
||||
) {
|
||||
return true
|
||||
}
|
||||
|
||||
const models = ['gemini-2.0-flash-search', 'gemini-2.0-flash-exp-search', 'gemini-2.0-pro-exp-02-05-search']
|
||||
return models.includes(model?.id)
|
||||
}
|
||||
|
||||
if (provider?.type === 'openai-compatible') {
|
||||
if (GEMINI_SEARCH_MODELS.includes(model?.id) || isOpenAIWebSearch(model)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if (provider.id === 'gemini' || provider?.type === 'gemini') {
|
||||
return GEMINI_SEARCH_MODELS.includes(model?.id)
|
||||
}
|
||||
|
||||
if (provider.id === 'hunyuan') {
|
||||
return model?.id !== 'hunyuan-lite'
|
||||
}
|
||||
|
||||
if (provider.id === 'zhipu') {
|
||||
return model?.id?.startsWith('glm-4-')
|
||||
}
|
||||
|
||||
if (provider.id === 'dashscope') {
|
||||
const models = ['qwen-turbo', 'qwen-max', 'qwen-plus', 'qwq']
|
||||
// matches id like qwen-max-0919, qwen-max-latest
|
||||
return models.some((i) => model.id.startsWith(i))
|
||||
}
|
||||
|
||||
if (provider.id === 'openrouter') {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
export function isGenerateImageModel(model: Model): boolean {
|
||||
|
||||
@ -95,7 +95,8 @@ export function getProviderLogo(providerId: string) {
|
||||
return PROVIDER_LOGO_MAP[providerId as keyof typeof PROVIDER_LOGO_MAP]
|
||||
}
|
||||
|
||||
export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai', 'dashscope', 'aihubmix']
|
||||
// export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai', 'dashscope', 'aihubmix']
|
||||
export const NOT_SUPPORTED_REANK_PROVIDERS = ['ollama']
|
||||
|
||||
export const PROVIDER_CONFIG = {
|
||||
openai: {
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import BochaLogo from '@renderer/assets/images/search/bocha.webp'
|
||||
import ExaLogo from '@renderer/assets/images/search/exa.png'
|
||||
import SearxngLogo from '@renderer/assets/images/search/searxng.svg'
|
||||
import TavilyLogo from '@renderer/assets/images/search/tavily.png'
|
||||
|
||||
export function getWebSearchProviderLogo(providerId: string) {
|
||||
switch (providerId) {
|
||||
case 'tavily':
|
||||
@ -9,6 +11,8 @@ export function getWebSearchProviderLogo(providerId: string) {
|
||||
return SearxngLogo
|
||||
case 'exa':
|
||||
return ExaLogo
|
||||
case 'bocha':
|
||||
return BochaLogo
|
||||
default:
|
||||
return undefined
|
||||
}
|
||||
@ -32,6 +36,12 @@ export const WEB_SEARCH_PROVIDER_CONFIG = {
|
||||
apiKey: 'https://dashboard.exa.ai/api-keys'
|
||||
}
|
||||
},
|
||||
bocha: {
|
||||
websites: {
|
||||
official: 'https://bochaai.com',
|
||||
apiKey: 'https://open.bochaai.com/overview'
|
||||
}
|
||||
},
|
||||
'local-google': {
|
||||
websites: {
|
||||
official: 'https://www.google.com'
|
||||
|
||||
63
src/renderer/src/hooks/usePinnedModels.ts
Normal file
63
src/renderer/src/hooks/usePinnedModels.ts
Normal file
@ -0,0 +1,63 @@
|
||||
import db from '@renderer/databases'
|
||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
import { sortBy } from 'lodash'
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
|
||||
import { useProviders } from './useProvider'
|
||||
|
||||
export const usePinnedModels = () => {
|
||||
const [pinnedModels, setPinnedModels] = useState<string[]>([])
|
||||
const [loading, setLoading] = useState(true)
|
||||
const { providers } = useProviders()
|
||||
|
||||
useEffect(() => {
|
||||
const loadPinnedModels = async () => {
|
||||
setLoading(true)
|
||||
const setting = await db.settings.get('pinned:models')
|
||||
const savedPinnedModels = setting?.value || []
|
||||
|
||||
// Filter out invalid pinned models
|
||||
const allModelIds = providers.flatMap((p) => p.models || []).map((m) => getModelUniqId(m))
|
||||
const validPinnedModels = savedPinnedModels.filter((id) => allModelIds.includes(id))
|
||||
|
||||
// Update storage if there were invalid models
|
||||
if (validPinnedModels.length !== savedPinnedModels.length) {
|
||||
await db.settings.put({ id: 'pinned:models', value: validPinnedModels })
|
||||
}
|
||||
|
||||
setPinnedModels(sortBy(validPinnedModels))
|
||||
setLoading(false)
|
||||
}
|
||||
|
||||
loadPinnedModels().catch((error) => {
|
||||
console.error('Failed to load pinned models', error)
|
||||
setPinnedModels([])
|
||||
setLoading(false)
|
||||
})
|
||||
}, [providers])
|
||||
|
||||
const updatePinnedModels = useCallback(async (models: string[]) => {
|
||||
await db.settings.put({ id: 'pinned:models', value: models })
|
||||
setPinnedModels(sortBy(models))
|
||||
}, [])
|
||||
|
||||
/**
|
||||
* Toggle a single pinned model
|
||||
* @param modelId - The ID string of the model to toggle
|
||||
*/
|
||||
const togglePinnedModel = useCallback(
|
||||
async (modelId: string) => {
|
||||
try {
|
||||
const newPinnedModels = pinnedModels.includes(modelId)
|
||||
? pinnedModels.filter((id) => id !== modelId)
|
||||
: [...pinnedModels, modelId]
|
||||
await updatePinnedModels(newPinnedModels)
|
||||
} catch (error) {
|
||||
console.error('Failed to toggle pinned model', error)
|
||||
}
|
||||
},
|
||||
[pinnedModels, updatePinnedModels]
|
||||
)
|
||||
|
||||
return { pinnedModels, updatePinnedModels, togglePinnedModel, loading }
|
||||
}
|
||||
@ -705,6 +705,7 @@
|
||||
"pinned": "Pinned",
|
||||
"rerank_model": "Reordering Model",
|
||||
"rerank_model_support_provider": "Currently, the reordering model only supports some providers ({{provider}})",
|
||||
"rerank_model_not_support_provider": "Currently, the reordering model does not support this provider ({{provider}})",
|
||||
"rerank_model_tooltip": "Click the Manage button in Settings -> Model Services to add.",
|
||||
"search": "Search models...",
|
||||
"stream_output": "Stream output",
|
||||
@ -1343,6 +1344,7 @@
|
||||
"providerPlaceholder": "Provider name",
|
||||
"advancedSettings": "Advanced Settings"
|
||||
},
|
||||
"messages.prompt": "Show prompt",
|
||||
"messages.divider": "Show divider between messages",
|
||||
"messages.grid_columns": "Message grid display columns",
|
||||
"messages.grid_popover_trigger": "Grid detail trigger",
|
||||
|
||||
@ -719,7 +719,8 @@
|
||||
"text": "テキスト",
|
||||
"vision": "画像",
|
||||
"websearch": "ウェブ検索"
|
||||
}
|
||||
},
|
||||
"rerank_model_not_support_provider": "現在、並べ替えモデルはこのプロバイダー ({{provider}}) をサポートしていません。"
|
||||
},
|
||||
"navbar": {
|
||||
"expand": "ダイアログを展開",
|
||||
@ -1341,6 +1342,7 @@
|
||||
"providerPlaceholder": "プロバイダー名",
|
||||
"advancedSettings": "詳細設定"
|
||||
},
|
||||
"messages.prompt": "プロンプト表示",
|
||||
"messages.divider": "メッセージ間に区切り線を表示",
|
||||
"messages.grid_columns": "メッセージグリッドの表示列数",
|
||||
"messages.grid_popover_trigger": "グリッド詳細トリガー",
|
||||
|
||||
@ -719,7 +719,8 @@
|
||||
"text": "Текст",
|
||||
"vision": "Визуальные",
|
||||
"websearch": "Веб-поисковые"
|
||||
}
|
||||
},
|
||||
"rerank_model_not_support_provider": "В настоящее время модель переупорядочивания не поддерживает этого провайдера ({{provider}})"
|
||||
},
|
||||
"navbar": {
|
||||
"expand": "Развернуть диалоговое окно",
|
||||
@ -1341,6 +1342,7 @@
|
||||
"providerPlaceholder": "Имя провайдера",
|
||||
"advancedSettings": "Расширенные настройки"
|
||||
},
|
||||
"messages.prompt": "Показывать подсказки",
|
||||
"messages.divider": "Показывать разделитель между сообщениями",
|
||||
"messages.grid_columns": "Количество столбцов сетки сообщений",
|
||||
"messages.grid_popover_trigger": "Триггер для отображения подробной информации в сетке",
|
||||
|
||||
@ -705,6 +705,7 @@
|
||||
"pinned": "已固定",
|
||||
"rerank_model": "重排模型",
|
||||
"rerank_model_support_provider": "目前重排序模型仅支持部分服务商 ({{provider}})",
|
||||
"rerank_model_not_support_provider": "目前重排序模型不支持该服务商 ({{provider}})",
|
||||
"rerank_model_tooltip": "在设置->模型服务中点击管理按钮添加",
|
||||
"search": "搜索模型...",
|
||||
"stream_output": "流式输出",
|
||||
@ -1343,6 +1344,7 @@
|
||||
"providerPlaceholder": "提供者名称",
|
||||
"advancedSettings": "高级设置"
|
||||
},
|
||||
"messages.prompt": "提示词显示",
|
||||
"messages.divider": "消息分割线",
|
||||
"messages.grid_columns": "消息网格展示列数",
|
||||
"messages.grid_popover_trigger": "网格详情触发",
|
||||
|
||||
@ -719,7 +719,8 @@
|
||||
"text": "文字",
|
||||
"vision": "視覺",
|
||||
"websearch": "網路搜尋"
|
||||
}
|
||||
},
|
||||
"rerank_model_not_support_provider": "目前,重新排序模型不支援此提供者({{provider}})"
|
||||
},
|
||||
"navbar": {
|
||||
"expand": "伸縮對話框",
|
||||
@ -1342,6 +1343,7 @@
|
||||
"providerPlaceholder": "提供者名稱",
|
||||
"advancedSettings": "高級設定"
|
||||
},
|
||||
"messages.prompt": "提示詞顯示",
|
||||
"messages.divider": "訊息間顯示分隔線",
|
||||
"messages.grid_columns": "訊息網格展示列數",
|
||||
"messages.grid_popover_trigger": "網格詳細資訊觸發",
|
||||
|
||||
@ -42,7 +42,7 @@ interface MessagesProps {
|
||||
|
||||
const Messages: React.FC<MessagesProps> = ({ assistant, topic, setActiveTopic }) => {
|
||||
const { t } = useTranslation()
|
||||
const { showTopics, topicPosition, showAssistants, messageNavigation } = useSettings()
|
||||
const { showPrompt, showTopics, topicPosition, showAssistants, messageNavigation } = useSettings()
|
||||
const { updateTopic, addTopic } = useAssistant(assistant.id)
|
||||
const dispatch = useAppDispatch()
|
||||
const containerRef = useRef<HTMLDivElement>(null)
|
||||
@ -254,7 +254,7 @@ const Messages: React.FC<MessagesProps> = ({ assistant, topic, setActiveTopic })
|
||||
)}
|
||||
</ScrollContainer>
|
||||
</InfiniteScroll>
|
||||
<Prompt assistant={assistant} key={assistant.prompt} topic={topic} />
|
||||
{showPrompt && <Prompt assistant={assistant} key={assistant.prompt} topic={topic} />}
|
||||
</NarrowLayout>
|
||||
{messageNavigation === 'anchor' && <MessageAnchorLine messages={displayMessages} />}
|
||||
{messageNavigation === 'buttons' && <ChatNavigation containerId="messages" />}
|
||||
|
||||
@ -37,6 +37,7 @@ import {
|
||||
setPasteLongTextThreshold,
|
||||
setRenderInputMessageAsMarkdown,
|
||||
setShowInputEstimatedTokens,
|
||||
setShowPrompt,
|
||||
setShowMessageDivider,
|
||||
setShowTranslateConfirm,
|
||||
setThoughtAutoCollapse
|
||||
@ -76,6 +77,7 @@ const SettingsTab: FC<Props> = (props) => {
|
||||
const dispatch = useAppDispatch()
|
||||
|
||||
const {
|
||||
showPrompt,
|
||||
showMessageDivider,
|
||||
messageFont,
|
||||
showInputEstimatedTokens,
|
||||
@ -282,6 +284,11 @@ const SettingsTab: FC<Props> = (props) => {
|
||||
<SettingGroup>
|
||||
<SettingSubtitle style={{ marginTop: 0 }}>{t('settings.messages.title')}</SettingSubtitle>
|
||||
<SettingDivider />
|
||||
<SettingRow>
|
||||
<SettingRowTitleSmall>{t('settings.messages.prompt')}</SettingRowTitleSmall>
|
||||
<Switch size="small" checked={showPrompt} onChange={(checked) => dispatch(setShowPrompt(checked))} />
|
||||
</SettingRow>
|
||||
<SettingDivider />
|
||||
<SettingRow>
|
||||
<SettingRowTitleSmall>{t('settings.messages.divider')}</SettingRowTitleSmall>
|
||||
<Switch
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import { TopView } from '@renderer/components/TopView'
|
||||
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||
import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
|
||||
import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
|
||||
// import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
|
||||
import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
|
||||
import { useProviders } from '@renderer/hooks/useProvider'
|
||||
import { SettingHelpText } from '@renderer/pages/settings'
|
||||
@ -67,7 +68,7 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
|
||||
|
||||
const rerankSelectOptions = providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.filter((p) => SUPPORTED_REANK_PROVIDERS.includes(p.id))
|
||||
.filter((p) => !NOT_SUPPORTED_REANK_PROVIDERS.includes(p.id))
|
||||
.map((p) => ({
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
@ -176,8 +177,8 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
|
||||
<Select style={{ width: '100%' }} options={rerankSelectOptions} placeholder={t('settings.models.empty')} />
|
||||
</Form.Item>
|
||||
<SettingHelpText style={{ marginTop: -15, marginBottom: 20 }}>
|
||||
{t('models.rerank_model_support_provider', {
|
||||
provider: SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
|
||||
{t('models.rerank_model_not_support_provider', {
|
||||
provider: NOT_SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
|
||||
})}
|
||||
</SettingHelpText>
|
||||
<Form.Item
|
||||
|
||||
@ -3,7 +3,8 @@ import { TopView } from '@renderer/components/TopView'
|
||||
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
|
||||
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||
import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
|
||||
import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
|
||||
// import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
|
||||
import { useKnowledge } from '@renderer/hooks/useKnowledge'
|
||||
import { useProviders } from '@renderer/hooks/useProvider'
|
||||
import { SettingHelpText } from '@renderer/pages/settings'
|
||||
@ -68,7 +69,7 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
||||
|
||||
const rerankSelectOptions = providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.filter((p) => SUPPORTED_REANK_PROVIDERS.includes(p.id))
|
||||
.filter((p) => !NOT_SUPPORTED_REANK_PROVIDERS.includes(p.id))
|
||||
.map((p) => ({
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
@ -157,8 +158,8 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
||||
/>
|
||||
</Form.Item>
|
||||
<SettingHelpText style={{ marginTop: -15, marginBottom: 20 }}>
|
||||
{t('models.rerank_model_support_provider', {
|
||||
provider: SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
|
||||
{t('models.rerank_model_not_support_provider', {
|
||||
provider: NOT_SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
|
||||
})}
|
||||
</SettingHelpText>
|
||||
|
||||
|
||||
@ -132,7 +132,7 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ model, onUpdateModel, ope
|
||||
] as ModelType[]
|
||||
|
||||
// 合并现有选择和默认类型
|
||||
const selectedTypes = [...new Set([...(model.type || []), ...defaultTypes])]
|
||||
const selectedTypes = model.type ? model.type : defaultTypes
|
||||
|
||||
const showTypeConfirmModal = (type: string) => {
|
||||
window.modal.confirm({
|
||||
@ -165,28 +165,23 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ model, onUpdateModel, ope
|
||||
options={[
|
||||
{
|
||||
label: t('models.type.vision'),
|
||||
value: 'vision',
|
||||
disabled: isVisionModel(model) && !selectedTypes.includes('vision')
|
||||
value: 'vision'
|
||||
},
|
||||
{
|
||||
label: t('models.type.websearch'),
|
||||
value: 'web_search',
|
||||
disabled: isWebSearchModel(model) && !selectedTypes.includes('web_search')
|
||||
value: 'web_search'
|
||||
},
|
||||
{
|
||||
label: t('models.type.embedding'),
|
||||
value: 'embedding',
|
||||
disabled: isEmbeddingModel(model) && !selectedTypes.includes('embedding')
|
||||
value: 'embedding'
|
||||
},
|
||||
{
|
||||
label: t('models.type.reasoning'),
|
||||
value: 'reasoning',
|
||||
disabled: isReasoningModel(model) && !selectedTypes.includes('reasoning')
|
||||
value: 'reasoning'
|
||||
},
|
||||
{
|
||||
label: t('models.type.function_calling'),
|
||||
value: 'function_calling',
|
||||
disabled: isFunctionCallingModel(model) && !selectedTypes.includes('function_calling')
|
||||
value: 'function_calling'
|
||||
}
|
||||
]}
|
||||
/>
|
||||
|
||||
@ -192,14 +192,11 @@ const WebSearchProviderSetting: FC<Props> = ({ provider: _provider }) => {
|
||||
onChange={(e) => setApiHost(e.target.value)}
|
||||
onBlur={onUpdateApiHost}
|
||||
/>
|
||||
<Button
|
||||
ghost={apiValid}
|
||||
type={apiValid ? 'primary' : 'default'}
|
||||
onClick={checkSearch}
|
||||
disabled={apiChecking}>
|
||||
{apiChecking ? <LoadingOutlined spin /> : apiValid ? <CheckOutlined /> : t('settings.websearch.check')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{hasObjectKey(provider, 'basicAuthUsername') && (
|
||||
<>
|
||||
<SettingDivider style={{ marginTop: 12, marginBottom: 12 }} />
|
||||
<SettingSubtitle style={{ marginTop: 5, marginBottom: 10 }}>
|
||||
{t('settings.provider.basic_auth')}
|
||||
|
||||
@ -402,7 +402,7 @@ export default class OpenAICompatibleProvider extends BaseOpenAiProvider {
|
||||
await this.checkIsCopilot()
|
||||
|
||||
const lastUserMsg = userMessages.findLast((m) => m.role === 'user')
|
||||
if (lastUserMsg) {
|
||||
if (lastUserMsg && isSupportedThinkingTokenQwenModel(model)) {
|
||||
const postsuffix = '/no_think'
|
||||
// qwenThinkMode === true 表示思考模式啓用,此時不應添加 /no_think,如果存在則移除
|
||||
const qwenThinkModeEnabled = assistant.settings?.qwenThinkMode === true
|
||||
|
||||
@ -4,18 +4,32 @@ import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
|
||||
export default abstract class BaseWebSearchProvider {
|
||||
// @ts-ignore this
|
||||
protected provider: WebSearchProvider
|
||||
protected apiHost?: string
|
||||
protected apiKey: string
|
||||
|
||||
constructor(provider: WebSearchProvider) {
|
||||
this.provider = provider
|
||||
this.apiHost = this.getApiHost()
|
||||
this.apiKey = this.getApiKey()
|
||||
}
|
||||
|
||||
abstract search(
|
||||
query: string,
|
||||
websearch: WebSearchState,
|
||||
httpOptions?: RequestInit
|
||||
): Promise<WebSearchProviderResponse>
|
||||
|
||||
public getApiHost() {
|
||||
return this.provider.apiHost
|
||||
}
|
||||
|
||||
public defaultHeaders() {
|
||||
return {
|
||||
'HTTP-Referer': 'https://cherry-ai.com',
|
||||
'X-Title': 'Cherry Studio'
|
||||
}
|
||||
}
|
||||
|
||||
public getApiKey() {
|
||||
const keys = this.provider.apiKey?.split(',').map((key) => key.trim()) || []
|
||||
const keyName = `web-search-provider:${this.provider.id}:last_used_key`
|
||||
|
||||
@ -0,0 +1,70 @@
|
||||
import { WebSearchState } from '@renderer/store/websearch'
|
||||
import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
|
||||
import { BochaSearchParams, BochaSearchResponse } from '@renderer/types/bocha'
|
||||
|
||||
import BaseWebSearchProvider from './BaseWebSearchProvider'
|
||||
|
||||
export default class BochaProvider extends BaseWebSearchProvider {
|
||||
constructor(provider: WebSearchProvider) {
|
||||
super(provider)
|
||||
if (!this.apiKey) {
|
||||
throw new Error('API key is required for Bocha provider')
|
||||
}
|
||||
if (!this.apiHost) {
|
||||
throw new Error('API host is required for Bocha provider')
|
||||
}
|
||||
}
|
||||
|
||||
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
|
||||
try {
|
||||
if (!query.trim()) {
|
||||
throw new Error('Search query cannot be empty')
|
||||
}
|
||||
|
||||
const headers = {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${this.apiKey}`
|
||||
}
|
||||
|
||||
const contentLimit = websearch.contentLimit
|
||||
|
||||
const params: BochaSearchParams = {
|
||||
query,
|
||||
count: websearch.maxResults,
|
||||
exclude: websearch.excludeDomains.join(','),
|
||||
freshness: websearch.searchWithTime ? 'oneDay' : 'noLimit',
|
||||
summary: false,
|
||||
page: contentLimit ? Math.ceil(contentLimit / websearch.maxResults) : 1
|
||||
}
|
||||
|
||||
const response = await fetch(`${this.apiHost}/v1/web-search`, {
|
||||
method: 'POST',
|
||||
body: JSON.stringify(params),
|
||||
headers: {
|
||||
...this.defaultHeaders(),
|
||||
...headers
|
||||
}
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Bocha search failed: ${response.status} ${response.statusText}`)
|
||||
}
|
||||
|
||||
const resp: BochaSearchResponse = await response.json()
|
||||
if (resp.code !== 200) {
|
||||
throw new Error(`Bocha search failed: ${resp.msg}`)
|
||||
}
|
||||
return {
|
||||
query: resp.data.queryContext.originalQuery,
|
||||
results: resp.data.webPages.value.map((result) => ({
|
||||
title: result.name,
|
||||
content: result.snippet,
|
||||
url: result.url
|
||||
}))
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Bocha search failed:', error)
|
||||
throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -12,7 +12,10 @@ export default class ExaProvider extends BaseWebSearchProvider {
|
||||
if (!this.apiKey) {
|
||||
throw new Error('API key is required for Exa provider')
|
||||
}
|
||||
this.exa = new ExaClient({ apiKey: this.apiKey })
|
||||
if (!this.apiHost) {
|
||||
throw new Error('API host is required for Exa provider')
|
||||
}
|
||||
this.exa = new ExaClient({ apiKey: this.apiKey, apiBaseUrl: this.apiHost })
|
||||
}
|
||||
|
||||
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
|
||||
|
||||
@ -10,7 +10,6 @@ import BaseWebSearchProvider from './BaseWebSearchProvider'
|
||||
export default class SearxngProvider extends BaseWebSearchProvider {
|
||||
private searxng: SearxngClient
|
||||
private engines: string[] = []
|
||||
private readonly apiHost: string
|
||||
private readonly basicAuthUsername?: string
|
||||
private readonly basicAuthPassword?: string
|
||||
private isInitialized = false
|
||||
|
||||
@ -12,7 +12,10 @@ export default class TavilyProvider extends BaseWebSearchProvider {
|
||||
if (!this.apiKey) {
|
||||
throw new Error('API key is required for Tavily provider')
|
||||
}
|
||||
this.tvly = new TavilyClient({ apiKey: this.apiKey })
|
||||
if (!this.apiHost) {
|
||||
throw new Error('API host is required for Tavily provider')
|
||||
}
|
||||
this.tvly = new TavilyClient({ apiKey: this.apiKey, apiBaseUrl: this.apiHost })
|
||||
}
|
||||
|
||||
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import { WebSearchProvider } from '@renderer/types'
|
||||
|
||||
import BaseWebSearchProvider from './BaseWebSearchProvider'
|
||||
import BochaProvider from './BochaProvider'
|
||||
import DefaultProvider from './DefaultProvider'
|
||||
import ExaProvider from './ExaProvider'
|
||||
import LocalBaiduProvider from './LocalBaiduProvider'
|
||||
@ -14,6 +15,8 @@ export default class WebSearchProviderFactory {
|
||||
switch (provider.id) {
|
||||
case 'tavily':
|
||||
return new TavilyProvider(provider)
|
||||
case 'bocha':
|
||||
return new BochaProvider(provider)
|
||||
case 'searxng':
|
||||
return new SearxngProvider(provider)
|
||||
case 'exa':
|
||||
|
||||
@ -125,6 +125,8 @@ async function fetchExternalTool(
|
||||
return
|
||||
}
|
||||
|
||||
if (extractResults.websearch.question[0] === 'not_needed') return
|
||||
|
||||
// Add check for assistant.model before using it
|
||||
if (!assistant.model) {
|
||||
console.warn('searchTheWeb called without assistant.model')
|
||||
|
||||
@ -106,7 +106,7 @@ class WebSearchService {
|
||||
const webSearchEngine = new WebSearchEngineProvider(provider)
|
||||
|
||||
let formattedQuery = query
|
||||
// 有待商榷,效果一般
|
||||
// FIXME: 有待商榷,效果一般
|
||||
if (websearch.searchWithTime) {
|
||||
formattedQuery = `today is ${dayjs().format('YYYY-MM-DD')} \r\n ${query}`
|
||||
}
|
||||
|
||||
@ -46,7 +46,7 @@ const persistedReducer = persistReducer(
|
||||
{
|
||||
key: 'cherry-studio',
|
||||
storage,
|
||||
version: 98,
|
||||
version: 99,
|
||||
blacklist: ['runtime', 'messages', 'messageBlocks'],
|
||||
migrate
|
||||
},
|
||||
|
||||
@ -5,7 +5,7 @@ import { SYSTEM_MODELS } from '@renderer/config/models'
|
||||
import { TRANSLATE_PROMPT } from '@renderer/config/prompts'
|
||||
import db from '@renderer/databases'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { Assistant } from '@renderer/types'
|
||||
import { Assistant, WebSearchProvider } from '@renderer/types'
|
||||
import { getDefaultGroupName, getLeadingEmoji, runAsyncFunction, uuid } from '@renderer/utils'
|
||||
import { isEmpty } from 'lodash'
|
||||
import { createMigrate } from 'redux-persist'
|
||||
@ -64,6 +64,18 @@ function addWebSearchProvider(state: RootState, id: string) {
|
||||
}
|
||||
}
|
||||
|
||||
function updateWebSearchProvider(state: RootState, provider: Partial<WebSearchProvider>) {
|
||||
if (state.websearch && state.websearch.providers) {
|
||||
const index = state.websearch.providers.findIndex((p) => p.id === provider.id)
|
||||
if (index !== -1) {
|
||||
state.websearch.providers[index] = {
|
||||
...state.websearch.providers[index],
|
||||
...provider
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const migrateConfig = {
|
||||
'2': (state: RootState) => {
|
||||
try {
|
||||
@ -1252,6 +1264,38 @@ const migrateConfig = {
|
||||
} catch (error) {
|
||||
return state
|
||||
}
|
||||
},
|
||||
'99': (state: RootState) => {
|
||||
try {
|
||||
state.settings.showPrompt = true
|
||||
|
||||
addWebSearchProvider(state, 'bocha')
|
||||
|
||||
updateWebSearchProvider(state, {
|
||||
id: 'exa',
|
||||
apiHost: 'https://api.exa.ai'
|
||||
})
|
||||
|
||||
updateWebSearchProvider(state, {
|
||||
id: 'tavily',
|
||||
apiHost: 'https://api.tavily.com'
|
||||
})
|
||||
|
||||
// Remove basic auth fields from exa and tavily
|
||||
if (state.websearch?.providers) {
|
||||
state.websearch.providers = state.websearch.providers.map((provider) => {
|
||||
if (provider.id === 'exa' || provider.id === 'tavily') {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
const { basicAuthUsername, basicAuthPassword, ...rest } = provider
|
||||
return rest
|
||||
}
|
||||
return provider
|
||||
})
|
||||
}
|
||||
return state
|
||||
} catch (error) {
|
||||
return state
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@ export interface SettingsState {
|
||||
proxyMode: 'system' | 'custom' | 'none'
|
||||
proxyUrl?: string
|
||||
userName: string
|
||||
showPrompt: boolean
|
||||
showMessageDivider: boolean
|
||||
messageFont: 'system' | 'serif'
|
||||
showInputEstimatedTokens: boolean
|
||||
@ -143,6 +144,7 @@ export const initialState: SettingsState = {
|
||||
proxyMode: 'system',
|
||||
proxyUrl: undefined,
|
||||
userName: '',
|
||||
showPrompt: true,
|
||||
showMessageDivider: true,
|
||||
messageFont: 'system',
|
||||
showInputEstimatedTokens: false,
|
||||
@ -272,6 +274,9 @@ const settingsSlice = createSlice({
|
||||
setUserName: (state, action: PayloadAction<string>) => {
|
||||
state.userName = action.payload
|
||||
},
|
||||
setShowPrompt: (state, action: PayloadAction<boolean>) => {
|
||||
state.showPrompt = action.payload
|
||||
},
|
||||
setShowMessageDivider: (state, action: PayloadAction<boolean>) => {
|
||||
state.showMessageDivider = action.payload
|
||||
},
|
||||
@ -528,6 +533,7 @@ export const {
|
||||
setProxyMode,
|
||||
setProxyUrl,
|
||||
setUserName,
|
||||
setShowPrompt,
|
||||
setShowMessageDivider,
|
||||
setMessageFont,
|
||||
setShowInputEstimatedTokens,
|
||||
|
||||
@ -25,6 +25,8 @@ export interface WebSearchState {
|
||||
/** @deprecated 支持在快捷菜单中自选搜索供应商,所以这个不再适用 */
|
||||
overwrite: boolean
|
||||
contentLimit?: number
|
||||
// 具体供应商的配置
|
||||
providerConfig: Record<string, any>
|
||||
}
|
||||
|
||||
const initialState: WebSearchState = {
|
||||
@ -33,16 +35,26 @@ const initialState: WebSearchState = {
|
||||
{
|
||||
id: 'tavily',
|
||||
name: 'Tavily',
|
||||
apiHost: 'https://api.tavily.com',
|
||||
apiKey: ''
|
||||
},
|
||||
{
|
||||
id: 'searxng',
|
||||
name: 'Searxng',
|
||||
apiHost: ''
|
||||
apiHost: '',
|
||||
basicAuthUsername: '',
|
||||
basicAuthPassword: ''
|
||||
},
|
||||
{
|
||||
id: 'exa',
|
||||
name: 'Exa',
|
||||
apiHost: 'https://api.exa.ai',
|
||||
apiKey: ''
|
||||
},
|
||||
{
|
||||
id: 'bocha',
|
||||
name: 'Bocha',
|
||||
apiHost: 'https://api.bochaai.com',
|
||||
apiKey: ''
|
||||
},
|
||||
{
|
||||
@ -65,7 +77,8 @@ const initialState: WebSearchState = {
|
||||
maxResults: 5,
|
||||
excludeDomains: [],
|
||||
subscribeSources: [],
|
||||
overwrite: false
|
||||
overwrite: false,
|
||||
providerConfig: {}
|
||||
}
|
||||
|
||||
export const defaultWebSearchProviders = initialState.providers
|
||||
@ -139,6 +152,12 @@ const websearchSlice = createSlice({
|
||||
},
|
||||
setContentLimit: (state, action: PayloadAction<number | undefined>) => {
|
||||
state.contentLimit = action.payload
|
||||
},
|
||||
setProviderConfig: (state, action: PayloadAction<Record<string, any>>) => {
|
||||
state.providerConfig = action.payload
|
||||
},
|
||||
updateProviderConfig: (state, action: PayloadAction<Record<string, any>>) => {
|
||||
state.providerConfig = { ...state.providerConfig, ...action.payload }
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -157,7 +176,9 @@ export const {
|
||||
setSubscribeSources,
|
||||
setOverwrite,
|
||||
addWebSearchProvider,
|
||||
setContentLimit
|
||||
setContentLimit,
|
||||
setProviderConfig,
|
||||
updateProviderConfig
|
||||
} = websearchSlice.actions
|
||||
|
||||
export default websearchSlice.reducer
|
||||
|
||||
207
src/renderer/src/types/bocha.ts
Normal file
207
src/renderer/src/types/bocha.ts
Normal file
@ -0,0 +1,207 @@
|
||||
import { z } from 'zod'
|
||||
|
||||
export const freshnessOptions = ['oneDay', 'oneWeek', 'oneMonth', 'oneYear', 'noLimit'] as const
|
||||
|
||||
const isValidDate = (dateStr: string): boolean => {
|
||||
// First check basic format
|
||||
if (!/^\d{4}-\d{2}-\d{2}$/.test(dateStr)) {
|
||||
return false
|
||||
}
|
||||
|
||||
const [year, month, day] = dateStr.split('-').map(Number)
|
||||
|
||||
if (year < 1900 || year > 2100) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check month range
|
||||
if (month < 1 || month > 12) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Get last day of the month
|
||||
const lastDay = new Date(year, month, 0).getDate()
|
||||
|
||||
// Check day range
|
||||
if (day < 1 || day > lastDay) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
const isValidDateRange = (dateRangeStr: string): boolean => {
|
||||
// Check if it's a single date
|
||||
if (/^\d{4}-\d{2}-\d{2}$/.test(dateRangeStr)) {
|
||||
return isValidDate(dateRangeStr)
|
||||
}
|
||||
|
||||
// Check if it's a date range
|
||||
if (!/^\d{4}-\d{2}-\d{2}\.\.\d{4}-\d{2}-\d{2}$/.test(dateRangeStr)) {
|
||||
return false
|
||||
}
|
||||
|
||||
const [startDate, endDate] = dateRangeStr.split('..')
|
||||
|
||||
// Validate both dates
|
||||
if (!isValidDate(startDate) || !isValidDate(endDate)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if start date is before or equal to end date
|
||||
const start = new Date(startDate)
|
||||
const end = new Date(endDate)
|
||||
return start <= end
|
||||
}
|
||||
|
||||
const isValidExcludeDomains = (excludeStr: string): boolean => {
|
||||
if (!excludeStr) return true
|
||||
|
||||
// Split by either | or ,
|
||||
const domains = excludeStr
|
||||
.split(/[|,]/)
|
||||
.map((d) => d.trim())
|
||||
.filter(Boolean)
|
||||
|
||||
// Check number of domains
|
||||
if (domains.length > 20) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Domain name regex (supports both root domains and subdomains)
|
||||
const domainRegex = /^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$/
|
||||
|
||||
// Check each domain
|
||||
return domains.every((domain) => domainRegex.test(domain))
|
||||
}
|
||||
|
||||
const BochaSearchParamsSchema = z.object({
|
||||
query: z.string(),
|
||||
freshness: z
|
||||
.union([
|
||||
z.enum(freshnessOptions),
|
||||
z
|
||||
.string()
|
||||
.regex(
|
||||
/^(\d{4}-\d{2}-\d{2})(\.\.\d{4}-\d{2}-\d{2})?$/,
|
||||
'Date must be in YYYY-MM-DD or YYYY-MM-DD..YYYY-MM-DD format'
|
||||
)
|
||||
.refine(isValidDateRange, {
|
||||
message: 'Invalid date range - please provide valid dates in YYYY-MM-DD or YYYY-MM-DD..YYYY-MM-DD format'
|
||||
})
|
||||
])
|
||||
.optional()
|
||||
.default('noLimit'),
|
||||
summary: z.boolean().optional().default(false),
|
||||
exclude: z
|
||||
.string()
|
||||
.optional()
|
||||
.refine((val) => !val || isValidExcludeDomains(val), {
|
||||
message:
|
||||
'Invalid exclude format. Please provide valid domain names separated by | or ,. Maximum 20 domains allowed.'
|
||||
}),
|
||||
page: z.number().optional().default(1),
|
||||
count: z.number().optional().default(10)
|
||||
})
|
||||
|
||||
const BochaSearchResponseDataSchema = z.object({
|
||||
type: z.string(),
|
||||
queryContext: z.object({
|
||||
originalQuery: z.string()
|
||||
}),
|
||||
webPages: z.object({
|
||||
webSearchUrl: z.string(),
|
||||
totalEstimatedMatches: z.number(),
|
||||
value: z.array(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
url: z.string(),
|
||||
displayUrl: z.string(),
|
||||
snippet: z.string(),
|
||||
summary: z.string().optional(),
|
||||
siteName: z.string(),
|
||||
siteIcon: z.string(),
|
||||
datePublished: z.string(),
|
||||
dateLastCrawled: z.string(),
|
||||
cachedPageUrl: z.string(),
|
||||
language: z.string(),
|
||||
isFamilyFriendly: z.boolean(),
|
||||
isNavigational: z.boolean()
|
||||
})
|
||||
),
|
||||
someResultsRemoved: z.boolean()
|
||||
}),
|
||||
images: z.object({
|
||||
id: z.string(),
|
||||
readLink: z.string(),
|
||||
webSearchUrl: z.string(),
|
||||
name: z.string(),
|
||||
value: z.array(
|
||||
z.object({
|
||||
webSearchUrl: z.string(),
|
||||
name: z.string(),
|
||||
thumbnailUrl: z.string(),
|
||||
datePublished: z.string(),
|
||||
contentUrl: z.string(),
|
||||
hostPageUrl: z.string(),
|
||||
contentSize: z.string(),
|
||||
encodingFormat: z.string(),
|
||||
hostPageDisplayUrl: z.string(),
|
||||
width: z.number(),
|
||||
height: z.number(),
|
||||
thumbnail: z.object({
|
||||
width: z.number(),
|
||||
height: z.number()
|
||||
})
|
||||
})
|
||||
)
|
||||
}),
|
||||
videos: z.object({
|
||||
id: z.string(),
|
||||
readLink: z.string(),
|
||||
webSearchUrl: z.string(),
|
||||
isFamilyFriendly: z.boolean(),
|
||||
scenario: z.string(),
|
||||
name: z.string(),
|
||||
value: z.array(
|
||||
z.object({
|
||||
webSearchUrl: z.string(),
|
||||
name: z.string(),
|
||||
description: z.string(),
|
||||
thumbnailUrl: z.string(),
|
||||
publisher: z.string(),
|
||||
creator: z.string(),
|
||||
contentUrl: z.string(),
|
||||
hostPageUrl: z.string(),
|
||||
encodingFormat: z.string(),
|
||||
hostPageDisplayUrl: z.string(),
|
||||
width: z.number(),
|
||||
height: z.number(),
|
||||
duration: z.number(),
|
||||
motionThumbnailUrl: z.string(),
|
||||
embedHtml: z.string(),
|
||||
allowHttpsEmbed: z.boolean(),
|
||||
viewCount: z.number(),
|
||||
thumbnail: z.object({
|
||||
width: z.number(),
|
||||
height: z.number()
|
||||
}),
|
||||
allowMobileEmbed: z.boolean(),
|
||||
isSuperfresh: z.boolean(),
|
||||
datePublished: z.string()
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
const BochaSearchResponseSchema = z.object({
|
||||
code: z.number(),
|
||||
logId: z.string(),
|
||||
data: BochaSearchResponseDataSchema,
|
||||
msg: z.string().optional()
|
||||
})
|
||||
|
||||
export type BochaSearchParams = z.infer<typeof BochaSearchParamsSchema>
|
||||
export type BochaSearchResponse = z.infer<typeof BochaSearchResponseSchema>
|
||||
export { BochaSearchParamsSchema, BochaSearchResponseSchema }
|
||||
Loading…
Reference in New Issue
Block a user