mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-07 05:39:05 +08:00
Merge branch 'main' into fix/next-release-bugs
This commit is contained in:
commit
572ffcc8be
Binary file not shown.
|
Before Width: | Height: | Size: 15 KiB |
1
src/renderer/src/assets/images/apps/n8n.svg
Normal file
1
src/renderer/src/assets/images/apps/n8n.svg
Normal file
@ -0,0 +1 @@
|
|||||||
|
<svg height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>n8n</title><path clip-rule="evenodd" d="M24 8.4c0 1.325-1.102 2.4-2.462 2.4-1.146 0-2.11-.765-2.384-1.8h-3.436c-.602 0-1.115.424-1.214 1.003l-.101.592a2.38 2.38 0 01-.8 1.405c.412.354.704.844.8 1.405l.1.592A1.222 1.222 0 0015.719 15h.975c.273-1.035 1.237-1.8 2.384-1.8 1.36 0 2.461 1.075 2.461 2.4S20.436 18 19.078 18c-1.147 0-2.11-.765-2.384-1.8h-.975c-1.204 0-2.23-.848-2.428-2.005l-.101-.592a1.222 1.222 0 00-1.214-1.003H10.97c-.308.984-1.246 1.7-2.356 1.7-1.11 0-2.048-.716-2.355-1.7H4.817c-.308.984-1.246 1.7-2.355 1.7C1.102 14.3 0 13.225 0 11.9s1.102-2.4 2.462-2.4c1.183 0 2.172.815 2.408 1.9h1.337c.236-1.085 1.225-1.9 2.408-1.9 1.184 0 2.172.815 2.408 1.9h.952c.601 0 1.115-.424 1.213-1.003l.102-.592c.198-1.157 1.225-2.005 2.428-2.005h3.436c.274-1.035 1.238-1.8 2.384-1.8C22.898 6 24 7.075 24 8.4zm-1.23 0c0 .663-.552 1.2-1.232 1.2-.68 0-1.23-.537-1.23-1.2 0-.663.55-1.2 1.23-1.2.68 0 1.231.537 1.231 1.2zM2.461 13.1c.68 0 1.23-.537 1.23-1.2 0-.663-.55-1.2-1.23-1.2-.68 0-1.231.537-1.231 1.2 0 .663.55 1.2 1.23 1.2zm6.153 0c.68 0 1.231-.537 1.231-1.2 0-.663-.55-1.2-1.23-1.2-.68 0-1.231.537-1.231 1.2 0 .663.55 1.2 1.23 1.2zm10.462 3.7c.68 0 1.23-.537 1.23-1.2 0-.663-.55-1.2-1.23-1.2-.68 0-1.23.537-1.23 1.2 0 .663.55 1.2 1.23 1.2z" fill="#EA4B71" fill-rule="evenodd"></path></svg>
|
||||||
|
After Width: | Height: | Size: 1.4 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 1.1 KiB |
@ -59,9 +59,11 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
|
|||||||
const [pinnedModels, setPinnedModels] = useState<string[]>([])
|
const [pinnedModels, setPinnedModels] = useState<string[]>([])
|
||||||
const [_focusedItemKey, setFocusedItemKey] = useState<string>('')
|
const [_focusedItemKey, setFocusedItemKey] = useState<string>('')
|
||||||
const focusedItemKey = useDeferredValue(_focusedItemKey)
|
const focusedItemKey = useDeferredValue(_focusedItemKey)
|
||||||
const [currentStickyGroup, setCurrentStickyGroup] = useState<FlatListItem | null>(null)
|
const [_stickyGroup, setStickyGroup] = useState<FlatListItem | null>(null)
|
||||||
|
const stickyGroup = useDeferredValue(_stickyGroup)
|
||||||
const firstGroupRef = useRef<FlatListItem | null>(null)
|
const firstGroupRef = useRef<FlatListItem | null>(null)
|
||||||
const scrollTriggerRef = useRef<ScrollTrigger>('initial')
|
const scrollTriggerRef = useRef<ScrollTrigger>('initial')
|
||||||
|
const lastScrollOffsetRef = useRef(0)
|
||||||
|
|
||||||
// 当前选中的模型ID
|
// 当前选中的模型ID
|
||||||
const currentModelId = model ? getModelUniqId(model) : ''
|
const currentModelId = model ? getModelUniqId(model) : ''
|
||||||
@ -220,6 +222,45 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
|
|||||||
return items
|
return items
|
||||||
}, [providers, getFilteredModels, pinnedModels, searchText, t, createModelItem])
|
}, [providers, getFilteredModels, pinnedModels, searchText, t, createModelItem])
|
||||||
|
|
||||||
|
// 基于滚动位置更新sticky分组标题
|
||||||
|
const updateStickyGroup = useCallback(
|
||||||
|
(scrollOffset?: number) => {
|
||||||
|
if (listItems.length === 0) {
|
||||||
|
setStickyGroup(null)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 基于滚动位置计算当前可见的第一个项的索引
|
||||||
|
const estimatedIndex = Math.floor((scrollOffset ?? lastScrollOffsetRef.current) / ITEM_HEIGHT)
|
||||||
|
|
||||||
|
// 从该索引向前查找最近的分组标题
|
||||||
|
for (let i = estimatedIndex - 1; i >= 0; i--) {
|
||||||
|
if (i < listItems.length && listItems[i]?.type === 'group') {
|
||||||
|
setStickyGroup(listItems[i])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 找不到则使用第一个分组标题
|
||||||
|
setStickyGroup(firstGroupRef.current ?? null)
|
||||||
|
},
|
||||||
|
[listItems]
|
||||||
|
)
|
||||||
|
|
||||||
|
// 在listItems变化时更新sticky group
|
||||||
|
useEffect(() => {
|
||||||
|
updateStickyGroup()
|
||||||
|
}, [listItems, updateStickyGroup])
|
||||||
|
|
||||||
|
// 处理列表滚动事件,更新lastScrollOffset并更新sticky分组
|
||||||
|
const handleScroll = useCallback(
|
||||||
|
({ scrollOffset }) => {
|
||||||
|
lastScrollOffsetRef.current = scrollOffset
|
||||||
|
updateStickyGroup(scrollOffset)
|
||||||
|
},
|
||||||
|
[updateStickyGroup]
|
||||||
|
)
|
||||||
|
|
||||||
// 获取可选择的模型项(过滤掉分组标题)
|
// 获取可选择的模型项(过滤掉分组标题)
|
||||||
const modelItems = useMemo(() => {
|
const modelItems = useMemo(() => {
|
||||||
return listItems.filter((item) => item.type === 'model')
|
return listItems.filter((item) => item.type === 'model')
|
||||||
@ -257,9 +298,6 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
|
|||||||
const alignment = scrollTriggerRef.current === 'keyboard' ? 'auto' : 'center'
|
const alignment = scrollTriggerRef.current === 'keyboard' ? 'auto' : 'center'
|
||||||
listRef.current?.scrollToItem(index, alignment)
|
listRef.current?.scrollToItem(index, alignment)
|
||||||
|
|
||||||
console.log('focusedItemKey', focusedItemKey)
|
|
||||||
console.log('scrollToFocusedItem', index, alignment)
|
|
||||||
|
|
||||||
// 滚动后重置触发器
|
// 滚动后重置触发器
|
||||||
scrollTriggerRef.current = 'none'
|
scrollTriggerRef.current = 'none'
|
||||||
}, [focusedItemKey, listItems])
|
}, [focusedItemKey, listItems])
|
||||||
@ -365,41 +403,19 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
|
|||||||
if (!open) return
|
if (!open) return
|
||||||
setTimeout(() => inputRef.current?.focus(), 0)
|
setTimeout(() => inputRef.current?.focus(), 0)
|
||||||
scrollTriggerRef.current = 'initial'
|
scrollTriggerRef.current = 'initial'
|
||||||
|
lastScrollOffsetRef.current = 0
|
||||||
}, [open])
|
}, [open])
|
||||||
|
|
||||||
// 初始化sticky分组标题
|
|
||||||
useEffect(() => {
|
|
||||||
if (firstGroupRef.current) {
|
|
||||||
setCurrentStickyGroup(firstGroupRef.current)
|
|
||||||
}
|
|
||||||
}, [listItems])
|
|
||||||
|
|
||||||
const handleItemsRendered = useCallback(
|
|
||||||
({ visibleStartIndex }: { visibleStartIndex: number; visibleStopIndex: number }) => {
|
|
||||||
// 从可见区域的起始位置向前查找最近的分组标题
|
|
||||||
for (let i = visibleStartIndex - 1; i >= 0; i--) {
|
|
||||||
if (listItems[i]?.type === 'group') {
|
|
||||||
setCurrentStickyGroup(listItems[i])
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 找不到则使用第一个分组标题
|
|
||||||
setCurrentStickyGroup(firstGroupRef.current ?? null)
|
|
||||||
},
|
|
||||||
[listItems]
|
|
||||||
)
|
|
||||||
|
|
||||||
const RowData = useMemo(
|
const RowData = useMemo(
|
||||||
(): VirtualizedRowData => ({
|
(): VirtualizedRowData => ({
|
||||||
listItems,
|
listItems,
|
||||||
focusedItemKey,
|
focusedItemKey,
|
||||||
setFocusedItemKey,
|
setFocusedItemKey,
|
||||||
currentStickyGroup,
|
stickyGroup,
|
||||||
handleItemClick,
|
handleItemClick,
|
||||||
togglePin
|
togglePin
|
||||||
}),
|
}),
|
||||||
[currentStickyGroup, focusedItemKey, handleItemClick, listItems, togglePin]
|
[stickyGroup, focusedItemKey, handleItemClick, listItems, togglePin]
|
||||||
)
|
)
|
||||||
|
|
||||||
const listHeight = useMemo(() => {
|
const listHeight = useMemo(() => {
|
||||||
@ -456,7 +472,7 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
|
|||||||
{listItems.length > 0 ? (
|
{listItems.length > 0 ? (
|
||||||
<ListContainer onMouseMove={() => setIsMouseOver(true)}>
|
<ListContainer onMouseMove={() => setIsMouseOver(true)}>
|
||||||
{/* Sticky Group Banner,它会替换第一个分组名称 */}
|
{/* Sticky Group Banner,它会替换第一个分组名称 */}
|
||||||
<StickyGroupBanner>{currentStickyGroup?.name}</StickyGroupBanner>
|
<StickyGroupBanner>{stickyGroup?.name}</StickyGroupBanner>
|
||||||
<FixedSizeList
|
<FixedSizeList
|
||||||
ref={listRef}
|
ref={listRef}
|
||||||
height={listHeight}
|
height={listHeight}
|
||||||
@ -466,7 +482,7 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
|
|||||||
itemData={RowData}
|
itemData={RowData}
|
||||||
itemKey={(index, data) => data.listItems[index].key}
|
itemKey={(index, data) => data.listItems[index].key}
|
||||||
overscanCount={4}
|
overscanCount={4}
|
||||||
onItemsRendered={handleItemsRendered}
|
onScroll={handleScroll}
|
||||||
style={{ pointerEvents: isMouseOver ? 'auto' : 'none' }}>
|
style={{ pointerEvents: isMouseOver ? 'auto' : 'none' }}>
|
||||||
{VirtualizedRow}
|
{VirtualizedRow}
|
||||||
</FixedSizeList>
|
</FixedSizeList>
|
||||||
@ -484,7 +500,7 @@ interface VirtualizedRowData {
|
|||||||
listItems: FlatListItem[]
|
listItems: FlatListItem[]
|
||||||
focusedItemKey: string
|
focusedItemKey: string
|
||||||
setFocusedItemKey: (key: string) => void
|
setFocusedItemKey: (key: string) => void
|
||||||
currentStickyGroup: FlatListItem | null
|
stickyGroup: FlatListItem | null
|
||||||
handleItemClick: (item: FlatListItem) => void
|
handleItemClick: (item: FlatListItem) => void
|
||||||
togglePin: (modelId: string) => void
|
togglePin: (modelId: string) => void
|
||||||
}
|
}
|
||||||
@ -494,7 +510,7 @@ interface VirtualizedRowData {
|
|||||||
*/
|
*/
|
||||||
const VirtualizedRow = React.memo(
|
const VirtualizedRow = React.memo(
|
||||||
({ data, index, style }: { data: VirtualizedRowData; index: number; style: React.CSSProperties }) => {
|
({ data, index, style }: { data: VirtualizedRowData; index: number; style: React.CSSProperties }) => {
|
||||||
const { listItems, focusedItemKey, setFocusedItemKey, handleItemClick, togglePin, currentStickyGroup } = data
|
const { listItems, focusedItemKey, setFocusedItemKey, handleItemClick, togglePin, stickyGroup } = data
|
||||||
|
|
||||||
const item = listItems[index]
|
const item = listItems[index]
|
||||||
|
|
||||||
@ -505,7 +521,7 @@ const VirtualizedRow = React.memo(
|
|||||||
return (
|
return (
|
||||||
<div style={style}>
|
<div style={style}>
|
||||||
{item.type === 'group' ? (
|
{item.type === 'group' ? (
|
||||||
<GroupItem $isSticky={item.key === currentStickyGroup?.key}>{item.name}</GroupItem>
|
<GroupItem $isSticky={item.key === stickyGroup?.key}>{item.name}</GroupItem>
|
||||||
) : (
|
) : (
|
||||||
<ModelItem
|
<ModelItem
|
||||||
className={classNames({
|
className={classNames({
|
||||||
|
|||||||
@ -1,9 +1,7 @@
|
|||||||
import n8nLogo from '@renderer/assets/images/apps/n8n.ico?url'
|
|
||||||
import ParateraLogo from '@renderer/assets/images/apps/paratera.ico?url'
|
|
||||||
import ApplicationLogo from '@renderer/assets/images/apps/application.png?url'
|
|
||||||
import ThreeMinTopAppLogo from '@renderer/assets/images/apps/3mintop.png?url'
|
import ThreeMinTopAppLogo from '@renderer/assets/images/apps/3mintop.png?url'
|
||||||
import AbacusLogo from '@renderer/assets/images/apps/abacus.webp?url'
|
import AbacusLogo from '@renderer/assets/images/apps/abacus.webp?url'
|
||||||
import AIStudioLogo from '@renderer/assets/images/apps/aistudio.svg?url'
|
import AIStudioLogo from '@renderer/assets/images/apps/aistudio.svg?url'
|
||||||
|
import ApplicationLogo from '@renderer/assets/images/apps/application.png?url'
|
||||||
import BaiduAiAppLogo from '@renderer/assets/images/apps/baidu-ai.png?url'
|
import BaiduAiAppLogo from '@renderer/assets/images/apps/baidu-ai.png?url'
|
||||||
import BaiduAiSearchLogo from '@renderer/assets/images/apps/baidu-ai-search.webp?url'
|
import BaiduAiSearchLogo from '@renderer/assets/images/apps/baidu-ai-search.webp?url'
|
||||||
import BaicuanAppLogo from '@renderer/assets/images/apps/baixiaoying.webp?url'
|
import BaicuanAppLogo from '@renderer/assets/images/apps/baixiaoying.webp?url'
|
||||||
@ -29,6 +27,7 @@ import LambdaChatLogo from '@renderer/assets/images/apps/lambdachat.webp?url'
|
|||||||
import LeChatLogo from '@renderer/assets/images/apps/lechat.png?url'
|
import LeChatLogo from '@renderer/assets/images/apps/lechat.png?url'
|
||||||
import MetasoAppLogo from '@renderer/assets/images/apps/metaso.webp?url'
|
import MetasoAppLogo from '@renderer/assets/images/apps/metaso.webp?url'
|
||||||
import MonicaLogo from '@renderer/assets/images/apps/monica.webp?url'
|
import MonicaLogo from '@renderer/assets/images/apps/monica.webp?url'
|
||||||
|
import n8nLogo from '@renderer/assets/images/apps/n8n.svg?url'
|
||||||
import NamiAiLogo from '@renderer/assets/images/apps/nm.png?url'
|
import NamiAiLogo from '@renderer/assets/images/apps/nm.png?url'
|
||||||
import NamiAiSearchLogo from '@renderer/assets/images/apps/nm-search.webp?url'
|
import NamiAiSearchLogo from '@renderer/assets/images/apps/nm-search.webp?url'
|
||||||
import NotebookLMAppLogo from '@renderer/assets/images/apps/notebooklm.svg?url'
|
import NotebookLMAppLogo from '@renderer/assets/images/apps/notebooklm.svg?url'
|
||||||
@ -62,11 +61,11 @@ const loadCustomMiniApp = async (): Promise<MinAppType[]> => {
|
|||||||
try {
|
try {
|
||||||
let content: string
|
let content: string
|
||||||
try {
|
try {
|
||||||
content = await window.api.file.read('customMiniAPP')
|
content = await window.api.file.read('custom-minapps.json')
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// 如果文件不存在,创建一个空的 JSON 数组
|
// 如果文件不存在,创建一个空的 JSON 数组
|
||||||
content = '[]'
|
content = '[]'
|
||||||
await window.api.file.writeWithId('customMiniAPP', content)
|
await window.api.file.writeWithId('custom-minapps.json', content)
|
||||||
}
|
}
|
||||||
|
|
||||||
const customApps = JSON.parse(content)
|
const customApps = JSON.parse(content)
|
||||||
@ -451,18 +450,15 @@ const ORIGIN_DEFAULT_MIN_APPS: MinAppType[] = [
|
|||||||
padding: 10
|
padding: 10
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
id: 'paratera',
|
|
||||||
name: 'ParateraAI',
|
|
||||||
logo: ParateraLogo,
|
|
||||||
url: 'https://ai.paratera.com/'
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
id: 'n8n',
|
id: 'n8n',
|
||||||
name: 'n8n',
|
name: 'n8n',
|
||||||
logo: n8nLogo,
|
logo: n8nLogo,
|
||||||
url: 'https://app.n8n.cloud/',
|
url: 'https://app.n8n.cloud/',
|
||||||
bodered: true
|
bodered: true,
|
||||||
|
style: {
|
||||||
|
padding: 5
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -2074,62 +2074,6 @@ export const SYSTEM_MODELS: Record<string, Model[]> = {
|
|||||||
name: 'Qwen2.5 72B Instruct',
|
name: 'Qwen2.5 72B Instruct',
|
||||||
group: 'Qwen'
|
group: 'Qwen'
|
||||||
}
|
}
|
||||||
],
|
|
||||||
paratera: [
|
|
||||||
{
|
|
||||||
id: 'GLM-Z1-Flash-P002',
|
|
||||||
provider: 'paratera',
|
|
||||||
name: 'GLM-Z1-Flash-P002',
|
|
||||||
group: 'GLM'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'GLM-Z1-AirX-P002',
|
|
||||||
provider: 'paratera',
|
|
||||||
name: 'GLM-Z1-AirX-P002',
|
|
||||||
group: 'GLM'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'DeepSeek-V3-250324-P001',
|
|
||||||
provider: 'paratera',
|
|
||||||
name: 'DeepSeek-V3-250324-P001',
|
|
||||||
group: 'DeepSeek'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'DeepSeek-R1',
|
|
||||||
provider: 'paratera',
|
|
||||||
name: 'DeepSeek-R1',
|
|
||||||
group: 'DeepSeek'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'QwQ-N011-32B',
|
|
||||||
provider: 'paratera',
|
|
||||||
name: 'QwQ-N011-32B',
|
|
||||||
group: 'Qwen'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'GLM-Embedding-2-P002',
|
|
||||||
provider: 'paratera',
|
|
||||||
name: 'GLM-Embedding-2-P002',
|
|
||||||
group: 'GLM'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'GLM-Embedding-3-P002',
|
|
||||||
provider: 'paratera',
|
|
||||||
name: 'GLM-Embedding-3-P002',
|
|
||||||
group: 'GLM'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'Doubao-Embedding-Text-P001',
|
|
||||||
provider: 'paratera',
|
|
||||||
name: 'Doubao-Embedding-Text-P001',
|
|
||||||
group: 'Doubao'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'Doubao-Embedding-Large-Text-P001',
|
|
||||||
provider: 'paratera',
|
|
||||||
name: 'Doubao-Embedding-Large-Text-P001',
|
|
||||||
group: 'Doubao'
|
|
||||||
}
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -42,7 +42,6 @@ import VoyageAIProviderLogo from '@renderer/assets/images/providers/voyageai.png
|
|||||||
import XirangProviderLogo from '@renderer/assets/images/providers/xirang.png'
|
import XirangProviderLogo from '@renderer/assets/images/providers/xirang.png'
|
||||||
import ZeroOneProviderLogo from '@renderer/assets/images/providers/zero-one.png'
|
import ZeroOneProviderLogo from '@renderer/assets/images/providers/zero-one.png'
|
||||||
import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png'
|
import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png'
|
||||||
import ParateraLogo from '@renderer/assets/images/apps/paratera.ico'
|
|
||||||
|
|
||||||
const PROVIDER_LOGO_MAP = {
|
const PROVIDER_LOGO_MAP = {
|
||||||
openai: OpenAiProviderLogo,
|
openai: OpenAiProviderLogo,
|
||||||
@ -89,8 +88,7 @@ const PROVIDER_LOGO_MAP = {
|
|||||||
gpustack: GPUStackProviderLogo,
|
gpustack: GPUStackProviderLogo,
|
||||||
alayanew: AlayaNewProviderLogo,
|
alayanew: AlayaNewProviderLogo,
|
||||||
voyageai: VoyageAIProviderLogo,
|
voyageai: VoyageAIProviderLogo,
|
||||||
qiniu: QiniuProviderLogo,
|
qiniu: QiniuProviderLogo
|
||||||
paratera: ParateraLogo
|
|
||||||
} as const
|
} as const
|
||||||
|
|
||||||
export function getProviderLogo(providerId: string) {
|
export function getProviderLogo(providerId: string) {
|
||||||
@ -585,16 +583,5 @@ export const PROVIDER_CONFIG = {
|
|||||||
docs: 'https://developer.qiniu.com/aitokenapi',
|
docs: 'https://developer.qiniu.com/aitokenapi',
|
||||||
models: 'https://developer.qiniu.com/aitokenapi/12883/model-list'
|
models: 'https://developer.qiniu.com/aitokenapi/12883/model-list'
|
||||||
}
|
}
|
||||||
},
|
|
||||||
paratera: {
|
|
||||||
api: {
|
|
||||||
url: 'https://llmapi.paratera.com'
|
|
||||||
},
|
|
||||||
websites: {
|
|
||||||
official: 'https://ai.paratera.com/',
|
|
||||||
apiKey: 'https://ai.paratera.com/#/lms/api',
|
|
||||||
docs: 'https://ai.paratera.com/document/llm/quickStart/useApi',
|
|
||||||
models: 'https://ai.paratera.com/#/lms/model'
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -705,6 +705,7 @@
|
|||||||
"rerank_model_tooltip": "Click the Manage button in Settings -> Model Services to add.",
|
"rerank_model_tooltip": "Click the Manage button in Settings -> Model Services to add.",
|
||||||
"search": "Search models...",
|
"search": "Search models...",
|
||||||
"stream_output": "Stream output",
|
"stream_output": "Stream output",
|
||||||
|
"enable_tool_use": "Enable Tool Use",
|
||||||
"type": {
|
"type": {
|
||||||
"embedding": "Embedding",
|
"embedding": "Embedding",
|
||||||
"free": "Free",
|
"free": "Free",
|
||||||
|
|||||||
@ -705,6 +705,7 @@
|
|||||||
"rerank_model_tooltip": "設定->モデルサービスに移動し、管理ボタンをクリックして追加します。",
|
"rerank_model_tooltip": "設定->モデルサービスに移動し、管理ボタンをクリックして追加します。",
|
||||||
"search": "モデルを検索...",
|
"search": "モデルを検索...",
|
||||||
"stream_output": "ストリーム出力",
|
"stream_output": "ストリーム出力",
|
||||||
|
"enable_tool_use": "ツール呼び出し",
|
||||||
"type": {
|
"type": {
|
||||||
"embedding": "埋め込み",
|
"embedding": "埋め込み",
|
||||||
"free": "無料",
|
"free": "無料",
|
||||||
|
|||||||
@ -705,6 +705,7 @@
|
|||||||
"rerank_model_tooltip": "В настройках -> Служба модели нажмите кнопку \"Управление\", чтобы добавить.",
|
"rerank_model_tooltip": "В настройках -> Служба модели нажмите кнопку \"Управление\", чтобы добавить.",
|
||||||
"search": "Поиск моделей...",
|
"search": "Поиск моделей...",
|
||||||
"stream_output": "Потоковый вывод",
|
"stream_output": "Потоковый вывод",
|
||||||
|
"enable_tool_use": "Вызов инструмента",
|
||||||
"type": {
|
"type": {
|
||||||
"embedding": "Встраиваемые",
|
"embedding": "Встраиваемые",
|
||||||
"free": "Бесплатные",
|
"free": "Бесплатные",
|
||||||
|
|||||||
@ -705,6 +705,7 @@
|
|||||||
"rerank_model_tooltip": "在设置->模型服务中点击管理按钮添加",
|
"rerank_model_tooltip": "在设置->模型服务中点击管理按钮添加",
|
||||||
"search": "搜索模型...",
|
"search": "搜索模型...",
|
||||||
"stream_output": "流式输出",
|
"stream_output": "流式输出",
|
||||||
|
"enable_tool_use": "工具调用",
|
||||||
"type": {
|
"type": {
|
||||||
"embedding": "嵌入",
|
"embedding": "嵌入",
|
||||||
"free": "免费",
|
"free": "免费",
|
||||||
|
|||||||
@ -705,6 +705,7 @@
|
|||||||
"rerank_model_tooltip": "在設定->模型服務中點擊管理按鈕添加",
|
"rerank_model_tooltip": "在設定->模型服務中點擊管理按鈕添加",
|
||||||
"search": "搜尋模型...",
|
"search": "搜尋模型...",
|
||||||
"stream_output": "串流輸出",
|
"stream_output": "串流輸出",
|
||||||
|
"enable_tool_use": "工具調用",
|
||||||
"type": {
|
"type": {
|
||||||
"embedding": "嵌入",
|
"embedding": "嵌入",
|
||||||
"free": "免費",
|
"free": "免費",
|
||||||
|
|||||||
@ -40,7 +40,7 @@ const App: FC<Props> = ({ app, onClick, size = 60, isLast }) => {
|
|||||||
|
|
||||||
const handleAddCustomApp = async (values: any) => {
|
const handleAddCustomApp = async (values: any) => {
|
||||||
try {
|
try {
|
||||||
const content = await window.api.file.read('customMiniAPP')
|
const content = await window.api.file.read('custom-minapps.json')
|
||||||
const customApps = JSON.parse(content)
|
const customApps = JSON.parse(content)
|
||||||
|
|
||||||
// Check for duplicate ID
|
// Check for duplicate ID
|
||||||
@ -62,7 +62,7 @@ const App: FC<Props> = ({ app, onClick, size = 60, isLast }) => {
|
|||||||
addTime: new Date().toISOString()
|
addTime: new Date().toISOString()
|
||||||
}
|
}
|
||||||
customApps.push(newApp)
|
customApps.push(newApp)
|
||||||
await window.api.file.writeWithId('customMiniAPP', JSON.stringify(customApps, null, 2))
|
await window.api.file.writeWithId('custom-minapps.json', JSON.stringify(customApps, null, 2))
|
||||||
message.success(t('settings.miniapps.custom.save_success'))
|
message.success(t('settings.miniapps.custom.save_success'))
|
||||||
setIsModalVisible(false)
|
setIsModalVisible(false)
|
||||||
form.resetFields()
|
form.resetFields()
|
||||||
@ -138,10 +138,10 @@ const App: FC<Props> = ({ app, onClick, size = 60, isLast }) => {
|
|||||||
danger: true,
|
danger: true,
|
||||||
onClick: async () => {
|
onClick: async () => {
|
||||||
try {
|
try {
|
||||||
const content = await window.api.file.read('customMiniAPP')
|
const content = await window.api.file.read('custom-minapps.json')
|
||||||
const customApps = JSON.parse(content)
|
const customApps = JSON.parse(content)
|
||||||
const updatedApps = customApps.filter((customApp: MinAppType) => customApp.id !== app.id)
|
const updatedApps = customApps.filter((customApp: MinAppType) => customApp.id !== app.id)
|
||||||
await window.api.file.writeWithId('customMiniAPP', JSON.stringify(updatedApps, null, 2))
|
await window.api.file.writeWithId('custom-minapps.json', JSON.stringify(updatedApps, null, 2))
|
||||||
message.success(t('settings.miniapps.custom.remove_success'))
|
message.success(t('settings.miniapps.custom.remove_success'))
|
||||||
const reloadedApps = [...ORIGIN_DEFAULT_MIN_APPS, ...(await loadCustomMiniApp())]
|
const reloadedApps = [...ORIGIN_DEFAULT_MIN_APPS, ...(await loadCustomMiniApp())]
|
||||||
updateDefaultMinApps(reloadedApps)
|
updateDefaultMinApps(reloadedApps)
|
||||||
|
|||||||
@ -91,7 +91,7 @@ const CitationsList: React.FC<CitationsListProps> = ({ citations }) => {
|
|||||||
onClose={() => setOpen(false)}
|
onClose={() => setOpen(false)}
|
||||||
open={open}
|
open={open}
|
||||||
width={680}
|
width={680}
|
||||||
styles={{ header: { border: 'none' }, body: { paddingTop: 0, backgroundColor: 'var(--color-background)' } }}
|
styles={{ header: { border: 'none' }, body: { paddingTop: 0 } }}
|
||||||
destroyOnClose={false}>
|
destroyOnClose={false}>
|
||||||
{open &&
|
{open &&
|
||||||
citations.map((citation) => (
|
citations.map((citation) => (
|
||||||
@ -127,12 +127,12 @@ const WebSearchCitation: React.FC<{ citation: Citation }> = ({ citation }) => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<WebSearchCard onClick={() => handleLinkClick(citation.url)}>
|
<WebSearchCard>
|
||||||
<WebSearchCardHeader>
|
<WebSearchCardHeader>
|
||||||
{citation.showFavicon && citation.url && (
|
{citation.showFavicon && citation.url && (
|
||||||
<Favicon hostname={new URL(citation.url).hostname} alt={citation.title || citation.hostname || ''} />
|
<Favicon hostname={new URL(citation.url).hostname} alt={citation.title || citation.hostname || ''} />
|
||||||
)}
|
)}
|
||||||
<CitationLink className="text-nowrap">
|
<CitationLink className="text-nowrap" href={citation.url} onClick={(e) => handleLinkClick(citation.url, e)}>
|
||||||
{citation.title || <span className="hostname">{citation.hostname}</span>}
|
{citation.title || <span className="hostname">{citation.hostname}</span>}
|
||||||
</CitationLink>
|
</CitationLink>
|
||||||
</WebSearchCardHeader>
|
</WebSearchCardHeader>
|
||||||
@ -146,10 +146,12 @@ const WebSearchCitation: React.FC<{ citation: Citation }> = ({ citation }) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const KnowledgeCitation: React.FC<{ citation: Citation }> = ({ citation }) => (
|
const KnowledgeCitation: React.FC<{ citation: Citation }> = ({ citation }) => (
|
||||||
<WebSearchCard onClick={() => handleLinkClick(citation.url)}>
|
<WebSearchCard>
|
||||||
<WebSearchCardHeader>
|
<WebSearchCardHeader>
|
||||||
{citation.showFavicon && <FileSearch width={16} />}
|
{citation.showFavicon && <FileSearch width={16} />}
|
||||||
<CitationLink className="text-nowrap">{citation.title}</CitationLink>
|
<CitationLink className="text-nowrap" href={citation.url} onClick={(e) => handleLinkClick(citation.url, e)}>
|
||||||
|
{citation.title}
|
||||||
|
</CitationLink>
|
||||||
</WebSearchCardHeader>
|
</WebSearchCardHeader>
|
||||||
<WebSearchCardContent>{citation.content && truncateText(citation.content, 100)}</WebSearchCardContent>
|
<WebSearchCardContent>{citation.content && truncateText(citation.content, 100)}</WebSearchCardContent>
|
||||||
</WebSearchCard>
|
</WebSearchCard>
|
||||||
@ -189,11 +191,15 @@ const PreviewIcon = styled.div`
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
|
|
||||||
const CitationLink = styled.div`
|
const CitationLink = styled.a`
|
||||||
font-size: 14px;
|
font-size: 14px;
|
||||||
line-height: 1.6;
|
line-height: 1.6;
|
||||||
color: var(--color-text-1);
|
color: var(--color-text-1);
|
||||||
text-decoration: none;
|
text-decoration: none;
|
||||||
|
|
||||||
|
.hostname {
|
||||||
|
color: var(--color-link);
|
||||||
|
}
|
||||||
`
|
`
|
||||||
|
|
||||||
const WebSearchCard = styled.div`
|
const WebSearchCard = styled.div`
|
||||||
@ -204,11 +210,6 @@ const WebSearchCard = styled.div`
|
|||||||
border-radius: var(--list-item-border-radius);
|
border-radius: var(--list-item-border-radius);
|
||||||
background-color: var(--color-background);
|
background-color: var(--color-background);
|
||||||
transition: all 0.3s ease;
|
transition: all 0.3s ease;
|
||||||
cursor: pointer;
|
|
||||||
|
|
||||||
&:hover {
|
|
||||||
background-color: var(--color-background-soft);
|
|
||||||
}
|
|
||||||
`
|
`
|
||||||
|
|
||||||
const WebSearchCardHeader = styled.div`
|
const WebSearchCardHeader = styled.div`
|
||||||
@ -217,7 +218,6 @@ const WebSearchCardHeader = styled.div`
|
|||||||
align-items: center;
|
align-items: center;
|
||||||
gap: 8px;
|
gap: 8px;
|
||||||
margin-bottom: 6px;
|
margin-bottom: 6px;
|
||||||
font-weight: 500;
|
|
||||||
`
|
`
|
||||||
|
|
||||||
const WebSearchCardContent = styled.div`
|
const WebSearchCardContent = styled.div`
|
||||||
|
|||||||
@ -67,7 +67,7 @@ const MessageTools: FC<Props> = ({ blocks }) => {
|
|||||||
const isDone = status === 'done'
|
const isDone = status === 'done'
|
||||||
const hasError = isDone && response?.isError === true
|
const hasError = isDone && response?.isError === true
|
||||||
const result = {
|
const result = {
|
||||||
params: tool.inputSchema,
|
params: toolResponse.arguments,
|
||||||
response: toolResponse.response
|
response: toolResponse.response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -70,6 +70,7 @@ const SettingsTab: FC<Props> = (props) => {
|
|||||||
const [maxTokens, setMaxTokens] = useState(assistant?.settings?.maxTokens ?? 0)
|
const [maxTokens, setMaxTokens] = useState(assistant?.settings?.maxTokens ?? 0)
|
||||||
const [fontSizeValue, setFontSizeValue] = useState(fontSize)
|
const [fontSizeValue, setFontSizeValue] = useState(fontSize)
|
||||||
const [streamOutput, setStreamOutput] = useState(assistant?.settings?.streamOutput ?? true)
|
const [streamOutput, setStreamOutput] = useState(assistant?.settings?.streamOutput ?? true)
|
||||||
|
const [enableToolUse, setEnableToolUse] = useState(assistant?.settings?.enableToolUse ?? false)
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
|
|
||||||
const dispatch = useAppDispatch()
|
const dispatch = useAppDispatch()
|
||||||
@ -222,6 +223,18 @@ const SettingsTab: FC<Props> = (props) => {
|
|||||||
/>
|
/>
|
||||||
</SettingRow>
|
</SettingRow>
|
||||||
<SettingDivider />
|
<SettingDivider />
|
||||||
|
<SettingRow>
|
||||||
|
<SettingRowTitleSmall>{t('models.enable_tool_use')}</SettingRowTitleSmall>
|
||||||
|
<Switch
|
||||||
|
size="small"
|
||||||
|
checked={enableToolUse}
|
||||||
|
onChange={(checked) => {
|
||||||
|
setEnableToolUse(checked)
|
||||||
|
updateAssistantSettings({ enableToolUse: checked })
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</SettingRow>
|
||||||
|
<SettingDivider />
|
||||||
<Row align="middle" justify="space-between" style={{ marginBottom: 10 }}>
|
<Row align="middle" justify="space-between" style={{ marginBottom: 10 }}>
|
||||||
<HStack alignItems="center">
|
<HStack alignItems="center">
|
||||||
<Label>{t('chat.settings.max_tokens')}</Label>
|
<Label>{t('chat.settings.max_tokens')}</Label>
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import { InfoCircleFilled, PlusOutlined, RedoOutlined } from '@ant-design/icons'
|
import { PlusOutlined, RedoOutlined } from '@ant-design/icons'
|
||||||
import IcImageUp from '@renderer/assets/images/paintings/ic_ImageUp.svg'
|
import IcImageUp from '@renderer/assets/images/paintings/ic_ImageUp.svg'
|
||||||
import { Navbar, NavbarCenter, NavbarRight } from '@renderer/components/app/Navbar'
|
import { Navbar, NavbarCenter, NavbarRight } from '@renderer/components/app/Navbar'
|
||||||
import { HStack } from '@renderer/components/Layout'
|
import { HStack } from '@renderer/components/Layout'
|
||||||
@ -20,6 +20,7 @@ import type { PaintingAction, PaintingsState } from '@renderer/types'
|
|||||||
import { getErrorMessage, uuid } from '@renderer/utils'
|
import { getErrorMessage, uuid } from '@renderer/utils'
|
||||||
import { Avatar, Button, Input, InputNumber, Radio, Segmented, Select, Slider, Switch, Tooltip, Upload } from 'antd'
|
import { Avatar, Button, Input, InputNumber, Radio, Segmented, Select, Slider, Switch, Tooltip, Upload } from 'antd'
|
||||||
import TextArea from 'antd/es/input/TextArea'
|
import TextArea from 'antd/es/input/TextArea'
|
||||||
|
import { Info } from 'lucide-react'
|
||||||
import type { FC } from 'react'
|
import type { FC } from 'react'
|
||||||
import { useEffect, useMemo, useRef, useState } from 'react'
|
import { useEffect, useMemo, useRef, useState } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
@ -72,6 +73,7 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => {
|
|||||||
{ label: t('paintings.mode.remix'), value: 'remix' },
|
{ label: t('paintings.mode.remix'), value: 'remix' },
|
||||||
{ label: t('paintings.mode.upscale'), value: 'upscale' }
|
{ label: t('paintings.mode.upscale'), value: 'upscale' }
|
||||||
]
|
]
|
||||||
|
|
||||||
const getNewPainting = () => {
|
const getNewPainting = () => {
|
||||||
return {
|
return {
|
||||||
...DEFAULT_PAINTING,
|
...DEFAULT_PAINTING,
|
||||||
@ -278,14 +280,6 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
removePainting(mode, paintingToDelete)
|
removePainting(mode, paintingToDelete)
|
||||||
|
|
||||||
if (filteredPaintings.length === 1) {
|
|
||||||
const defaultPainting = {
|
|
||||||
...DEFAULT_PAINTING,
|
|
||||||
id: uuid()
|
|
||||||
}
|
|
||||||
setPainting(defaultPainting)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const translate = async () => {
|
const translate = async () => {
|
||||||
@ -334,6 +328,7 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => {
|
|||||||
navigate('../' + providerId, { replace: true })
|
navigate('../' + providerId, { replace: true })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理模式切换
|
// 处理模式切换
|
||||||
const handleModeChange = (value: string) => {
|
const handleModeChange = (value: string) => {
|
||||||
setMode(value as keyof PaintingsState)
|
setMode(value as keyof PaintingsState)
|
||||||
@ -494,8 +489,9 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => {
|
|||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (filteredPaintings.length === 0) {
|
if (filteredPaintings.length === 0) {
|
||||||
addPainting(mode, getNewPainting())
|
const newPainting = getNewPainting()
|
||||||
setPainting(DEFAULT_PAINTING)
|
addPainting(mode, newPainting)
|
||||||
|
setPainting(newPainting)
|
||||||
}
|
}
|
||||||
}, [filteredPaintings, mode, addPainting, painting])
|
}, [filteredPaintings, mode, addPainting, painting])
|
||||||
|
|
||||||
@ -674,11 +670,17 @@ const ToolbarMenu = styled.div`
|
|||||||
gap: 6px;
|
gap: 6px;
|
||||||
`
|
`
|
||||||
|
|
||||||
const InfoIcon = styled(InfoCircleFilled)`
|
const InfoIcon = styled(Info)`
|
||||||
margin-left: 5px;
|
margin-left: 5px;
|
||||||
cursor: help;
|
cursor: help;
|
||||||
color: #8d94a6;
|
color: var(--color-text-2);
|
||||||
font-size: 12px;
|
opacity: 0.6;
|
||||||
|
width: 14px;
|
||||||
|
height: 16px;
|
||||||
|
|
||||||
|
&:hover {
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
`
|
`
|
||||||
|
|
||||||
const SliderContainer = styled.div`
|
const SliderContainer = styled.div`
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import { PlusOutlined, QuestionCircleOutlined, RedoOutlined } from '@ant-design/icons'
|
import { PlusOutlined, RedoOutlined } from '@ant-design/icons'
|
||||||
import ImageSize1_1 from '@renderer/assets/images/paintings/image-size-1-1.svg'
|
import ImageSize1_1 from '@renderer/assets/images/paintings/image-size-1-1.svg'
|
||||||
import ImageSize1_2 from '@renderer/assets/images/paintings/image-size-1-2.svg'
|
import ImageSize1_2 from '@renderer/assets/images/paintings/image-size-1-2.svg'
|
||||||
import ImageSize3_2 from '@renderer/assets/images/paintings/image-size-3-2.svg'
|
import ImageSize3_2 from '@renderer/assets/images/paintings/image-size-3-2.svg'
|
||||||
@ -26,6 +26,7 @@ import type { FileType, Painting } from '@renderer/types'
|
|||||||
import { getErrorMessage, uuid } from '@renderer/utils'
|
import { getErrorMessage, uuid } from '@renderer/utils'
|
||||||
import { Button, Input, InputNumber, Radio, Select, Slider, Switch, Tooltip } from 'antd'
|
import { Button, Input, InputNumber, Radio, Select, Slider, Switch, Tooltip } from 'antd'
|
||||||
import TextArea from 'antd/es/input/TextArea'
|
import TextArea from 'antd/es/input/TextArea'
|
||||||
|
import { Info } from 'lucide-react'
|
||||||
import type { FC } from 'react'
|
import type { FC } from 'react'
|
||||||
import { useEffect, useRef, useState } from 'react'
|
import { useEffect, useRef, useState } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
@ -90,7 +91,7 @@ const DEFAULT_PAINTING: Painting = {
|
|||||||
const PaintingsPage: FC<{ Options: string[] }> = ({ Options }) => {
|
const PaintingsPage: FC<{ Options: string[] }> = ({ Options }) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const { paintings, addPainting, removePainting, updatePainting } = usePaintings()
|
const { paintings, addPainting, removePainting, updatePainting } = usePaintings()
|
||||||
const [painting, setPainting] = useState<Painting>(DEFAULT_PAINTING)
|
const [painting, setPainting] = useState<Painting>(paintings[0] || DEFAULT_PAINTING)
|
||||||
const { theme } = useTheme()
|
const { theme } = useTheme()
|
||||||
const providers = useAllProviders()
|
const providers = useAllProviders()
|
||||||
const providerOptions = Options.map((option) => {
|
const providerOptions = Options.map((option) => {
|
||||||
@ -260,10 +261,6 @@ const PaintingsPage: FC<{ Options: string[] }> = ({ Options }) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
removePainting('paintings', paintingToDelete)
|
removePainting('paintings', paintingToDelete)
|
||||||
|
|
||||||
if (paintings.length === 1) {
|
|
||||||
setPainting(getNewPainting())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const onSelectPainting = (newPainting: Painting) => {
|
const onSelectPainting = (newPainting: Painting) => {
|
||||||
@ -326,8 +323,11 @@ const PaintingsPage: FC<{ Options: string[] }> = ({ Options }) => {
|
|||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (paintings.length === 0) {
|
if (paintings.length === 0) {
|
||||||
addPainting('paintings', getNewPainting())
|
const newPainting = getNewPainting()
|
||||||
|
addPainting('paintings', newPainting)
|
||||||
|
setPainting(newPainting)
|
||||||
}
|
}
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
if (spaceClickTimer.current) {
|
if (spaceClickTimer.current) {
|
||||||
clearTimeout(spaceClickTimer.current)
|
clearTimeout(spaceClickTimer.current)
|
||||||
@ -602,11 +602,13 @@ const RadioButton = styled(Radio.Button)`
|
|||||||
align-items: center;
|
align-items: center;
|
||||||
`
|
`
|
||||||
|
|
||||||
const InfoIcon = styled(QuestionCircleOutlined)`
|
const InfoIcon = styled(Info)`
|
||||||
margin-left: 5px;
|
margin-left: 5px;
|
||||||
cursor: help;
|
cursor: help;
|
||||||
color: var(--color-text-2);
|
color: var(--color-text-2);
|
||||||
opacity: 0.6;
|
opacity: 0.6;
|
||||||
|
width: 16px;
|
||||||
|
height: 16px;
|
||||||
|
|
||||||
&:hover {
|
&:hover {
|
||||||
opacity: 1;
|
opacity: 1;
|
||||||
|
|||||||
@ -24,6 +24,7 @@ const AssistantModelSettings: FC<Props> = ({ assistant, updateAssistant, updateA
|
|||||||
const [enableMaxTokens, setEnableMaxTokens] = useState(assistant?.settings?.enableMaxTokens ?? false)
|
const [enableMaxTokens, setEnableMaxTokens] = useState(assistant?.settings?.enableMaxTokens ?? false)
|
||||||
const [maxTokens, setMaxTokens] = useState(assistant?.settings?.maxTokens ?? 0)
|
const [maxTokens, setMaxTokens] = useState(assistant?.settings?.maxTokens ?? 0)
|
||||||
const [streamOutput, setStreamOutput] = useState(assistant?.settings?.streamOutput ?? true)
|
const [streamOutput, setStreamOutput] = useState(assistant?.settings?.streamOutput ?? true)
|
||||||
|
const [enableToolUse, setEnableToolUse] = useState(assistant?.settings?.enableToolUse ?? false)
|
||||||
const [defaultModel, setDefaultModel] = useState(assistant?.defaultModel)
|
const [defaultModel, setDefaultModel] = useState(assistant?.defaultModel)
|
||||||
const [topP, setTopP] = useState(assistant?.settings?.topP ?? 1)
|
const [topP, setTopP] = useState(assistant?.settings?.topP ?? 1)
|
||||||
const [customParameters, setCustomParameters] = useState<AssistantSettingCustomParameters[]>(
|
const [customParameters, setCustomParameters] = useState<AssistantSettingCustomParameters[]>(
|
||||||
@ -377,6 +378,18 @@ const AssistantModelSettings: FC<Props> = ({ assistant, updateAssistant, updateA
|
|||||||
/>
|
/>
|
||||||
</SettingRow>
|
</SettingRow>
|
||||||
<Divider style={{ margin: '10px 0' }} />
|
<Divider style={{ margin: '10px 0' }} />
|
||||||
|
<SettingRow style={{ minHeight: 30 }}>
|
||||||
|
<Label>{t('models.enable_tool_use')}</Label>
|
||||||
|
<Switch
|
||||||
|
size="small"
|
||||||
|
checked={enableToolUse}
|
||||||
|
onChange={(checked) => {
|
||||||
|
setEnableToolUse(checked)
|
||||||
|
updateAssistantSettings({ enableToolUse: checked })
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</SettingRow>
|
||||||
|
<Divider style={{ margin: '10px 0' }} />
|
||||||
<SettingRow style={{ minHeight: 30 }}>
|
<SettingRow style={{ minHeight: 30 }}>
|
||||||
<Label>{t('models.custom_parameters')}</Label>
|
<Label>{t('models.custom_parameters')}</Label>
|
||||||
<Button icon={<PlusOutlined />} onClick={onAddCustomParameter}>
|
<Button icon={<PlusOutlined />} onClick={onAddCustomParameter}>
|
||||||
|
|||||||
@ -1,10 +1,5 @@
|
|||||||
import { UndoOutlined } from '@ant-design/icons' // 导入重置图标
|
import { UndoOutlined } from '@ant-design/icons' // 导入重置图标
|
||||||
import {
|
import { DEFAULT_MIN_APPS } from '@renderer/config/minapps'
|
||||||
DEFAULT_MIN_APPS,
|
|
||||||
loadCustomMiniApp,
|
|
||||||
ORIGIN_DEFAULT_MIN_APPS,
|
|
||||||
updateDefaultMinApps
|
|
||||||
} from '@renderer/config/minapps'
|
|
||||||
import { useTheme } from '@renderer/context/ThemeProvider'
|
import { useTheme } from '@renderer/context/ThemeProvider'
|
||||||
import { useMinapps } from '@renderer/hooks/useMinapps'
|
import { useMinapps } from '@renderer/hooks/useMinapps'
|
||||||
import { useSettings } from '@renderer/hooks/useSettings'
|
import { useSettings } from '@renderer/hooks/useSettings'
|
||||||
@ -14,7 +9,7 @@ import {
|
|||||||
setMinappsOpenLinkExternal,
|
setMinappsOpenLinkExternal,
|
||||||
setShowOpenedMinappsInSidebar
|
setShowOpenedMinappsInSidebar
|
||||||
} from '@renderer/store/settings'
|
} from '@renderer/store/settings'
|
||||||
import { Button, Input, message, Slider, Switch, Tooltip } from 'antd'
|
import { Button, message, Slider, Switch, Tooltip } from 'antd'
|
||||||
import { FC, useCallback, useEffect, useRef, useState } from 'react'
|
import { FC, useCallback, useEffect, useRef, useState } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import styled from 'styled-components'
|
import styled from 'styled-components'
|
||||||
@ -36,92 +31,6 @@ const MiniAppSettings: FC = () => {
|
|||||||
const [disabledMiniApps, setDisabledMiniApps] = useState(disabled || [])
|
const [disabledMiniApps, setDisabledMiniApps] = useState(disabled || [])
|
||||||
const [messageApi, contextHolder] = message.useMessage()
|
const [messageApi, contextHolder] = message.useMessage()
|
||||||
const debounceTimerRef = useRef<NodeJS.Timeout | null>(null)
|
const debounceTimerRef = useRef<NodeJS.Timeout | null>(null)
|
||||||
const [customMiniAppContent, setCustomMiniAppContent] = useState('[]')
|
|
||||||
|
|
||||||
// 加载自定义小应用配置
|
|
||||||
useEffect(() => {
|
|
||||||
const loadCustomMiniApp = async () => {
|
|
||||||
try {
|
|
||||||
const content = await window.api.file.read('customMiniAPP')
|
|
||||||
let validContent = '[]'
|
|
||||||
try {
|
|
||||||
const parsed = JSON.parse(content)
|
|
||||||
validContent = JSON.stringify(parsed)
|
|
||||||
} catch (e) {
|
|
||||||
console.error('Invalid JSON format in custom mini app config:', e)
|
|
||||||
}
|
|
||||||
setCustomMiniAppContent(validContent)
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Failed to load custom mini app config:', error)
|
|
||||||
setCustomMiniAppContent('[]')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
loadCustomMiniApp()
|
|
||||||
}, [])
|
|
||||||
|
|
||||||
// 保存自定义小应用配置
|
|
||||||
const handleSaveCustomMiniApp = useCallback(async () => {
|
|
||||||
try {
|
|
||||||
// 验证 JSON 格式
|
|
||||||
if (customMiniAppContent === '') {
|
|
||||||
setCustomMiniAppContent('[]')
|
|
||||||
}
|
|
||||||
const parsedContent = JSON.parse(customMiniAppContent)
|
|
||||||
// 确保是数组
|
|
||||||
if (!Array.isArray(parsedContent)) {
|
|
||||||
throw new Error('Content must be an array')
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查自定义应用中的重复ID
|
|
||||||
const customIds = new Set<string>()
|
|
||||||
const duplicateIds = new Set<string>()
|
|
||||||
parsedContent.forEach((app: any) => {
|
|
||||||
if (app.id) {
|
|
||||||
if (customIds.has(app.id)) {
|
|
||||||
duplicateIds.add(app.id)
|
|
||||||
}
|
|
||||||
customIds.add(app.id)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// 检查与默认应用的ID重复
|
|
||||||
const defaultIds = new Set(ORIGIN_DEFAULT_MIN_APPS.map((app) => app.id))
|
|
||||||
const conflictingIds = new Set<string>()
|
|
||||||
customIds.forEach((id) => {
|
|
||||||
if (defaultIds.has(id)) {
|
|
||||||
conflictingIds.add(id)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// 如果有重复ID,显示错误信息
|
|
||||||
if (duplicateIds.size > 0 || conflictingIds.size > 0) {
|
|
||||||
let errorMessage = ''
|
|
||||||
if (duplicateIds.size > 0) {
|
|
||||||
errorMessage += t('settings.miniapps.custom.duplicate_ids', { ids: Array.from(duplicateIds).join(', ') })
|
|
||||||
}
|
|
||||||
if (conflictingIds.size > 0) {
|
|
||||||
console.log('conflictingIds', Array.from(conflictingIds))
|
|
||||||
if (errorMessage) errorMessage += '\n'
|
|
||||||
errorMessage += t('settings.miniapps.custom.conflicting_ids', { ids: Array.from(conflictingIds).join(', ') })
|
|
||||||
}
|
|
||||||
messageApi.error(errorMessage)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存文件
|
|
||||||
await window.api.file.writeWithId('customMiniAPP', customMiniAppContent)
|
|
||||||
messageApi.success(t('settings.miniapps.custom.save_success'))
|
|
||||||
// 重新加载应用列表
|
|
||||||
console.log('Reloading mini app list...')
|
|
||||||
const reloadedApps = [...ORIGIN_DEFAULT_MIN_APPS, ...(await loadCustomMiniApp())]
|
|
||||||
updateDefaultMinApps(reloadedApps)
|
|
||||||
console.log('Reloaded mini app list:', reloadedApps)
|
|
||||||
updateMinapps(reloadedApps)
|
|
||||||
} catch (error) {
|
|
||||||
messageApi.error(t('settings.miniapps.custom.save_error'))
|
|
||||||
console.error('Failed to save custom mini app config:', error)
|
|
||||||
}
|
|
||||||
}, [customMiniAppContent, messageApi, t, updateMinapps])
|
|
||||||
|
|
||||||
const handleResetMinApps = useCallback(() => {
|
const handleResetMinApps = useCallback(() => {
|
||||||
setVisibleMiniApps(DEFAULT_MIN_APPS)
|
setVisibleMiniApps(DEFAULT_MIN_APPS)
|
||||||
@ -235,30 +144,6 @@ const MiniAppSettings: FC = () => {
|
|||||||
onChange={(checked) => dispatch(setShowOpenedMinappsInSidebar(checked))}
|
onChange={(checked) => dispatch(setShowOpenedMinappsInSidebar(checked))}
|
||||||
/>
|
/>
|
||||||
</SettingRow>
|
</SettingRow>
|
||||||
<SettingDivider />
|
|
||||||
<SettingRow>
|
|
||||||
<SettingLabelGroup>
|
|
||||||
<SettingRowTitle>{t('settings.miniapps.custom.edit_title')}</SettingRowTitle>
|
|
||||||
<SettingDescription>{t('settings.miniapps.custom.edit_description')}</SettingDescription>
|
|
||||||
</SettingLabelGroup>
|
|
||||||
</SettingRow>
|
|
||||||
<CustomEditorContainer>
|
|
||||||
<Input.TextArea
|
|
||||||
value={customMiniAppContent}
|
|
||||||
onChange={(e) => setCustomMiniAppContent(e.target.value)}
|
|
||||||
placeholder={t('settings.miniapps.custom.placeholder')}
|
|
||||||
style={{
|
|
||||||
minHeight: 200,
|
|
||||||
fontFamily: 'monospace',
|
|
||||||
backgroundColor: 'var(--color-bg-2)',
|
|
||||||
color: 'var(--color-text)',
|
|
||||||
borderColor: 'var(--color-border)'
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
<Button type="primary" onClick={handleSaveCustomMiniApp} style={{ marginTop: 8 }}>
|
|
||||||
{t('settings.miniapps.custom.save')}
|
|
||||||
</Button>
|
|
||||||
</CustomEditorContainer>
|
|
||||||
</SettingGroup>
|
</SettingGroup>
|
||||||
</SettingContainer>
|
</SettingContainer>
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import { isOpenAILLMModel } from '@renderer/config/models'
|
import { isOpenAILLMModel } from '@renderer/config/models'
|
||||||
import { getDefaultModel } from '@renderer/services/AssistantService'
|
import { getDefaultModel } from '@renderer/services/AssistantService'
|
||||||
import { Assistant, Model, Provider, Suggestion } from '@renderer/types'
|
import { Assistant, MCPCallToolResponse, MCPTool, MCPToolResponse, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
import { Message } from '@renderer/types/newMessage'
|
import { Message } from '@renderer/types/newMessage'
|
||||||
import OpenAI from 'openai'
|
import OpenAI from 'openai'
|
||||||
|
|
||||||
@ -18,6 +18,7 @@ import OpenAIProvider from './OpenAIProvider'
|
|||||||
export default class AihubmixProvider extends BaseProvider {
|
export default class AihubmixProvider extends BaseProvider {
|
||||||
private providers: Map<string, BaseProvider> = new Map()
|
private providers: Map<string, BaseProvider> = new Map()
|
||||||
private defaultProvider: BaseProvider
|
private defaultProvider: BaseProvider
|
||||||
|
private currentProvider: BaseProvider
|
||||||
|
|
||||||
constructor(provider: Provider) {
|
constructor(provider: Provider) {
|
||||||
super(provider)
|
super(provider)
|
||||||
@ -30,6 +31,7 @@ export default class AihubmixProvider extends BaseProvider {
|
|||||||
|
|
||||||
// 设置默认提供商
|
// 设置默认提供商
|
||||||
this.defaultProvider = this.providers.get('default')!
|
this.defaultProvider = this.providers.get('default')!
|
||||||
|
this.currentProvider = this.defaultProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -70,7 +72,8 @@ export default class AihubmixProvider extends BaseProvider {
|
|||||||
|
|
||||||
public async completions(params: CompletionsParams): Promise<void> {
|
public async completions(params: CompletionsParams): Promise<void> {
|
||||||
const model = params.assistant.model
|
const model = params.assistant.model
|
||||||
return this.getProvider(model!).completions(params)
|
this.currentProvider = this.getProvider(model!)
|
||||||
|
return this.currentProvider.completions(params)
|
||||||
}
|
}
|
||||||
|
|
||||||
public async translate(
|
public async translate(
|
||||||
@ -100,4 +103,12 @@ export default class AihubmixProvider extends BaseProvider {
|
|||||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||||
return this.getProvider(model).getEmbeddingDimensions(model)
|
return this.getProvider(model).getEmbeddingDimensions(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public convertMcpTools<T>(mcpTools: MCPTool[]) {
|
||||||
|
return this.currentProvider.convertMcpTools(mcpTools) as T[]
|
||||||
|
}
|
||||||
|
|
||||||
|
public mcpToolCallResponseToMessage(mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) {
|
||||||
|
return this.currentProvider.mcpToolCallResponseToMessage(mcpToolResponse, resp, model)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,15 +1,19 @@
|
|||||||
import Anthropic from '@anthropic-ai/sdk'
|
import Anthropic from '@anthropic-ai/sdk'
|
||||||
import {
|
import {
|
||||||
|
Base64ImageSource,
|
||||||
|
ImageBlockParam,
|
||||||
MessageCreateParamsNonStreaming,
|
MessageCreateParamsNonStreaming,
|
||||||
MessageParam,
|
MessageParam,
|
||||||
TextBlockParam,
|
TextBlockParam,
|
||||||
|
ToolResultBlockParam,
|
||||||
ToolUnion,
|
ToolUnion,
|
||||||
|
ToolUseBlock,
|
||||||
WebSearchResultBlock,
|
WebSearchResultBlock,
|
||||||
WebSearchTool20250305,
|
WebSearchTool20250305,
|
||||||
WebSearchToolResultError
|
WebSearchToolResultError
|
||||||
} from '@anthropic-ai/sdk/resources'
|
} from '@anthropic-ai/sdk/resources'
|
||||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||||
import { isReasoningModel, isVisionModel, isWebSearchModel } from '@renderer/config/models'
|
import { isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
||||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||||
import i18n from '@renderer/i18n'
|
import i18n from '@renderer/i18n'
|
||||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||||
@ -23,16 +27,24 @@ import {
|
|||||||
Assistant,
|
Assistant,
|
||||||
EFFORT_RATIO,
|
EFFORT_RATIO,
|
||||||
FileTypes,
|
FileTypes,
|
||||||
|
MCPCallToolResponse,
|
||||||
|
MCPTool,
|
||||||
MCPToolResponse,
|
MCPToolResponse,
|
||||||
Model,
|
Model,
|
||||||
Provider,
|
Provider,
|
||||||
Suggestion,
|
Suggestion,
|
||||||
|
ToolCallResponse,
|
||||||
WebSearchSource
|
WebSearchSource
|
||||||
} from '@renderer/types'
|
} from '@renderer/types'
|
||||||
import { ChunkType } from '@renderer/types/chunk'
|
import { ChunkType } from '@renderer/types/chunk'
|
||||||
import type { Message } from '@renderer/types/newMessage'
|
import type { Message } from '@renderer/types/newMessage'
|
||||||
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
||||||
import { mcpToolCallResponseToAnthropicMessage, parseAndCallTools } from '@renderer/utils/mcp-tools'
|
import {
|
||||||
|
anthropicToolUseToMcpTool,
|
||||||
|
mcpToolCallResponseToAnthropicMessage,
|
||||||
|
mcpToolsToAnthropicTools,
|
||||||
|
parseAndCallTools
|
||||||
|
} from '@renderer/utils/mcp-tools'
|
||||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||||
import { first, flatten, sum, takeRight } from 'lodash'
|
import { first, flatten, sum, takeRight } from 'lodash'
|
||||||
@ -199,7 +211,7 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) {
|
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) {
|
||||||
const defaultModel = getDefaultModel()
|
const defaultModel = getDefaultModel()
|
||||||
const model = assistant.model || defaultModel
|
const model = assistant.model || defaultModel
|
||||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant)
|
||||||
|
|
||||||
const userMessagesParams: MessageParam[] = []
|
const userMessagesParams: MessageParam[] = []
|
||||||
|
|
||||||
@ -215,10 +227,16 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
|
|
||||||
const userMessages = flatten(userMessagesParams)
|
const userMessages = flatten(userMessagesParams)
|
||||||
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
|
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
|
||||||
// const tools = mcpTools ? mcpToolsToAnthropicTools(mcpTools) : undefined
|
|
||||||
|
|
||||||
let systemPrompt = assistant.prompt
|
let systemPrompt = assistant.prompt
|
||||||
if (mcpTools && mcpTools.length > 0) {
|
|
||||||
|
const { tools } = this.setupToolsConfig<ToolUnion>({
|
||||||
|
model,
|
||||||
|
mcpTools,
|
||||||
|
enableToolUse
|
||||||
|
})
|
||||||
|
|
||||||
|
if (this.useSystemPromptForTools && mcpTools && mcpTools.length) {
|
||||||
systemPrompt = buildSystemPrompt(systemPrompt, mcpTools)
|
systemPrompt = buildSystemPrompt(systemPrompt, mcpTools)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -232,8 +250,6 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
|
|
||||||
const isEnabledBuiltinWebSearch = assistant.enableWebSearch
|
const isEnabledBuiltinWebSearch = assistant.enableWebSearch
|
||||||
|
|
||||||
const tools: ToolUnion[] = []
|
|
||||||
|
|
||||||
if (isEnabledBuiltinWebSearch) {
|
if (isEnabledBuiltinWebSearch) {
|
||||||
const webSearchTool = await this.getWebSearchParams(model)
|
const webSearchTool = await this.getWebSearchParams(model)
|
||||||
if (webSearchTool) {
|
if (webSearchTool) {
|
||||||
@ -244,7 +260,6 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
const body: MessageCreateParamsNonStreaming = {
|
const body: MessageCreateParamsNonStreaming = {
|
||||||
model: model.id,
|
model: model.id,
|
||||||
messages: userMessages,
|
messages: userMessages,
|
||||||
// tools: isEmpty(tools) ? undefined : tools,
|
|
||||||
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
|
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||||
temperature: this.getTemperature(assistant, model),
|
temperature: this.getTemperature(assistant, model),
|
||||||
top_p: this.getTopP(assistant, model),
|
top_p: this.getTopP(assistant, model),
|
||||||
@ -303,7 +318,7 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
const processStream = (body: MessageCreateParamsNonStreaming, idx: number) => {
|
const processStream = (body: MessageCreateParamsNonStreaming, idx: number) => {
|
||||||
return new Promise<void>((resolve, reject) => {
|
return new Promise<void>((resolve, reject) => {
|
||||||
// 等待接口返回流
|
// 等待接口返回流
|
||||||
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
const toolCalls: ToolUseBlock[] = []
|
||||||
let hasThinkingContent = false
|
let hasThinkingContent = false
|
||||||
this.sdk.messages
|
this.sdk.messages
|
||||||
.stream({ ...body, stream: true }, { signal, timeout: 5 * 60 * 1000 })
|
.stream({ ...body, stream: true }, { signal, timeout: 5 * 60 * 1000 })
|
||||||
@ -380,30 +395,70 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
})
|
})
|
||||||
thinking_content += thinking
|
thinking_content += thinking
|
||||||
})
|
})
|
||||||
|
.on('contentBlock', (content) => {
|
||||||
|
if (content.type === 'tool_use') {
|
||||||
|
toolCalls.push(content)
|
||||||
|
}
|
||||||
|
})
|
||||||
.on('finalMessage', async (message) => {
|
.on('finalMessage', async (message) => {
|
||||||
|
const toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
|
||||||
|
// tool call
|
||||||
|
if (toolCalls.length > 0) {
|
||||||
|
const mcpToolResponses = toolCalls
|
||||||
|
.map((toolCall) => {
|
||||||
|
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
|
||||||
|
if (!mcpTool) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
id: toolCall.id,
|
||||||
|
toolCallId: toolCall.id,
|
||||||
|
tool: mcpTool,
|
||||||
|
arguments: toolCall.input as Record<string, unknown>,
|
||||||
|
status: 'pending'
|
||||||
|
} as ToolCallResponse
|
||||||
|
})
|
||||||
|
.filter((t) => typeof t !== 'undefined')
|
||||||
|
toolResults.push(
|
||||||
|
...(await parseAndCallTools(
|
||||||
|
mcpToolResponses,
|
||||||
|
toolResponses,
|
||||||
|
onChunk,
|
||||||
|
this.mcpToolCallResponseToMessage,
|
||||||
|
model,
|
||||||
|
mcpTools
|
||||||
|
))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tool use
|
||||||
const content = message.content[0]
|
const content = message.content[0]
|
||||||
if (content && content.type === 'text') {
|
if (content && content.type === 'text') {
|
||||||
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content.text })
|
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content.text })
|
||||||
const toolResults = await parseAndCallTools(
|
toolResults.push(
|
||||||
content.text,
|
...(await parseAndCallTools(
|
||||||
toolResponses,
|
content.text,
|
||||||
onChunk,
|
toolResponses,
|
||||||
idx,
|
onChunk,
|
||||||
mcpToolCallResponseToAnthropicMessage,
|
this.mcpToolCallResponseToMessage,
|
||||||
mcpTools,
|
model,
|
||||||
isVisionModel(model)
|
mcpTools
|
||||||
|
))
|
||||||
)
|
)
|
||||||
if (toolResults.length > 0) {
|
}
|
||||||
userMessages.push({
|
|
||||||
role: message.role,
|
|
||||||
content: message.content
|
|
||||||
})
|
|
||||||
|
|
||||||
toolResults.forEach((ts) => userMessages.push(ts as MessageParam))
|
userMessages.push({
|
||||||
const newBody = body
|
role: message.role,
|
||||||
newBody.messages = userMessages
|
content: message.content
|
||||||
await processStream(newBody, idx + 1)
|
})
|
||||||
}
|
|
||||||
|
if (toolResults.length > 0) {
|
||||||
|
toolResults.forEach((ts) => userMessages.push(ts as MessageParam))
|
||||||
|
const newBody = body
|
||||||
|
newBody.messages = userMessages
|
||||||
|
|
||||||
|
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||||
|
await processStream(newBody, idx + 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||||
@ -434,7 +489,7 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||||
await processStream(body, 0).finally(cleanup)
|
await processStream(body, 0).finally(cleanup)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -683,4 +738,47 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
public async getEmbeddingDimensions(): Promise<number> {
|
public async getEmbeddingDimensions(): Promise<number> {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public convertMcpTools<T>(mcpTools: MCPTool[]): T[] {
|
||||||
|
return mcpToolsToAnthropicTools(mcpTools) as T[]
|
||||||
|
}
|
||||||
|
|
||||||
|
public mcpToolCallResponseToMessage = (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => {
|
||||||
|
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||||
|
return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model)
|
||||||
|
} else if ('toolCallId' in mcpToolResponse) {
|
||||||
|
return {
|
||||||
|
role: 'user',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'tool_result',
|
||||||
|
tool_use_id: mcpToolResponse.toolCallId!,
|
||||||
|
content: resp.content
|
||||||
|
.map((item) => {
|
||||||
|
if (item.type === 'text') {
|
||||||
|
return {
|
||||||
|
type: 'text',
|
||||||
|
text: item.text || ''
|
||||||
|
} satisfies TextBlockParam
|
||||||
|
}
|
||||||
|
if (item.type === 'image') {
|
||||||
|
return {
|
||||||
|
type: 'image',
|
||||||
|
source: {
|
||||||
|
data: item.data || '',
|
||||||
|
media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'],
|
||||||
|
type: 'base64'
|
||||||
|
}
|
||||||
|
} satisfies ImageBlockParam
|
||||||
|
}
|
||||||
|
return
|
||||||
|
})
|
||||||
|
.filter((n) => typeof n !== 'undefined'),
|
||||||
|
is_error: resp.isError
|
||||||
|
} satisfies ToolResultBlockParam
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,9 +1,13 @@
|
|||||||
|
import { isFunctionCallingModel } from '@renderer/config/models'
|
||||||
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
||||||
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
|
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
|
||||||
import type {
|
import type {
|
||||||
Assistant,
|
Assistant,
|
||||||
GenerateImageParams,
|
GenerateImageParams,
|
||||||
KnowledgeReference,
|
KnowledgeReference,
|
||||||
|
MCPCallToolResponse,
|
||||||
|
MCPTool,
|
||||||
|
MCPToolResponse,
|
||||||
Model,
|
Model,
|
||||||
Provider,
|
Provider,
|
||||||
Suggestion,
|
Suggestion,
|
||||||
@ -22,10 +26,15 @@ import type OpenAI from 'openai'
|
|||||||
import type { CompletionsParams } from '.'
|
import type { CompletionsParams } from '.'
|
||||||
|
|
||||||
export default abstract class BaseProvider {
|
export default abstract class BaseProvider {
|
||||||
|
// Threshold for determining whether to use system prompt for tools
|
||||||
|
private static readonly SYSTEM_PROMPT_THRESHOLD: number = 128
|
||||||
|
|
||||||
protected provider: Provider
|
protected provider: Provider
|
||||||
protected host: string
|
protected host: string
|
||||||
protected apiKey: string
|
protected apiKey: string
|
||||||
|
|
||||||
|
protected useSystemPromptForTools: boolean = true
|
||||||
|
|
||||||
constructor(provider: Provider) {
|
constructor(provider: Provider) {
|
||||||
this.provider = provider
|
this.provider = provider
|
||||||
this.host = this.getBaseURL()
|
this.host = this.getBaseURL()
|
||||||
@ -47,6 +56,12 @@ export default abstract class BaseProvider {
|
|||||||
abstract generateImage(params: GenerateImageParams): Promise<string[]>
|
abstract generateImage(params: GenerateImageParams): Promise<string[]>
|
||||||
abstract generateImageByChat({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
|
abstract generateImageByChat({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
|
||||||
abstract getEmbeddingDimensions(model: Model): Promise<number>
|
abstract getEmbeddingDimensions(model: Model): Promise<number>
|
||||||
|
public abstract convertMcpTools<T>(mcpTools: MCPTool[]): T[]
|
||||||
|
public abstract mcpToolCallResponseToMessage(
|
||||||
|
mcpToolResponse: MCPToolResponse,
|
||||||
|
resp: MCPCallToolResponse,
|
||||||
|
model: Model
|
||||||
|
): any
|
||||||
|
|
||||||
public getBaseURL(): string {
|
public getBaseURL(): string {
|
||||||
const host = this.provider.apiHost
|
const host = this.provider.apiHost
|
||||||
@ -229,4 +244,31 @@ export default abstract class BaseProvider {
|
|||||||
cleanup
|
cleanup
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Setup tools configuration based on provided parameters
|
||||||
|
protected setupToolsConfig<T>(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
||||||
|
tools: T[]
|
||||||
|
} {
|
||||||
|
const { mcpTools, model, enableToolUse } = params
|
||||||
|
let tools: T[] = []
|
||||||
|
|
||||||
|
// If there are no tools, return an empty array
|
||||||
|
if (!mcpTools?.length) {
|
||||||
|
return { tools }
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the number of tools exceeds the threshold, use the system prompt
|
||||||
|
if (mcpTools.length > BaseProvider.SYSTEM_PROMPT_THRESHOLD) {
|
||||||
|
this.useSystemPromptForTools = true
|
||||||
|
return { tools }
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the model supports function calling and tool usage is enabled
|
||||||
|
if (isFunctionCallingModel(model) && enableToolUse) {
|
||||||
|
tools = this.convertMcpTools<T>(mcpTools)
|
||||||
|
this.useSystemPromptForTools = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return { tools }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import {
|
import {
|
||||||
Content,
|
Content,
|
||||||
File,
|
File,
|
||||||
|
FunctionCall,
|
||||||
GenerateContentConfig,
|
GenerateContentConfig,
|
||||||
GenerateContentResponse,
|
GenerateContentResponse,
|
||||||
GoogleGenAI,
|
GoogleGenAI,
|
||||||
@ -11,8 +12,9 @@ import {
|
|||||||
PartUnion,
|
PartUnion,
|
||||||
SafetySetting,
|
SafetySetting,
|
||||||
ThinkingConfig,
|
ThinkingConfig,
|
||||||
ToolListUnion
|
Tool
|
||||||
} from '@google/genai'
|
} from '@google/genai'
|
||||||
|
import { nanoid } from '@reduxjs/toolkit'
|
||||||
import {
|
import {
|
||||||
findTokenLimit,
|
findTokenLimit,
|
||||||
isGeminiReasoningModel,
|
isGeminiReasoningModel,
|
||||||
@ -35,17 +37,25 @@ import {
|
|||||||
EFFORT_RATIO,
|
EFFORT_RATIO,
|
||||||
FileType,
|
FileType,
|
||||||
FileTypes,
|
FileTypes,
|
||||||
|
MCPCallToolResponse,
|
||||||
|
MCPTool,
|
||||||
MCPToolResponse,
|
MCPToolResponse,
|
||||||
Model,
|
Model,
|
||||||
Provider,
|
Provider,
|
||||||
Suggestion,
|
Suggestion,
|
||||||
|
ToolCallResponse,
|
||||||
Usage,
|
Usage,
|
||||||
WebSearchSource
|
WebSearchSource
|
||||||
} from '@renderer/types'
|
} from '@renderer/types'
|
||||||
import { BlockCompleteChunk, Chunk, ChunkType, LLMWebSearchCompleteChunk } from '@renderer/types/chunk'
|
import { BlockCompleteChunk, Chunk, ChunkType, LLMWebSearchCompleteChunk } from '@renderer/types/chunk'
|
||||||
import type { Message, Response } from '@renderer/types/newMessage'
|
import type { Message, Response } from '@renderer/types/newMessage'
|
||||||
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
||||||
import { mcpToolCallResponseToGeminiMessage, parseAndCallTools } from '@renderer/utils/mcp-tools'
|
import {
|
||||||
|
geminiFunctionCallToMcpTool,
|
||||||
|
mcpToolCallResponseToGeminiMessage,
|
||||||
|
mcpToolsToGeminiTools,
|
||||||
|
parseAndCallTools
|
||||||
|
} from '@renderer/utils/mcp-tools'
|
||||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||||
import { MB } from '@shared/config/constant'
|
import { MB } from '@shared/config/constant'
|
||||||
@ -263,7 +273,7 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
}: CompletionsParams): Promise<void> {
|
}: CompletionsParams): Promise<void> {
|
||||||
const defaultModel = getDefaultModel()
|
const defaultModel = getDefaultModel()
|
||||||
const model = assistant.model || defaultModel
|
const model = assistant.model || defaultModel
|
||||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant)
|
||||||
|
|
||||||
const userMessages = filterUserRoleStartMessages(
|
const userMessages = filterUserRoleStartMessages(
|
||||||
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
|
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
|
||||||
@ -280,12 +290,16 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
|
|
||||||
let systemInstruction = assistant.prompt
|
let systemInstruction = assistant.prompt
|
||||||
|
|
||||||
if (mcpTools && mcpTools.length > 0) {
|
const { tools } = this.setupToolsConfig<Tool>({
|
||||||
|
mcpTools,
|
||||||
|
model,
|
||||||
|
enableToolUse
|
||||||
|
})
|
||||||
|
|
||||||
|
if (this.useSystemPromptForTools) {
|
||||||
systemInstruction = buildSystemPrompt(assistant.prompt || '', mcpTools)
|
systemInstruction = buildSystemPrompt(assistant.prompt || '', mcpTools)
|
||||||
}
|
}
|
||||||
|
|
||||||
// const tools = mcpToolsToGeminiTools(mcpTools)
|
|
||||||
const tools: ToolListUnion = []
|
|
||||||
const toolResponses: MCPToolResponse[] = []
|
const toolResponses: MCPToolResponse[] = []
|
||||||
|
|
||||||
if (assistant.enableWebSearch && isWebSearchModel(model)) {
|
if (assistant.enableWebSearch && isWebSearchModel(model)) {
|
||||||
@ -351,6 +365,224 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
|
|
||||||
const { cleanup, abortController } = this.createAbortController(userLastMessage?.id, true)
|
const { cleanup, abortController } = this.createAbortController(userLastMessage?.id, true)
|
||||||
|
|
||||||
|
const processToolResults = async (toolResults: Awaited<ReturnType<typeof parseAndCallTools>>, idx: number) => {
|
||||||
|
if (toolResults.length === 0) return
|
||||||
|
const newChat = this.sdk.chats.create({
|
||||||
|
model: model.id,
|
||||||
|
config: generateContentConfig,
|
||||||
|
history: history as Content[]
|
||||||
|
})
|
||||||
|
|
||||||
|
const newStream = await newChat.sendMessageStream({
|
||||||
|
message: flatten(toolResults.map((ts) => (ts as Content).parts)) as PartUnion,
|
||||||
|
config: {
|
||||||
|
...generateContentConfig,
|
||||||
|
abortSignal: abortController.signal
|
||||||
|
}
|
||||||
|
})
|
||||||
|
await processStream(newStream, idx + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
const processToolCalls = async (toolCalls: FunctionCall[]) => {
|
||||||
|
const mcpToolResponses: ToolCallResponse[] = toolCalls
|
||||||
|
.map((toolCall) => {
|
||||||
|
const mcpTool = geminiFunctionCallToMcpTool(mcpTools, toolCall)
|
||||||
|
if (!mcpTool) return undefined
|
||||||
|
|
||||||
|
const parsedArgs = (() => {
|
||||||
|
try {
|
||||||
|
return typeof toolCall.args === 'string' ? JSON.parse(toolCall.args) : toolCall.args
|
||||||
|
} catch {
|
||||||
|
return toolCall.args
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
return {
|
||||||
|
id: toolCall.id || nanoid(),
|
||||||
|
toolCallId: toolCall.id,
|
||||||
|
tool: mcpTool,
|
||||||
|
arguments: parsedArgs,
|
||||||
|
status: 'pending'
|
||||||
|
} as ToolCallResponse
|
||||||
|
})
|
||||||
|
.filter((t): t is ToolCallResponse => typeof t !== 'undefined')
|
||||||
|
|
||||||
|
return await parseAndCallTools(
|
||||||
|
mcpToolResponses,
|
||||||
|
toolResponses,
|
||||||
|
onChunk,
|
||||||
|
this.mcpToolCallResponseToMessage,
|
||||||
|
model,
|
||||||
|
mcpTools
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const processToolUses = async (content: string) => {
|
||||||
|
return await parseAndCallTools(
|
||||||
|
content,
|
||||||
|
toolResponses,
|
||||||
|
onChunk,
|
||||||
|
this.mcpToolCallResponseToMessage,
|
||||||
|
model,
|
||||||
|
mcpTools
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const processStream = async (
|
||||||
|
stream: AsyncGenerator<GenerateContentResponse> | GenerateContentResponse,
|
||||||
|
idx: number
|
||||||
|
) => {
|
||||||
|
history.push(messageContents)
|
||||||
|
|
||||||
|
let functionCalls: FunctionCall[] = []
|
||||||
|
|
||||||
|
if (stream instanceof GenerateContentResponse) {
|
||||||
|
let content = ''
|
||||||
|
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||||
|
|
||||||
|
const toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
|
||||||
|
if (stream.text?.length) {
|
||||||
|
toolResults.push(...(await processToolUses(stream.text)))
|
||||||
|
}
|
||||||
|
stream.candidates?.forEach((candidate) => {
|
||||||
|
if (candidate.content) {
|
||||||
|
history.push(candidate.content)
|
||||||
|
|
||||||
|
candidate.content.parts?.forEach((part) => {
|
||||||
|
if (part.functionCall) {
|
||||||
|
functionCalls.push(part.functionCall)
|
||||||
|
}
|
||||||
|
if (part.text) {
|
||||||
|
content += part.text
|
||||||
|
onChunk({ type: ChunkType.TEXT_DELTA, text: part.text })
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if (content.length) {
|
||||||
|
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
|
||||||
|
}
|
||||||
|
if (functionCalls.length) {
|
||||||
|
toolResults.push(...(await processToolCalls(functionCalls)))
|
||||||
|
}
|
||||||
|
if (stream.text?.length) {
|
||||||
|
toolResults.push(...(await processToolUses(stream.text)))
|
||||||
|
}
|
||||||
|
if (toolResults.length) {
|
||||||
|
await processToolResults(toolResults, idx)
|
||||||
|
}
|
||||||
|
onChunk({
|
||||||
|
type: ChunkType.BLOCK_COMPLETE,
|
||||||
|
response: {
|
||||||
|
text: stream.text,
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: stream.usageMetadata?.promptTokenCount || 0,
|
||||||
|
thoughts_tokens: stream.usageMetadata?.thoughtsTokenCount || 0,
|
||||||
|
completion_tokens: stream.usageMetadata?.candidatesTokenCount || 0,
|
||||||
|
total_tokens: stream.usageMetadata?.totalTokenCount || 0
|
||||||
|
},
|
||||||
|
metrics: {
|
||||||
|
completion_tokens: stream.usageMetadata?.candidatesTokenCount,
|
||||||
|
time_completion_millsec,
|
||||||
|
time_first_token_millsec: 0
|
||||||
|
},
|
||||||
|
webSearch: {
|
||||||
|
results: stream.candidates?.[0]?.groundingMetadata,
|
||||||
|
source: 'gemini'
|
||||||
|
}
|
||||||
|
} as Response
|
||||||
|
} as BlockCompleteChunk)
|
||||||
|
} else {
|
||||||
|
let content = ''
|
||||||
|
let final_time_completion_millsec = 0
|
||||||
|
let lastUsage: Usage | undefined = undefined
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
||||||
|
|
||||||
|
// --- Calculate Metrics ---
|
||||||
|
if (time_first_token_millsec == 0 && chunk.text !== undefined) {
|
||||||
|
// Update based on text arrival
|
||||||
|
time_first_token_millsec = new Date().getTime() - start_time_millsec
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. Text Content
|
||||||
|
if (chunk.text !== undefined) {
|
||||||
|
content += chunk.text
|
||||||
|
onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text })
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Usage Data
|
||||||
|
if (chunk.usageMetadata) {
|
||||||
|
lastUsage = {
|
||||||
|
prompt_tokens: chunk.usageMetadata.promptTokenCount || 0,
|
||||||
|
completion_tokens: chunk.usageMetadata.candidatesTokenCount || 0,
|
||||||
|
total_tokens: chunk.usageMetadata.totalTokenCount || 0
|
||||||
|
}
|
||||||
|
final_time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Image Generation
|
||||||
|
const generateImage = this.processGeminiImageResponse(chunk, onChunk)
|
||||||
|
if (generateImage?.images?.length) {
|
||||||
|
onChunk({ type: ChunkType.IMAGE_COMPLETE, image: generateImage })
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chunk.candidates?.[0]?.finishReason) {
|
||||||
|
if (chunk.text) {
|
||||||
|
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
|
||||||
|
}
|
||||||
|
if (chunk.candidates?.[0]?.groundingMetadata) {
|
||||||
|
// 3. Grounding/Search Metadata
|
||||||
|
const groundingMetadata = chunk.candidates?.[0]?.groundingMetadata
|
||||||
|
onChunk({
|
||||||
|
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||||
|
llm_web_search: {
|
||||||
|
results: groundingMetadata,
|
||||||
|
source: WebSearchSource.GEMINI
|
||||||
|
}
|
||||||
|
} as LLMWebSearchCompleteChunk)
|
||||||
|
}
|
||||||
|
if (chunk.functionCalls) {
|
||||||
|
chunk.candidates?.forEach((candidate) => {
|
||||||
|
if (candidate.content) {
|
||||||
|
history.push(candidate.content)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
functionCalls = functionCalls.concat(chunk.functionCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
onChunk({
|
||||||
|
type: ChunkType.BLOCK_COMPLETE,
|
||||||
|
response: {
|
||||||
|
metrics: {
|
||||||
|
completion_tokens: lastUsage?.completion_tokens,
|
||||||
|
time_completion_millsec: final_time_completion_millsec,
|
||||||
|
time_first_token_millsec
|
||||||
|
},
|
||||||
|
usage: lastUsage
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- End Incremental onChunk calls ---
|
||||||
|
|
||||||
|
// Call processToolUses AFTER potentially processing text content in this chunk
|
||||||
|
// This assumes tools might be specified within the text stream
|
||||||
|
// Note: parseAndCallTools inside should handle its own onChunk for tool responses
|
||||||
|
let toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
|
||||||
|
if (functionCalls.length) {
|
||||||
|
toolResults = await processToolCalls(functionCalls)
|
||||||
|
}
|
||||||
|
if (content.length) {
|
||||||
|
toolResults = toolResults.concat(await processToolUses(content))
|
||||||
|
}
|
||||||
|
if (toolResults.length) {
|
||||||
|
await processToolResults(toolResults, idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (!streamOutput) {
|
if (!streamOutput) {
|
||||||
const response = await chat.sendMessage({
|
const response = await chat.sendMessage({
|
||||||
message: messageContents as PartUnion,
|
message: messageContents as PartUnion,
|
||||||
@ -359,32 +591,10 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
abortSignal: abortController.signal
|
abortSignal: abortController.signal
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||||
onChunk({
|
return await processStream(response, 0).then(cleanup)
|
||||||
type: ChunkType.BLOCK_COMPLETE,
|
|
||||||
response: {
|
|
||||||
text: response.text,
|
|
||||||
usage: {
|
|
||||||
prompt_tokens: response.usageMetadata?.promptTokenCount || 0,
|
|
||||||
thoughts_tokens: response.usageMetadata?.thoughtsTokenCount || 0,
|
|
||||||
completion_tokens: response.usageMetadata?.candidatesTokenCount || 0,
|
|
||||||
total_tokens: response.usageMetadata?.totalTokenCount || 0
|
|
||||||
},
|
|
||||||
metrics: {
|
|
||||||
completion_tokens: response.usageMetadata?.candidatesTokenCount,
|
|
||||||
time_completion_millsec,
|
|
||||||
time_first_token_millsec: 0
|
|
||||||
},
|
|
||||||
webSearch: {
|
|
||||||
results: response.candidates?.[0]?.groundingMetadata,
|
|
||||||
source: 'gemini'
|
|
||||||
}
|
|
||||||
} as Response
|
|
||||||
} as BlockCompleteChunk)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 等待接口返回流
|
|
||||||
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||||
const userMessagesStream = await chat.sendMessageStream({
|
const userMessagesStream = await chat.sendMessageStream({
|
||||||
message: messageContents as PartUnion,
|
message: messageContents as PartUnion,
|
||||||
@ -394,105 +604,6 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
const processToolUses = async (content: string, idx: number) => {
|
|
||||||
const toolResults = await parseAndCallTools(
|
|
||||||
content,
|
|
||||||
toolResponses,
|
|
||||||
onChunk,
|
|
||||||
idx,
|
|
||||||
mcpToolCallResponseToGeminiMessage,
|
|
||||||
mcpTools,
|
|
||||||
isVisionModel(model)
|
|
||||||
)
|
|
||||||
if (toolResults && toolResults.length > 0) {
|
|
||||||
history.push(messageContents)
|
|
||||||
const newChat = this.sdk.chats.create({
|
|
||||||
model: model.id,
|
|
||||||
config: generateContentConfig,
|
|
||||||
history: history as Content[]
|
|
||||||
})
|
|
||||||
const newStream = await newChat.sendMessageStream({
|
|
||||||
message: flatten(toolResults.map((ts) => (ts as Content).parts)) as PartUnion,
|
|
||||||
config: {
|
|
||||||
...generateContentConfig,
|
|
||||||
abortSignal: abortController.signal
|
|
||||||
}
|
|
||||||
})
|
|
||||||
await processStream(newStream, idx + 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const processStream = async (stream: AsyncGenerator<GenerateContentResponse>, idx: number) => {
|
|
||||||
let content = ''
|
|
||||||
let final_time_completion_millsec = 0
|
|
||||||
let lastUsage: Usage | undefined = undefined
|
|
||||||
for await (const chunk of stream) {
|
|
||||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
|
||||||
|
|
||||||
// --- Calculate Metrics ---
|
|
||||||
if (time_first_token_millsec == 0 && chunk.text !== undefined) {
|
|
||||||
// Update based on text arrival
|
|
||||||
time_first_token_millsec = new Date().getTime() - start_time_millsec
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1. Text Content
|
|
||||||
if (chunk.text !== undefined) {
|
|
||||||
content += chunk.text
|
|
||||||
onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text })
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. Usage Data
|
|
||||||
if (chunk.usageMetadata) {
|
|
||||||
lastUsage = {
|
|
||||||
prompt_tokens: chunk.usageMetadata.promptTokenCount || 0,
|
|
||||||
completion_tokens: chunk.usageMetadata.candidatesTokenCount || 0,
|
|
||||||
total_tokens: chunk.usageMetadata.totalTokenCount || 0
|
|
||||||
}
|
|
||||||
final_time_completion_millsec = new Date().getTime() - start_time_millsec
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. Image Generation
|
|
||||||
const generateImage = this.processGeminiImageResponse(chunk, onChunk)
|
|
||||||
if (generateImage?.images?.length) {
|
|
||||||
onChunk({ type: ChunkType.IMAGE_COMPLETE, image: generateImage })
|
|
||||||
}
|
|
||||||
|
|
||||||
if (chunk.candidates?.[0]?.finishReason) {
|
|
||||||
if (chunk.text) {
|
|
||||||
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
|
|
||||||
}
|
|
||||||
if (chunk.candidates?.[0]?.groundingMetadata) {
|
|
||||||
// 3. Grounding/Search Metadata
|
|
||||||
const groundingMetadata = chunk.candidates?.[0]?.groundingMetadata
|
|
||||||
onChunk({
|
|
||||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
|
||||||
llm_web_search: {
|
|
||||||
results: groundingMetadata,
|
|
||||||
source: WebSearchSource.GEMINI
|
|
||||||
}
|
|
||||||
} as LLMWebSearchCompleteChunk)
|
|
||||||
}
|
|
||||||
onChunk({
|
|
||||||
type: ChunkType.BLOCK_COMPLETE,
|
|
||||||
response: {
|
|
||||||
metrics: {
|
|
||||||
completion_tokens: lastUsage?.completion_tokens,
|
|
||||||
time_completion_millsec: final_time_completion_millsec,
|
|
||||||
time_first_token_millsec
|
|
||||||
},
|
|
||||||
usage: lastUsage
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
// --- End Incremental onChunk calls ---
|
|
||||||
|
|
||||||
// Call processToolUses AFTER potentially processing text content in this chunk
|
|
||||||
// This assumes tools might be specified within the text stream
|
|
||||||
// Note: parseAndCallTools inside should handle its own onChunk for tool responses
|
|
||||||
await processToolUses(content, idx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
await processStream(userMessagesStream, 0).finally(cleanup)
|
await processStream(userMessagesStream, 0).finally(cleanup)
|
||||||
|
|
||||||
const final_time_completion_millsec = new Date().getTime() - start_time_millsec
|
const final_time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||||
@ -841,4 +952,32 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
public generateImageByChat(): Promise<void> {
|
public generateImageByChat(): Promise<void> {
|
||||||
throw new Error('Method not implemented.')
|
throw new Error('Method not implemented.')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public convertMcpTools<T>(mcpTools: MCPTool[]): T[] {
|
||||||
|
return mcpToolsToGeminiTools(mcpTools) as T[]
|
||||||
|
}
|
||||||
|
|
||||||
|
public mcpToolCallResponseToMessage = (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => {
|
||||||
|
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||||
|
return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||||
|
} else if ('toolCallId' in mcpToolResponse) {
|
||||||
|
const toolCallOut = {
|
||||||
|
role: 'user',
|
||||||
|
parts: [
|
||||||
|
{
|
||||||
|
functionResponse: {
|
||||||
|
id: mcpToolResponse.toolCallId,
|
||||||
|
name: mcpToolResponse.tool.id,
|
||||||
|
response: {
|
||||||
|
output: !resp.isError ? resp.content : undefined,
|
||||||
|
error: resp.isError ? resp.content : undefined
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
} satisfies Content
|
||||||
|
return toolCallOut
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -31,10 +31,13 @@ import {
|
|||||||
Assistant,
|
Assistant,
|
||||||
EFFORT_RATIO,
|
EFFORT_RATIO,
|
||||||
FileTypes,
|
FileTypes,
|
||||||
|
MCPCallToolResponse,
|
||||||
|
MCPTool,
|
||||||
MCPToolResponse,
|
MCPToolResponse,
|
||||||
Model,
|
Model,
|
||||||
Provider,
|
Provider,
|
||||||
Suggestion,
|
Suggestion,
|
||||||
|
ToolCallResponse,
|
||||||
Usage,
|
Usage,
|
||||||
WebSearchSource
|
WebSearchSource
|
||||||
} from '@renderer/types'
|
} from '@renderer/types'
|
||||||
@ -48,7 +51,12 @@ import {
|
|||||||
convertLinksToOpenRouter,
|
convertLinksToOpenRouter,
|
||||||
convertLinksToZhipu
|
convertLinksToZhipu
|
||||||
} from '@renderer/utils/linkConverter'
|
} from '@renderer/utils/linkConverter'
|
||||||
import { mcpToolCallResponseToOpenAICompatibleMessage, parseAndCallTools } from '@renderer/utils/mcp-tools'
|
import {
|
||||||
|
mcpToolCallResponseToOpenAICompatibleMessage,
|
||||||
|
mcpToolsToOpenAIChatTools,
|
||||||
|
openAIToolsToMcpTool,
|
||||||
|
parseAndCallTools
|
||||||
|
} from '@renderer/utils/mcp-tools'
|
||||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||||
import { asyncGeneratorToReadableStream, readableStreamAsyncIterable } from '@renderer/utils/stream'
|
import { asyncGeneratorToReadableStream, readableStreamAsyncIterable } from '@renderer/utils/stream'
|
||||||
@ -57,18 +65,22 @@ import OpenAI, { AzureOpenAI } from 'openai'
|
|||||||
import {
|
import {
|
||||||
ChatCompletionContentPart,
|
ChatCompletionContentPart,
|
||||||
ChatCompletionCreateParamsNonStreaming,
|
ChatCompletionCreateParamsNonStreaming,
|
||||||
ChatCompletionMessageParam
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
ChatCompletionTool,
|
||||||
|
ChatCompletionToolMessageParam
|
||||||
} from 'openai/resources'
|
} from 'openai/resources'
|
||||||
|
|
||||||
import { CompletionsParams } from '.'
|
import { CompletionsParams } from '.'
|
||||||
import OpenAIProvider from './OpenAIProvider'
|
import { BaseOpenAiProvider } from './OpenAIProvider'
|
||||||
|
|
||||||
// 1. 定义联合类型
|
// 1. 定义联合类型
|
||||||
export type OpenAIStreamChunk =
|
export type OpenAIStreamChunk =
|
||||||
| { type: 'reasoning' | 'text-delta'; textDelta: string }
|
| { type: 'reasoning' | 'text-delta'; textDelta: string }
|
||||||
|
| { type: 'tool-calls'; delta: any }
|
||||||
| { type: 'finish'; finishReason: any; usage: any; delta: any; chunk: any }
|
| { type: 'finish'; finishReason: any; usage: any; delta: any; chunk: any }
|
||||||
|
|
||||||
export default class OpenAICompatibleProvider extends OpenAIProvider {
|
export default class OpenAICompatibleProvider extends BaseOpenAiProvider {
|
||||||
constructor(provider: Provider) {
|
constructor(provider: Provider) {
|
||||||
super(provider)
|
super(provider)
|
||||||
|
|
||||||
@ -313,6 +325,24 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
|
|||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public convertMcpTools<T>(mcpTools: MCPTool[]): T[] {
|
||||||
|
return mcpToolsToOpenAIChatTools(mcpTools) as T[]
|
||||||
|
}
|
||||||
|
|
||||||
|
public mcpToolCallResponseToMessage = (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => {
|
||||||
|
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||||
|
return mcpToolCallResponseToOpenAICompatibleMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||||
|
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
|
||||||
|
const toolCallOut: ChatCompletionToolMessageParam = {
|
||||||
|
role: 'tool',
|
||||||
|
tool_call_id: mcpToolResponse.toolCallId,
|
||||||
|
content: JSON.stringify(resp.content)
|
||||||
|
}
|
||||||
|
return toolCallOut
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate completions for the assistant
|
* Generate completions for the assistant
|
||||||
* @param messages - The messages
|
* @param messages - The messages
|
||||||
@ -330,7 +360,7 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
|
|||||||
const defaultModel = getDefaultModel()
|
const defaultModel = getDefaultModel()
|
||||||
const model = assistant.model || defaultModel
|
const model = assistant.model || defaultModel
|
||||||
|
|
||||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant)
|
||||||
const isEnabledBultinWebSearch = assistant.enableWebSearch
|
const isEnabledBultinWebSearch = assistant.enableWebSearch
|
||||||
messages = addImageFileToContents(messages)
|
messages = addImageFileToContents(messages)
|
||||||
const enableReasoning =
|
const enableReasoning =
|
||||||
@ -344,7 +374,9 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
|
|||||||
content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}`
|
content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}`
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (mcpTools && mcpTools.length > 0) {
|
const { tools } = this.setupToolsConfig<ChatCompletionTool>({ mcpTools, model, enableToolUse })
|
||||||
|
|
||||||
|
if (this.useSystemPromptForTools) {
|
||||||
systemMessage.content = buildSystemPrompt(systemMessage.content || '', mcpTools)
|
systemMessage.content = buildSystemPrompt(systemMessage.content || '', mcpTools)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -379,53 +411,86 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
|
|||||||
|
|
||||||
const toolResponses: MCPToolResponse[] = []
|
const toolResponses: MCPToolResponse[] = []
|
||||||
|
|
||||||
const processToolUses = async (content: string, idx: number) => {
|
const processToolResults = async (toolResults: Awaited<ReturnType<typeof parseAndCallTools>>, idx: number) => {
|
||||||
const toolResults = await parseAndCallTools(
|
if (toolResults.length === 0) return
|
||||||
|
|
||||||
|
toolResults.forEach((ts) => reqMessages.push(ts as ChatCompletionMessageParam))
|
||||||
|
|
||||||
|
console.debug('[tool] reqMessages before processing', model.id, reqMessages)
|
||||||
|
reqMessages = processReqMessages(model, reqMessages)
|
||||||
|
console.debug('[tool] reqMessages', model.id, reqMessages)
|
||||||
|
|
||||||
|
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||||
|
const newStream = await this.sdk.chat.completions
|
||||||
|
// @ts-ignore key is not typed
|
||||||
|
.create(
|
||||||
|
{
|
||||||
|
model: model.id,
|
||||||
|
messages: reqMessages,
|
||||||
|
temperature: this.getTemperature(assistant, model),
|
||||||
|
top_p: this.getTopP(assistant, model),
|
||||||
|
max_tokens: maxTokens,
|
||||||
|
keep_alive: this.keepAliveTime,
|
||||||
|
stream: isSupportStreamOutput(),
|
||||||
|
tools: !isEmpty(tools) ? tools : undefined,
|
||||||
|
...getOpenAIWebSearchParams(assistant, model),
|
||||||
|
...this.getReasoningEffort(assistant, model),
|
||||||
|
...this.getProviderSpecificParameters(assistant, model),
|
||||||
|
...this.getCustomParameters(assistant)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
signal
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await processStream(newStream, idx + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
const processToolCalls = async (mcpTools, toolCalls: ChatCompletionMessageToolCall[]) => {
|
||||||
|
const mcpToolResponses = toolCalls
|
||||||
|
.map((toolCall) => {
|
||||||
|
const mcpTool = openAIToolsToMcpTool(mcpTools, toolCall as ChatCompletionMessageToolCall)
|
||||||
|
if (!mcpTool) return undefined
|
||||||
|
|
||||||
|
const parsedArgs = (() => {
|
||||||
|
try {
|
||||||
|
return JSON.parse(toolCall.function.arguments)
|
||||||
|
} catch {
|
||||||
|
return toolCall.function.arguments
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
return {
|
||||||
|
id: toolCall.id,
|
||||||
|
toolCallId: toolCall.id,
|
||||||
|
tool: mcpTool,
|
||||||
|
arguments: parsedArgs,
|
||||||
|
status: 'pending'
|
||||||
|
} as ToolCallResponse
|
||||||
|
})
|
||||||
|
.filter((t): t is ToolCallResponse => typeof t !== 'undefined')
|
||||||
|
return await parseAndCallTools(
|
||||||
|
mcpToolResponses,
|
||||||
|
toolResponses,
|
||||||
|
onChunk,
|
||||||
|
this.mcpToolCallResponseToMessage,
|
||||||
|
model,
|
||||||
|
mcpTools
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const processToolUses = async (content: string) => {
|
||||||
|
return await parseAndCallTools(
|
||||||
content,
|
content,
|
||||||
toolResponses,
|
toolResponses,
|
||||||
onChunk,
|
onChunk,
|
||||||
idx,
|
this.mcpToolCallResponseToMessage,
|
||||||
mcpToolCallResponseToOpenAICompatibleMessage,
|
model,
|
||||||
mcpTools,
|
mcpTools
|
||||||
isVisionModel(model)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (toolResults.length > 0) {
|
|
||||||
reqMessages.push({
|
|
||||||
role: 'assistant',
|
|
||||||
content: content
|
|
||||||
} as ChatCompletionMessageParam)
|
|
||||||
toolResults.forEach((ts) => reqMessages.push(ts as ChatCompletionMessageParam))
|
|
||||||
|
|
||||||
reqMessages = processReqMessages(model, reqMessages)
|
|
||||||
const newStream = await this.sdk.chat.completions
|
|
||||||
// @ts-ignore key is not typed
|
|
||||||
.create(
|
|
||||||
{
|
|
||||||
model: model.id,
|
|
||||||
messages: reqMessages,
|
|
||||||
temperature: this.getTemperature(assistant, model),
|
|
||||||
top_p: this.getTopP(assistant, model),
|
|
||||||
max_tokens: maxTokens,
|
|
||||||
keep_alive: this.keepAliveTime,
|
|
||||||
stream: isSupportStreamOutput(),
|
|
||||||
// tools: tools,
|
|
||||||
service_tier: this.getServiceTier(model),
|
|
||||||
...getOpenAIWebSearchParams(assistant, model),
|
|
||||||
...this.getReasoningEffort(assistant, model),
|
|
||||||
...this.getProviderSpecificParameters(assistant, model),
|
|
||||||
...this.getCustomParameters(assistant)
|
|
||||||
},
|
|
||||||
{
|
|
||||||
signal,
|
|
||||||
timeout: this.getTimeout(model)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
await processStream(newStream, idx + 1)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const processStream = async (stream: any, idx: number) => {
|
const processStream = async (stream: any, idx: number) => {
|
||||||
|
const toolCalls: ChatCompletionMessageToolCall[] = []
|
||||||
// Handle non-streaming case (already returns early, no change needed here)
|
// Handle non-streaming case (already returns early, no change needed here)
|
||||||
if (!isSupportStreamOutput()) {
|
if (!isSupportStreamOutput()) {
|
||||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||||
@ -439,10 +504,59 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
|
|||||||
// Create a synthetic usage object if stream.usage is undefined
|
// Create a synthetic usage object if stream.usage is undefined
|
||||||
const finalUsage = stream.usage
|
const finalUsage = stream.usage
|
||||||
// Separate onChunk calls for text and usage/metrics
|
// Separate onChunk calls for text and usage/metrics
|
||||||
if (stream.choices[0].message?.content) {
|
let content = ''
|
||||||
onChunk({ type: ChunkType.TEXT_COMPLETE, text: stream.choices[0].message.content })
|
stream.choices.forEach((choice) => {
|
||||||
|
// reasoning
|
||||||
|
if (choice.message.reasoning) {
|
||||||
|
onChunk({ type: ChunkType.THINKING_DELTA, text: choice.message.reasoning })
|
||||||
|
onChunk({
|
||||||
|
type: ChunkType.THINKING_COMPLETE,
|
||||||
|
text: choice.message.reasoning,
|
||||||
|
thinking_millsec: time_completion_millsec
|
||||||
|
})
|
||||||
|
}
|
||||||
|
// text
|
||||||
|
if (choice.message.content) {
|
||||||
|
content += choice.message.content
|
||||||
|
onChunk({ type: ChunkType.TEXT_DELTA, text: choice.message.content })
|
||||||
|
}
|
||||||
|
// tool call
|
||||||
|
if (choice.message.tool_calls && choice.message.tool_calls.length) {
|
||||||
|
choice.message.tool_calls.forEach((t) => toolCalls.push(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
reqMessages.push({
|
||||||
|
role: choice.message.role,
|
||||||
|
content: choice.message.content,
|
||||||
|
tool_calls: toolCalls.length
|
||||||
|
? toolCalls.map((toolCall) => ({
|
||||||
|
id: toolCall.id,
|
||||||
|
function: {
|
||||||
|
...toolCall.function,
|
||||||
|
arguments:
|
||||||
|
typeof toolCall.function.arguments === 'string'
|
||||||
|
? toolCall.function.arguments
|
||||||
|
: JSON.stringify(toolCall.function.arguments)
|
||||||
|
},
|
||||||
|
type: 'function'
|
||||||
|
}))
|
||||||
|
: undefined
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
if (content.length) {
|
||||||
|
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
|
||||||
|
if (toolCalls.length) {
|
||||||
|
toolResults.push(...(await processToolCalls(mcpTools, toolCalls)))
|
||||||
|
}
|
||||||
|
if (stream.choices[0].message?.content) {
|
||||||
|
toolResults.push(...(await processToolUses(stream.choices[0].message?.content)))
|
||||||
|
}
|
||||||
|
await processToolResults(toolResults, idx)
|
||||||
|
|
||||||
// Always send usage and metrics data
|
// Always send usage and metrics data
|
||||||
onChunk({ type: ChunkType.BLOCK_COMPLETE, response: { usage: finalUsage, metrics: finalMetrics } })
|
onChunk({ type: ChunkType.BLOCK_COMPLETE, response: { usage: finalUsage, metrics: finalMetrics } })
|
||||||
return
|
return
|
||||||
@ -486,6 +600,9 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
|
|||||||
if (delta?.content) {
|
if (delta?.content) {
|
||||||
yield { type: 'text-delta', textDelta: delta.content }
|
yield { type: 'text-delta', textDelta: delta.content }
|
||||||
}
|
}
|
||||||
|
if (delta?.tool_calls) {
|
||||||
|
yield { type: 'tool-calls', delta: delta }
|
||||||
|
}
|
||||||
|
|
||||||
const finishReason = chunk.choices[0]?.finish_reason
|
const finishReason = chunk.choices[0]?.finish_reason
|
||||||
if (!isEmpty(finishReason)) {
|
if (!isEmpty(finishReason)) {
|
||||||
@ -563,6 +680,25 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
|
|||||||
onChunk({ type: ChunkType.TEXT_DELTA, text: textDelta })
|
onChunk({ type: ChunkType.TEXT_DELTA, text: textDelta })
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
case 'tool-calls': {
|
||||||
|
chunk.delta.tool_calls.forEach((toolCall) => {
|
||||||
|
const { id, index, type, function: fun } = toolCall
|
||||||
|
if (id && type === 'function' && fun) {
|
||||||
|
const { name, arguments: args } = fun
|
||||||
|
toolCalls.push({
|
||||||
|
id,
|
||||||
|
function: {
|
||||||
|
name: name || '',
|
||||||
|
arguments: args || ''
|
||||||
|
},
|
||||||
|
type: 'function'
|
||||||
|
})
|
||||||
|
} else if (fun?.arguments) {
|
||||||
|
toolCalls[index].function.arguments += fun.arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
case 'finish': {
|
case 'finish': {
|
||||||
const finishReason = chunk.finishReason
|
const finishReason = chunk.finishReason
|
||||||
const usage = chunk.usage
|
const usage = chunk.usage
|
||||||
@ -624,7 +760,33 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
|
|||||||
} as LLMWebSearchCompleteChunk)
|
} as LLMWebSearchCompleteChunk)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
await processToolUses(content, idx)
|
reqMessages.push({
|
||||||
|
role: 'assistant',
|
||||||
|
content: content,
|
||||||
|
tool_calls: toolCalls.length
|
||||||
|
? toolCalls.map((toolCall) => ({
|
||||||
|
id: toolCall.id,
|
||||||
|
function: {
|
||||||
|
...toolCall.function,
|
||||||
|
arguments:
|
||||||
|
typeof toolCall.function.arguments === 'string'
|
||||||
|
? toolCall.function.arguments
|
||||||
|
: JSON.stringify(toolCall.function.arguments)
|
||||||
|
},
|
||||||
|
type: 'function'
|
||||||
|
}))
|
||||||
|
: undefined
|
||||||
|
})
|
||||||
|
let toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
|
||||||
|
if (toolCalls.length) {
|
||||||
|
toolResults = await processToolCalls(mcpTools, toolCalls)
|
||||||
|
}
|
||||||
|
if (content.length) {
|
||||||
|
toolResults = toolResults.concat(await processToolUses(content))
|
||||||
|
}
|
||||||
|
if (toolResults.length) {
|
||||||
|
await processToolResults(toolResults, idx)
|
||||||
|
}
|
||||||
onChunk({
|
onChunk({
|
||||||
type: ChunkType.BLOCK_COMPLETE,
|
type: ChunkType.BLOCK_COMPLETE,
|
||||||
response: {
|
response: {
|
||||||
@ -657,7 +819,7 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
|
|||||||
max_tokens: maxTokens,
|
max_tokens: maxTokens,
|
||||||
keep_alive: this.keepAliveTime,
|
keep_alive: this.keepAliveTime,
|
||||||
stream: isSupportStreamOutput(),
|
stream: isSupportStreamOutput(),
|
||||||
// tools: tools,
|
tools: !isEmpty(tools) ? tools : undefined,
|
||||||
service_tier: this.getServiceTier(model),
|
service_tier: this.getServiceTier(model),
|
||||||
...getOpenAIWebSearchParams(assistant, model),
|
...getOpenAIWebSearchParams(assistant, model),
|
||||||
...this.getReasoningEffort(assistant, model),
|
...this.getReasoningEffort(assistant, model),
|
||||||
|
|||||||
@ -21,10 +21,13 @@ import {
|
|||||||
Assistant,
|
Assistant,
|
||||||
FileTypes,
|
FileTypes,
|
||||||
GenerateImageParams,
|
GenerateImageParams,
|
||||||
|
MCPCallToolResponse,
|
||||||
|
MCPTool,
|
||||||
MCPToolResponse,
|
MCPToolResponse,
|
||||||
Model,
|
Model,
|
||||||
Provider,
|
Provider,
|
||||||
Suggestion,
|
Suggestion,
|
||||||
|
ToolCallResponse,
|
||||||
Usage,
|
Usage,
|
||||||
WebSearchSource
|
WebSearchSource
|
||||||
} from '@renderer/types'
|
} from '@renderer/types'
|
||||||
@ -33,7 +36,12 @@ import { Message } from '@renderer/types/newMessage'
|
|||||||
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
||||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||||
import { convertLinks } from '@renderer/utils/linkConverter'
|
import { convertLinks } from '@renderer/utils/linkConverter'
|
||||||
import { mcpToolCallResponseToOpenAIMessage, parseAndCallTools } from '@renderer/utils/mcp-tools'
|
import {
|
||||||
|
mcpToolCallResponseToOpenAIMessage,
|
||||||
|
mcpToolsToOpenAIResponseTools,
|
||||||
|
openAIToolsToMcpTool,
|
||||||
|
parseAndCallTools
|
||||||
|
} from '@renderer/utils/mcp-tools'
|
||||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||||
import { isEmpty, takeRight } from 'lodash'
|
import { isEmpty, takeRight } from 'lodash'
|
||||||
@ -45,7 +53,7 @@ import { FileLike, toFile } from 'openai/uploads'
|
|||||||
import { CompletionsParams } from '.'
|
import { CompletionsParams } from '.'
|
||||||
import BaseProvider from './BaseProvider'
|
import BaseProvider from './BaseProvider'
|
||||||
|
|
||||||
export default class OpenAIProvider extends BaseProvider {
|
export abstract class BaseOpenAiProvider extends BaseProvider {
|
||||||
protected sdk: OpenAI
|
protected sdk: OpenAI
|
||||||
|
|
||||||
constructor(provider: Provider) {
|
constructor(provider: Provider) {
|
||||||
@ -61,6 +69,14 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
abstract convertMcpTools<T>(mcpTools: MCPTool[]): T[]
|
||||||
|
|
||||||
|
abstract mcpToolCallResponseToMessage: (
|
||||||
|
mcpToolResponse: MCPToolResponse,
|
||||||
|
resp: MCPCallToolResponse,
|
||||||
|
model: Model
|
||||||
|
) => OpenAI.Responses.ResponseInputItem | ChatCompletionMessageParam | undefined
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Extract the file content from the message
|
* Extract the file content from the message
|
||||||
* @param message - The message
|
* @param message - The message
|
||||||
@ -91,16 +107,23 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
return ''
|
return ''
|
||||||
}
|
}
|
||||||
|
|
||||||
private async getReponseMessageParam(message: Message, model: Model): Promise<OpenAI.Responses.EasyInputMessage> {
|
private async getReponseMessageParam(message: Message, model: Model): Promise<OpenAI.Responses.ResponseInputItem> {
|
||||||
const isVision = isVisionModel(model)
|
const isVision = isVisionModel(model)
|
||||||
const content = await this.getMessageContent(message)
|
const content = await this.getMessageContent(message)
|
||||||
const fileBlocks = findFileBlocks(message)
|
const fileBlocks = findFileBlocks(message)
|
||||||
const imageBlocks = findImageBlocks(message)
|
const imageBlocks = findImageBlocks(message)
|
||||||
|
|
||||||
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
|
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
|
||||||
return {
|
if (message.role === 'assistant') {
|
||||||
role: message.role === 'system' ? 'user' : message.role,
|
return {
|
||||||
content: content ? [{ type: 'input_text', text: content }] : []
|
role: 'assistant',
|
||||||
|
content: content
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return {
|
||||||
|
role: message.role === 'system' ? 'user' : message.role,
|
||||||
|
content: content ? [{ type: 'input_text', text: content }] : []
|
||||||
|
} as OpenAI.Responses.EasyInputMessage
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -285,10 +308,8 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
}
|
}
|
||||||
const defaultModel = getDefaultModel()
|
const defaultModel = getDefaultModel()
|
||||||
const model = assistant.model || defaultModel
|
const model = assistant.model || defaultModel
|
||||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant)
|
||||||
|
|
||||||
const isEnabledBuiltinWebSearch = assistant.enableWebSearch
|
const isEnabledBuiltinWebSearch = assistant.enableWebSearch
|
||||||
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
|
||||||
// 退回到 OpenAI 兼容模式
|
// 退回到 OpenAI 兼容模式
|
||||||
if (isOpenAIWebSearch(model)) {
|
if (isOpenAIWebSearch(model)) {
|
||||||
const systemMessage = { role: 'system', content: assistant.prompt || '' }
|
const systemMessage = { role: 'system', content: assistant.prompt || '' }
|
||||||
@ -387,7 +408,7 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
const tools: OpenAI.Responses.Tool[] = []
|
let tools: OpenAI.Responses.Tool[] = []
|
||||||
const toolChoices: OpenAI.Responses.ToolChoiceTypes = {
|
const toolChoices: OpenAI.Responses.ToolChoiceTypes = {
|
||||||
type: 'web_search_preview'
|
type: 'web_search_preview'
|
||||||
}
|
}
|
||||||
@ -411,7 +432,15 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
systemMessage.role = 'developer'
|
systemMessage.role = 'developer'
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mcpTools && mcpTools.length > 0) {
|
const { tools: extraTools } = this.setupToolsConfig<OpenAI.Responses.Tool>({
|
||||||
|
mcpTools,
|
||||||
|
model,
|
||||||
|
enableToolUse
|
||||||
|
})
|
||||||
|
|
||||||
|
tools = tools.concat(extraTools)
|
||||||
|
|
||||||
|
if (this.useSystemPromptForTools) {
|
||||||
systemMessageInput.text = buildSystemPrompt(systemMessageInput.text || '', mcpTools)
|
systemMessageInput.text = buildSystemPrompt(systemMessageInput.text || '', mcpTools)
|
||||||
}
|
}
|
||||||
systemMessageContent.push(systemMessageInput)
|
systemMessageContent.push(systemMessageInput)
|
||||||
@ -421,7 +450,7 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
)
|
)
|
||||||
|
|
||||||
onFilterMessages(_messages)
|
onFilterMessages(_messages)
|
||||||
const userMessage: OpenAI.Responses.EasyInputMessage[] = []
|
const userMessage: OpenAI.Responses.ResponseInputItem[] = []
|
||||||
for (const message of _messages) {
|
for (const message of _messages) {
|
||||||
userMessage.push(await this.getReponseMessageParam(message, model))
|
userMessage.push(await this.getReponseMessageParam(message, model))
|
||||||
}
|
}
|
||||||
@ -434,7 +463,7 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
const { signal } = abortController
|
const { signal } = abortController
|
||||||
|
|
||||||
// 当 systemMessage 内容为空时不发送 systemMessage
|
// 当 systemMessage 内容为空时不发送 systemMessage
|
||||||
let reqMessages: OpenAI.Responses.EasyInputMessage[]
|
let reqMessages: OpenAI.Responses.ResponseInput
|
||||||
if (!systemMessage.content) {
|
if (!systemMessage.content) {
|
||||||
reqMessages = [...userMessage]
|
reqMessages = [...userMessage]
|
||||||
} else {
|
} else {
|
||||||
@ -443,48 +472,84 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
|
|
||||||
const toolResponses: MCPToolResponse[] = []
|
const toolResponses: MCPToolResponse[] = []
|
||||||
|
|
||||||
const processToolUses = async (content: string, idx: number) => {
|
const processToolResults = async (toolResults: Awaited<ReturnType<typeof parseAndCallTools>>, idx: number) => {
|
||||||
const toolResults = await parseAndCallTools(
|
if (toolResults.length === 0) return
|
||||||
|
|
||||||
|
toolResults.forEach((ts) => reqMessages.push(ts as OpenAI.Responses.EasyInputMessage))
|
||||||
|
|
||||||
|
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||||
|
const stream = await this.sdk.responses.create(
|
||||||
|
{
|
||||||
|
model: model.id,
|
||||||
|
input: reqMessages,
|
||||||
|
temperature: this.getTemperature(assistant, model),
|
||||||
|
top_p: this.getTopP(assistant, model),
|
||||||
|
max_output_tokens: maxTokens,
|
||||||
|
stream: streamOutput,
|
||||||
|
tools: !isEmpty(tools) ? tools : undefined,
|
||||||
|
service_tier: this.getServiceTier(model),
|
||||||
|
...this.getResponseReasoningEffort(assistant, model),
|
||||||
|
...this.getCustomParameters(assistant)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
signal,
|
||||||
|
timeout: this.getTimeout(model)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await processStream(stream, idx + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
const processToolCalls = async (mcpTools, toolCalls: OpenAI.Responses.ResponseFunctionToolCall[]) => {
|
||||||
|
const mcpToolResponses = toolCalls
|
||||||
|
.map((toolCall) => {
|
||||||
|
const mcpTool = openAIToolsToMcpTool(mcpTools, toolCall as OpenAI.Responses.ResponseFunctionToolCall)
|
||||||
|
if (!mcpTool) return undefined
|
||||||
|
|
||||||
|
const parsedArgs = (() => {
|
||||||
|
try {
|
||||||
|
return JSON.parse(toolCall.arguments)
|
||||||
|
} catch {
|
||||||
|
return toolCall.arguments
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
return {
|
||||||
|
id: toolCall.call_id,
|
||||||
|
toolCallId: toolCall.call_id,
|
||||||
|
tool: mcpTool,
|
||||||
|
arguments: parsedArgs,
|
||||||
|
status: 'pending'
|
||||||
|
} as ToolCallResponse
|
||||||
|
})
|
||||||
|
.filter((t): t is ToolCallResponse => typeof t !== 'undefined')
|
||||||
|
|
||||||
|
return await parseAndCallTools<OpenAI.Responses.ResponseInputItem | ChatCompletionMessageParam>(
|
||||||
|
mcpToolResponses,
|
||||||
|
toolResponses,
|
||||||
|
onChunk,
|
||||||
|
this.mcpToolCallResponseToMessage,
|
||||||
|
model,
|
||||||
|
mcpTools
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const processToolUses = async (content: string) => {
|
||||||
|
return await parseAndCallTools(
|
||||||
content,
|
content,
|
||||||
toolResponses,
|
toolResponses,
|
||||||
onChunk,
|
onChunk,
|
||||||
idx,
|
this.mcpToolCallResponseToMessage,
|
||||||
mcpToolCallResponseToOpenAIMessage,
|
model,
|
||||||
mcpTools,
|
mcpTools
|
||||||
isVisionModel(model)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (toolResults.length > 0) {
|
|
||||||
reqMessages.push({
|
|
||||||
role: 'assistant',
|
|
||||||
content: content
|
|
||||||
})
|
|
||||||
toolResults.forEach((ts) => reqMessages.push(ts as OpenAI.Responses.EasyInputMessage))
|
|
||||||
const newStream = await this.sdk.responses.create(
|
|
||||||
{
|
|
||||||
model: model.id,
|
|
||||||
input: reqMessages,
|
|
||||||
temperature: this.getTemperature(assistant, model),
|
|
||||||
top_p: this.getTopP(assistant, model),
|
|
||||||
max_output_tokens: maxTokens,
|
|
||||||
stream: true,
|
|
||||||
service_tier: this.getServiceTier(model),
|
|
||||||
...this.getResponseReasoningEffort(assistant, model),
|
|
||||||
...this.getCustomParameters(assistant)
|
|
||||||
},
|
|
||||||
{
|
|
||||||
signal,
|
|
||||||
timeout: this.getTimeout(model)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
await processStream(newStream, idx + 1)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const processStream = async (
|
const processStream = async (
|
||||||
stream: Stream<OpenAI.Responses.ResponseStreamEvent> | OpenAI.Responses.Response,
|
stream: Stream<OpenAI.Responses.ResponseStreamEvent> | OpenAI.Responses.Response,
|
||||||
idx: number
|
idx: number
|
||||||
) => {
|
) => {
|
||||||
|
const toolCalls: OpenAI.Responses.ResponseFunctionToolCall[] = []
|
||||||
|
|
||||||
if (!streamOutput) {
|
if (!streamOutput) {
|
||||||
const nonStream = stream as OpenAI.Responses.Response
|
const nonStream = stream as OpenAI.Responses.Response
|
||||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||||
@ -502,11 +567,15 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
prompt_tokens: nonStream.usage?.input_tokens || 0,
|
prompt_tokens: nonStream.usage?.input_tokens || 0,
|
||||||
total_tokens
|
total_tokens
|
||||||
}
|
}
|
||||||
|
let content = ''
|
||||||
|
|
||||||
for (const output of nonStream.output) {
|
for (const output of nonStream.output) {
|
||||||
switch (output.type) {
|
switch (output.type) {
|
||||||
case 'message':
|
case 'message':
|
||||||
if (output.content[0].type === 'output_text') {
|
if (output.content[0].type === 'output_text') {
|
||||||
|
onChunk({ type: ChunkType.TEXT_DELTA, text: output.content[0].text })
|
||||||
onChunk({ type: ChunkType.TEXT_COMPLETE, text: output.content[0].text })
|
onChunk({ type: ChunkType.TEXT_COMPLETE, text: output.content[0].text })
|
||||||
|
content += output.content[0].text
|
||||||
if (output.content[0].annotations && output.content[0].annotations.length > 0) {
|
if (output.content[0].annotations && output.content[0].annotations.length > 0) {
|
||||||
onChunk({
|
onChunk({
|
||||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||||
@ -525,8 +594,32 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
thinking_millsec: new Date().getTime() - start_time_millsec
|
thinking_millsec: new Date().getTime() - start_time_millsec
|
||||||
})
|
})
|
||||||
break
|
break
|
||||||
|
case 'function_call':
|
||||||
|
toolCalls.push(output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (content) {
|
||||||
|
reqMessages.push({
|
||||||
|
role: 'assistant',
|
||||||
|
content: content
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if (toolCalls.length) {
|
||||||
|
toolCalls.forEach((toolCall) => {
|
||||||
|
reqMessages.push(toolCall)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
|
||||||
|
if (toolCalls.length) {
|
||||||
|
toolResults.push(...(await processToolCalls(mcpTools, toolCalls)))
|
||||||
|
}
|
||||||
|
if (content.length) {
|
||||||
|
toolResults.push(...(await processToolUses(content)))
|
||||||
|
}
|
||||||
|
await processToolResults(toolResults, idx)
|
||||||
|
|
||||||
onChunk({
|
onChunk({
|
||||||
type: ChunkType.BLOCK_COMPLETE,
|
type: ChunkType.BLOCK_COMPLETE,
|
||||||
response: {
|
response: {
|
||||||
@ -537,6 +630,9 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
let content = ''
|
let content = ''
|
||||||
|
|
||||||
|
const outputItems: OpenAI.Responses.ResponseOutputItem[] = []
|
||||||
|
|
||||||
let lastUsage: Usage | undefined = undefined
|
let lastUsage: Usage | undefined = undefined
|
||||||
let final_time_completion_millsec_delta = 0
|
let final_time_completion_millsec_delta = 0
|
||||||
for await (const chunk of stream as Stream<OpenAI.Responses.ResponseStreamEvent>) {
|
for await (const chunk of stream as Stream<OpenAI.Responses.ResponseStreamEvent>) {
|
||||||
@ -547,6 +643,12 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
case 'response.created':
|
case 'response.created':
|
||||||
time_first_token_millsec = new Date().getTime()
|
time_first_token_millsec = new Date().getTime()
|
||||||
break
|
break
|
||||||
|
case 'response.output_item.added':
|
||||||
|
if (chunk.item.type === 'function_call') {
|
||||||
|
outputItems.push(chunk.item)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
|
||||||
case 'response.reasoning_summary_text.delta':
|
case 'response.reasoning_summary_text.delta':
|
||||||
onChunk({
|
onChunk({
|
||||||
type: ChunkType.THINKING_DELTA,
|
type: ChunkType.THINKING_DELTA,
|
||||||
@ -579,6 +681,21 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
text: content
|
text: content
|
||||||
})
|
})
|
||||||
break
|
break
|
||||||
|
case 'response.function_call_arguments.done': {
|
||||||
|
const outputItem: OpenAI.Responses.ResponseOutputItem | undefined = outputItems.find(
|
||||||
|
(item) => item.id === chunk.item_id
|
||||||
|
)
|
||||||
|
if (outputItem) {
|
||||||
|
if (outputItem.type === 'function_call') {
|
||||||
|
toolCalls.push({
|
||||||
|
...outputItem,
|
||||||
|
arguments: chunk.arguments
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
case 'response.content_part.done':
|
case 'response.content_part.done':
|
||||||
if (chunk.part.type === 'output_text' && chunk.part.annotations && chunk.part.annotations.length > 0) {
|
if (chunk.part.type === 'output_text' && chunk.part.annotations && chunk.part.annotations.length > 0) {
|
||||||
onChunk({
|
onChunk({
|
||||||
@ -615,9 +732,31 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
})
|
})
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- End of Incremental onChunk calls ---
|
||||||
|
} // End of for await loop
|
||||||
|
if (content) {
|
||||||
|
reqMessages.push({
|
||||||
|
role: 'assistant',
|
||||||
|
content: content
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if (toolCalls.length) {
|
||||||
|
toolCalls.forEach((toolCall) => {
|
||||||
|
reqMessages.push(toolCall)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
await processToolUses(content, idx)
|
// Call processToolUses AFTER the loop finishes processing the main stream content
|
||||||
|
// Note: parseAndCallTools inside processToolUses should handle its own onChunk for tool responses
|
||||||
|
const toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
|
||||||
|
if (toolCalls.length) {
|
||||||
|
toolResults.push(...(await processToolCalls(mcpTools, toolCalls)))
|
||||||
|
}
|
||||||
|
if (content) {
|
||||||
|
toolResults.push(...(await processToolUses(content)))
|
||||||
|
}
|
||||||
|
await processToolResults(toolResults, idx)
|
||||||
|
|
||||||
onChunk({
|
onChunk({
|
||||||
type: ChunkType.BLOCK_COMPLETE,
|
type: ChunkType.BLOCK_COMPLETE,
|
||||||
@ -632,6 +771,7 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||||
const stream = await this.sdk.responses.create(
|
const stream = await this.sdk.responses.create(
|
||||||
{
|
{
|
||||||
model: model.id,
|
model: model.id,
|
||||||
@ -1081,3 +1221,31 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
return data.data[0].embedding.length
|
return data.data[0].embedding.length
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export default class OpenAIProvider extends BaseOpenAiProvider {
|
||||||
|
constructor(provider: Provider) {
|
||||||
|
super(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
public convertMcpTools<T>(mcpTools: MCPTool[]) {
|
||||||
|
return mcpToolsToOpenAIResponseTools(mcpTools) as T[]
|
||||||
|
}
|
||||||
|
|
||||||
|
public mcpToolCallResponseToMessage = (
|
||||||
|
mcpToolResponse: MCPToolResponse,
|
||||||
|
resp: MCPCallToolResponse,
|
||||||
|
model: Model
|
||||||
|
): OpenAI.Responses.ResponseInputItem | undefined => {
|
||||||
|
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||||
|
return mcpToolCallResponseToOpenAIMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||||
|
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
|
||||||
|
const toolCallOut: OpenAI.Responses.ResponseInputItem = {
|
||||||
|
type: 'function_call_output',
|
||||||
|
call_id: mcpToolResponse.toolCallId,
|
||||||
|
output: JSON.stringify(resp.content)
|
||||||
|
}
|
||||||
|
return toolCallOut
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -107,6 +107,7 @@ export const getAssistantSettings = (assistant: Assistant): AssistantSettings =>
|
|||||||
enableMaxTokens: assistant?.settings?.enableMaxTokens ?? false,
|
enableMaxTokens: assistant?.settings?.enableMaxTokens ?? false,
|
||||||
maxTokens: getAssistantMaxTokens(),
|
maxTokens: getAssistantMaxTokens(),
|
||||||
streamOutput: assistant?.settings?.streamOutput ?? true,
|
streamOutput: assistant?.settings?.streamOutput ?? true,
|
||||||
|
enableToolUse: assistant?.settings?.enableToolUse ?? false,
|
||||||
hideMessages: assistant?.settings?.hideMessages ?? false,
|
hideMessages: assistant?.settings?.hideMessages ?? false,
|
||||||
defaultModel: assistant?.defaultModel ?? undefined,
|
defaultModel: assistant?.defaultModel ?? undefined,
|
||||||
customParameters: assistant?.settings?.customParameters ?? []
|
customParameters: assistant?.settings?.customParameters ?? []
|
||||||
|
|||||||
@ -46,7 +46,7 @@ const persistedReducer = persistReducer(
|
|||||||
{
|
{
|
||||||
key: 'cherry-studio',
|
key: 'cherry-studio',
|
||||||
storage,
|
storage,
|
||||||
version: 99,
|
version: 98,
|
||||||
blacklist: ['runtime', 'messages', 'messageBlocks'],
|
blacklist: ['runtime', 'messages', 'messageBlocks'],
|
||||||
migrate
|
migrate
|
||||||
},
|
},
|
||||||
|
|||||||
@ -476,16 +476,6 @@ export const INITIAL_PROVIDERS: Provider[] = [
|
|||||||
models: SYSTEM_MODELS.voyageai,
|
models: SYSTEM_MODELS.voyageai,
|
||||||
isSystem: true,
|
isSystem: true,
|
||||||
enabled: false
|
enabled: false
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'paratera',
|
|
||||||
name: 'Paratera AI',
|
|
||||||
type: 'openai-compatible',
|
|
||||||
apiKey: '',
|
|
||||||
apiHost: 'https://llmapi.paratera.com',
|
|
||||||
models: SYSTEM_MODELS.paratera,
|
|
||||||
isSystem: true,
|
|
||||||
enabled: false
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1248,15 +1248,6 @@ const migrateConfig = {
|
|||||||
provider.type = 'openai-compatible'
|
provider.type = 'openai-compatible'
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return state
|
|
||||||
} catch (error) {
|
|
||||||
return state
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'99': (state: RootState) => {
|
|
||||||
try {
|
|
||||||
addProvider(state, 'paratera')
|
|
||||||
|
|
||||||
return state
|
return state
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
return state
|
return state
|
||||||
|
|||||||
@ -427,7 +427,17 @@ const fetchAndProcessAssistantResponseImpl = async (
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
|
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
|
||||||
if (toolResponse.status === 'invoking') {
|
if (lastBlockType === MessageBlockType.UNKNOWN && lastBlockId) {
|
||||||
|
lastBlockType = MessageBlockType.TOOL
|
||||||
|
const changes = {
|
||||||
|
type: MessageBlockType.TOOL,
|
||||||
|
status: MessageBlockStatus.PROCESSING,
|
||||||
|
metadata: { rawMcpToolResponse: toolResponse }
|
||||||
|
}
|
||||||
|
dispatch(updateOneBlock({ id: lastBlockId, changes }))
|
||||||
|
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||||
|
toolCallIdToBlockIdMap.set(toolResponse.id, lastBlockId)
|
||||||
|
} else if (toolResponse.status === 'invoking') {
|
||||||
const toolBlock = createToolBlock(assistantMsgId, toolResponse.id, {
|
const toolBlock = createToolBlock(assistantMsgId, toolResponse.id, {
|
||||||
toolName: toolResponse.tool.name,
|
toolName: toolResponse.tool.name,
|
||||||
status: MessageBlockStatus.PROCESSING,
|
status: MessageBlockStatus.PROCESSING,
|
||||||
|
|||||||
@ -55,6 +55,7 @@ export type AssistantSettings = {
|
|||||||
maxTokens: number | undefined
|
maxTokens: number | undefined
|
||||||
enableMaxTokens: boolean
|
enableMaxTokens: boolean
|
||||||
streamOutput: boolean
|
streamOutput: boolean
|
||||||
|
enableToolUse: boolean
|
||||||
hideMessages: boolean
|
hideMessages: boolean
|
||||||
defaultModel?: Model
|
defaultModel?: Model
|
||||||
customParameters?: AssistantSettingCustomParameters[]
|
customParameters?: AssistantSettingCustomParameters[]
|
||||||
@ -570,13 +571,25 @@ export interface MCPConfig {
|
|||||||
servers: MCPServer[]
|
servers: MCPServer[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface MCPToolResponse {
|
interface BaseToolResponse {
|
||||||
id: string // tool call id, it should be unique
|
id: string // unique id
|
||||||
tool: MCPTool // tool info
|
tool: MCPTool
|
||||||
|
arguments: Record<string, unknown> | undefined
|
||||||
status: string // 'invoking' | 'done'
|
status: string // 'invoking' | 'done'
|
||||||
response?: any
|
response?: any
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ToolUseResponse extends BaseToolResponse {
|
||||||
|
toolUseId: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ToolCallResponse extends BaseToolResponse {
|
||||||
|
// gemini tool call id might be undefined
|
||||||
|
toolCallId?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export type MCPToolResponse = ToolUseResponse | ToolCallResponse
|
||||||
|
|
||||||
export interface MCPToolResultContent {
|
export interface MCPToolResultContent {
|
||||||
type: 'text' | 'image' | 'audio' | 'resource'
|
type: 'text' | 'image' | 'audio' | 'resource'
|
||||||
text?: string
|
text?: string
|
||||||
@ -586,6 +599,7 @@ export interface MCPToolResultContent {
|
|||||||
uri?: string
|
uri?: string
|
||||||
text?: string
|
text?: string
|
||||||
mimeType?: string
|
mimeType?: string
|
||||||
|
blob?: string
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,18 +1,31 @@
|
|||||||
import { ContentBlockParam, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
|
import {
|
||||||
import { MessageParam } from '@anthropic-ai/sdk/resources'
|
ContentBlockParam,
|
||||||
import { Content, FunctionCall, Part } from '@google/genai'
|
MessageParam,
|
||||||
|
ToolResultBlockParam,
|
||||||
|
ToolUnion,
|
||||||
|
ToolUseBlock
|
||||||
|
} from '@anthropic-ai/sdk/resources'
|
||||||
|
import { Content, FunctionCall, Part, Tool, Type as GeminiSchemaType } from '@google/genai'
|
||||||
|
import { isVisionModel } from '@renderer/config/models'
|
||||||
import store from '@renderer/store'
|
import store from '@renderer/store'
|
||||||
import { addMCPServer } from '@renderer/store/mcp'
|
import { addMCPServer } from '@renderer/store/mcp'
|
||||||
import { MCPCallToolResponse, MCPServer, MCPTool, MCPToolResponse } from '@renderer/types'
|
import { MCPCallToolResponse, MCPServer, MCPTool, MCPToolResponse, Model, ToolUseResponse } from '@renderer/types'
|
||||||
import type { MCPToolCompleteChunk, MCPToolInProgressChunk } from '@renderer/types/chunk'
|
import type { MCPToolCompleteChunk, MCPToolInProgressChunk } from '@renderer/types/chunk'
|
||||||
import { ChunkType } from '@renderer/types/chunk'
|
import { ChunkType } from '@renderer/types/chunk'
|
||||||
|
import { isArray, isObject, pull, transform } from 'lodash'
|
||||||
import { nanoid } from 'nanoid'
|
import { nanoid } from 'nanoid'
|
||||||
import OpenAI from 'openai'
|
import OpenAI from 'openai'
|
||||||
import { ChatCompletionContentPart, ChatCompletionMessageParam, ChatCompletionMessageToolCall } from 'openai/resources'
|
import {
|
||||||
|
ChatCompletionContentPart,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
ChatCompletionTool
|
||||||
|
} from 'openai/resources'
|
||||||
|
|
||||||
import { CompletionsParams } from '../providers/AiProvider'
|
import { CompletionsParams } from '../providers/AiProvider'
|
||||||
|
|
||||||
const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install'
|
const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install'
|
||||||
|
const EXTRA_SCHEMA_KEYS = ['schema', 'headers']
|
||||||
|
|
||||||
// const ensureValidSchema = (obj: Record<string, any>) => {
|
// const ensureValidSchema = (obj: Record<string, any>) => {
|
||||||
// // Filter out unsupported keys for Gemini
|
// // Filter out unsupported keys for Gemini
|
||||||
@ -153,77 +166,116 @@ const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install'
|
|||||||
// return processedProperties
|
// return processedProperties
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
|
export function filterProperties(
|
||||||
// return mcpTools.map((tool) => ({
|
properties: Record<string, any> | string | number | boolean | Array<Record<string, any> | string | number | boolean>,
|
||||||
// type: 'function',
|
supportedKeys: string[]
|
||||||
// name: tool.name,
|
) {
|
||||||
// function: {
|
// If it is an array, recursively process each element
|
||||||
// name: tool.id,
|
if (isArray(properties)) {
|
||||||
// description: tool.description,
|
return properties.map((item) => filterProperties(item, supportedKeys))
|
||||||
// parameters: {
|
|
||||||
// type: 'object',
|
|
||||||
// properties: filterPropertieAttributes(tool)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }))
|
|
||||||
// }
|
|
||||||
|
|
||||||
export function openAIToolsToMcpTool(
|
|
||||||
mcpTools: MCPTool[] | undefined,
|
|
||||||
llmTool: ChatCompletionMessageToolCall
|
|
||||||
): MCPTool | undefined {
|
|
||||||
if (!mcpTools) {
|
|
||||||
return undefined
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const tool = mcpTools.find(
|
// If it is an object, recursively process each property
|
||||||
(mcptool) => mcptool.id === llmTool.function.name || mcptool.name === llmTool.function.name
|
if (isObject(properties)) {
|
||||||
)
|
return transform(
|
||||||
|
properties,
|
||||||
|
(result, value, key) => {
|
||||||
|
if (key === 'properties') {
|
||||||
|
result[key] = transform(value, (acc, v, k) => {
|
||||||
|
acc[k] = filterProperties(v, supportedKeys)
|
||||||
|
})
|
||||||
|
|
||||||
if (!tool) {
|
result['additionalProperties'] = false
|
||||||
console.warn('No MCP Tool found for tool call:', llmTool)
|
result['required'] = pull(Object.keys(value), ...EXTRA_SCHEMA_KEYS)
|
||||||
return undefined
|
} else if (key === 'oneOf') {
|
||||||
|
// openai only supports anyOf
|
||||||
|
result['anyOf'] = filterProperties(value, supportedKeys)
|
||||||
|
} else if (supportedKeys.includes(key)) {
|
||||||
|
result[key] = filterProperties(value, supportedKeys)
|
||||||
|
if (key === 'type' && value === 'object') {
|
||||||
|
result['additionalProperties'] = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{}
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
console.log(
|
// Return other types directly (e.g., string, number, etc.)
|
||||||
`[MCP] OpenAI Tool to MCP Tool: ${tool.serverName} ${tool.name}`,
|
return properties
|
||||||
tool,
|
|
||||||
'args',
|
|
||||||
llmTool.function.arguments
|
|
||||||
)
|
|
||||||
// use this to parse the arguments and avoid parsing errors
|
|
||||||
let args: any = {}
|
|
||||||
try {
|
|
||||||
args = JSON.parse(llmTool.function.arguments)
|
|
||||||
} catch (e) {
|
|
||||||
console.error('Error parsing arguments', e)
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
id: tool.id,
|
|
||||||
serverId: tool.serverId,
|
|
||||||
serverName: tool.serverName,
|
|
||||||
name: tool.name,
|
|
||||||
description: tool.description,
|
|
||||||
inputSchema: args
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function callMCPTool(tool: MCPTool): Promise<MCPCallToolResponse> {
|
export function mcpToolsToOpenAIResponseTools(mcpTools: MCPTool[]): OpenAI.Responses.Tool[] {
|
||||||
console.log(`[MCP] Calling Tool: ${tool.serverName} ${tool.name}`, tool)
|
const schemaKeys = ['type', 'description', 'items', 'enum', 'additionalProperties', 'anyof']
|
||||||
|
return mcpTools.map(
|
||||||
|
(tool) =>
|
||||||
|
({
|
||||||
|
type: 'function',
|
||||||
|
name: tool.id,
|
||||||
|
parameters: {
|
||||||
|
type: 'object',
|
||||||
|
properties: filterProperties(tool.inputSchema, schemaKeys).properties,
|
||||||
|
required: pull(Object.keys(tool.inputSchema.properties), ...EXTRA_SCHEMA_KEYS),
|
||||||
|
additionalProperties: false
|
||||||
|
},
|
||||||
|
strict: true
|
||||||
|
}) satisfies OpenAI.Responses.Tool
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function mcpToolsToOpenAIChatTools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
|
||||||
|
return mcpTools.map(
|
||||||
|
(tool) =>
|
||||||
|
({
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: tool.id,
|
||||||
|
description: tool.description,
|
||||||
|
parameters: {
|
||||||
|
type: 'object',
|
||||||
|
properties: tool.inputSchema.properties,
|
||||||
|
required: tool.inputSchema.required
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}) as ChatCompletionTool
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function openAIToolsToMcpTool(
|
||||||
|
mcpTools: MCPTool[],
|
||||||
|
toolCall: OpenAI.Responses.ResponseFunctionToolCall | ChatCompletionMessageToolCall
|
||||||
|
): MCPTool | undefined {
|
||||||
|
const tool = mcpTools.find((mcpTool) => {
|
||||||
|
if ('name' in toolCall) {
|
||||||
|
return mcpTool.id === toolCall.name || mcpTool.name === toolCall.name
|
||||||
|
} else {
|
||||||
|
return mcpTool.id === toolCall.function.name || mcpTool.name === toolCall.function.name
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if (!tool) {
|
||||||
|
console.warn('No MCP Tool found for tool call:', toolCall)
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
return tool
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function callMCPTool(toolResponse: MCPToolResponse): Promise<MCPCallToolResponse> {
|
||||||
|
console.log(`[MCP] Calling Tool: ${toolResponse.tool.serverName} ${toolResponse.tool.name}`, toolResponse.tool)
|
||||||
try {
|
try {
|
||||||
const server = getMcpServerByTool(tool)
|
const server = getMcpServerByTool(toolResponse.tool)
|
||||||
|
|
||||||
if (!server) {
|
if (!server) {
|
||||||
throw new Error(`Server not found: ${tool.serverName}`)
|
throw new Error(`Server not found: ${toolResponse.tool.serverName}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
const resp = await window.api.mcp.callTool({
|
const resp = await window.api.mcp.callTool({
|
||||||
server,
|
server,
|
||||||
name: tool.name,
|
name: toolResponse.tool.name,
|
||||||
args: tool.inputSchema
|
args: toolResponse.arguments
|
||||||
})
|
})
|
||||||
if (tool.serverName === MCP_AUTO_INSTALL_SERVER_NAME) {
|
if (toolResponse.tool.serverName === MCP_AUTO_INSTALL_SERVER_NAME) {
|
||||||
if (resp.data) {
|
if (resp.data) {
|
||||||
const mcpServer: MCPServer = {
|
const mcpServer: MCPServer = {
|
||||||
id: `f${nanoid()}`,
|
id: `f${nanoid()}`,
|
||||||
@ -241,16 +293,16 @@ export async function callMCPTool(tool: MCPTool): Promise<MCPCallToolResponse> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
console.log(`[MCP] Tool called: ${tool.serverName} ${tool.name}`, resp)
|
console.log(`[MCP] Tool called: ${toolResponse.tool.serverName} ${toolResponse.tool.name}`, resp)
|
||||||
return resp
|
return resp
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error(`[MCP] Error calling Tool: ${tool.serverName} ${tool.name}`, e)
|
console.error(`[MCP] Error calling Tool: ${toolResponse.tool.serverName} ${toolResponse.tool.name}`, e)
|
||||||
return Promise.resolve({
|
return Promise.resolve({
|
||||||
isError: true,
|
isError: true,
|
||||||
content: [
|
content: [
|
||||||
{
|
{
|
||||||
type: 'text',
|
type: 'text',
|
||||||
text: `Error calling tool ${tool.name}: ${e instanceof Error ? e.stack || e.message || 'No error details available' : JSON.stringify(e)}`
|
text: `Error calling tool ${toolResponse.tool.name}: ${e instanceof Error ? e.stack || e.message || 'No error details available' : JSON.stringify(e)}`
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
@ -262,7 +314,7 @@ export function mcpToolsToAnthropicTools(mcpTools: MCPTool[]): Array<ToolUnion>
|
|||||||
const t: ToolUnion = {
|
const t: ToolUnion = {
|
||||||
name: tool.id,
|
name: tool.id,
|
||||||
description: tool.description,
|
description: tool.description,
|
||||||
// @ts-ignore no check
|
// @ts-ignore ignore type as it it unknow
|
||||||
input_schema: tool.inputSchema
|
input_schema: tool.inputSchema
|
||||||
}
|
}
|
||||||
return t
|
return t
|
||||||
@ -275,53 +327,68 @@ export function anthropicToolUseToMcpTool(mcpTools: MCPTool[] | undefined, toolU
|
|||||||
if (!tool) {
|
if (!tool) {
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
// @ts-ignore ignore type as it it unknow
|
|
||||||
tool.inputSchema = toolUse.input
|
|
||||||
return tool
|
return tool
|
||||||
}
|
}
|
||||||
|
|
||||||
// export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTool[] {
|
/**
|
||||||
// if (!mcpTools || mcpTools.length === 0) {
|
* @param mcpTools
|
||||||
// // No tools available
|
* @returns
|
||||||
// return []
|
*/
|
||||||
// }
|
export function mcpToolsToGeminiTools(mcpTools: MCPTool[]): Tool[] {
|
||||||
// const functions: FunctionDeclaration[] = []
|
/**
|
||||||
|
* @typedef {import('@google/genai').Schema} Schema
|
||||||
// for (const tool of mcpTools) {
|
*/
|
||||||
// const properties = filterPropertieAttributes(tool, true)
|
const schemaKeys = [
|
||||||
// const functionDeclaration: FunctionDeclaration = {
|
'example',
|
||||||
// name: tool.id,
|
'pattern',
|
||||||
// description: tool.description,
|
'default',
|
||||||
// parameters: {
|
'maxLength',
|
||||||
// type: SchemaType.OBJECT,
|
'minLength',
|
||||||
// properties:
|
'minProperties',
|
||||||
// Object.keys(properties).length > 0
|
'maxProperties',
|
||||||
// ? Object.fromEntries(
|
'anyOf',
|
||||||
// Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
|
'description',
|
||||||
// )
|
'enum',
|
||||||
// : { _empty: { type: SchemaType.STRING } as SimpleStringSchema }
|
'format',
|
||||||
// } as FunctionDeclarationSchema
|
'items',
|
||||||
// }
|
'maxItems',
|
||||||
// functions.push(functionDeclaration)
|
'maximum',
|
||||||
// }
|
'minItems',
|
||||||
// const tool: geminiTool = {
|
'minimum',
|
||||||
// functionDeclarations: functions
|
'nullable',
|
||||||
// }
|
'properties',
|
||||||
// return [tool]
|
'propertyOrdering',
|
||||||
// }
|
'required',
|
||||||
|
'title',
|
||||||
|
'type'
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
functionDeclarations: mcpTools?.map((tool) => {
|
||||||
|
return {
|
||||||
|
name: tool.id,
|
||||||
|
description: tool.description,
|
||||||
|
parameters: {
|
||||||
|
type: GeminiSchemaType.OBJECT,
|
||||||
|
properties: filterProperties(tool.inputSchema, schemaKeys).properties,
|
||||||
|
required: tool.inputSchema.required
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
export function geminiFunctionCallToMcpTool(
|
export function geminiFunctionCallToMcpTool(
|
||||||
mcpTools: MCPTool[] | undefined,
|
mcpTools: MCPTool[] | undefined,
|
||||||
fcall: FunctionCall | undefined
|
toolCall: FunctionCall | undefined
|
||||||
): MCPTool | undefined {
|
): MCPTool | undefined {
|
||||||
if (!fcall) return undefined
|
if (!toolCall) return undefined
|
||||||
if (!mcpTools) return undefined
|
if (!mcpTools) return undefined
|
||||||
const tool = mcpTools.find((tool) => tool.id === fcall.name)
|
const tool = mcpTools.find((tool) => tool.id === toolCall.name)
|
||||||
if (!tool) {
|
if (!tool) {
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
// @ts-ignore schema is not a valid property
|
|
||||||
tool.inputSchema = fcall.args
|
|
||||||
return tool
|
return tool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -368,13 +435,13 @@ export function getMcpServerByTool(tool: MCPTool) {
|
|||||||
return servers.find((s) => s.id === tool.serverId)
|
return servers.find((s) => s.id === tool.serverId)
|
||||||
}
|
}
|
||||||
|
|
||||||
export function parseToolUse(content: string, mcpTools: MCPTool[]): MCPToolResponse[] {
|
export function parseToolUse(content: string, mcpTools: MCPTool[]): ToolUseResponse[] {
|
||||||
if (!content || !mcpTools || mcpTools.length === 0) {
|
if (!content || !mcpTools || mcpTools.length === 0) {
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
const toolUsePattern =
|
const toolUsePattern =
|
||||||
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
|
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
|
||||||
const tools: MCPToolResponse[] = []
|
const tools: ToolUseResponse[] = []
|
||||||
let match
|
let match
|
||||||
let idx = 0
|
let idx = 0
|
||||||
// Find all tool use blocks
|
// Find all tool use blocks
|
||||||
@ -401,10 +468,9 @@ export function parseToolUse(content: string, mcpTools: MCPTool[]): MCPToolRespo
|
|||||||
// Add to tools array
|
// Add to tools array
|
||||||
tools.push({
|
tools.push({
|
||||||
id: `${toolName}-${idx++}`, // Unique ID for each tool use
|
id: `${toolName}-${idx++}`, // Unique ID for each tool use
|
||||||
tool: {
|
toolUseId: mcpTool.id,
|
||||||
...mcpTool,
|
tool: mcpTool,
|
||||||
inputSchema: parsedArgs
|
arguments: parsedArgs,
|
||||||
},
|
|
||||||
status: 'pending'
|
status: 'pending'
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -414,36 +480,69 @@ export function parseToolUse(content: string, mcpTools: MCPTool[]): MCPToolRespo
|
|||||||
return tools
|
return tools
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function parseAndCallTools(
|
export async function parseAndCallTools<R>(
|
||||||
content: string,
|
tools: MCPToolResponse[],
|
||||||
toolResponses: MCPToolResponse[],
|
allToolResponses: MCPToolResponse[],
|
||||||
onChunk: CompletionsParams['onChunk'],
|
onChunk: CompletionsParams['onChunk'],
|
||||||
idx: number,
|
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
||||||
convertToMessage: (
|
model: Model,
|
||||||
toolCallId: string,
|
mcpTools?: MCPTool[]
|
||||||
resp: MCPCallToolResponse,
|
): Promise<
|
||||||
isVisionModel: boolean
|
(ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.ResponseInputItem | ToolResultBlockParam)[]
|
||||||
) => ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.EasyInputMessage,
|
>
|
||||||
mcpTools?: MCPTool[],
|
|
||||||
isVisionModel: boolean = false
|
export async function parseAndCallTools<R>(
|
||||||
): Promise<(ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.EasyInputMessage)[]> {
|
content: string,
|
||||||
const toolResults: (ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.EasyInputMessage)[] = []
|
allToolResponses: MCPToolResponse[],
|
||||||
// process tool use
|
onChunk: CompletionsParams['onChunk'],
|
||||||
const tools = parseToolUse(content, mcpTools || [])
|
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
||||||
if (!tools || tools.length === 0) {
|
model: Model,
|
||||||
|
mcpTools?: MCPTool[]
|
||||||
|
): Promise<
|
||||||
|
(ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.ResponseInputItem | ToolResultBlockParam)[]
|
||||||
|
>
|
||||||
|
|
||||||
|
export async function parseAndCallTools<R>(
|
||||||
|
content: string | MCPToolResponse[],
|
||||||
|
allToolResponses: MCPToolResponse[],
|
||||||
|
onChunk: CompletionsParams['onChunk'],
|
||||||
|
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
||||||
|
model: Model,
|
||||||
|
mcpTools?: MCPTool[]
|
||||||
|
): Promise<R[]> {
|
||||||
|
const toolResults: R[] = []
|
||||||
|
let curToolResponses: MCPToolResponse[] = []
|
||||||
|
if (Array.isArray(content)) {
|
||||||
|
curToolResponses = content
|
||||||
|
} else {
|
||||||
|
// process tool use
|
||||||
|
curToolResponses = parseToolUse(content, mcpTools || [])
|
||||||
|
}
|
||||||
|
if (!curToolResponses || curToolResponses.length === 0) {
|
||||||
return toolResults
|
return toolResults
|
||||||
}
|
}
|
||||||
for (let i = 0; i < tools.length; i++) {
|
for (let i = 0; i < curToolResponses.length; i++) {
|
||||||
const tool = tools[i]
|
const toolResponse = curToolResponses[i]
|
||||||
upsertMCPToolResponse(toolResponses, { id: `${tool.id}-${idx}-${i}`, tool: tool.tool, status: 'invoking' }, onChunk)
|
upsertMCPToolResponse(
|
||||||
|
allToolResponses,
|
||||||
|
{
|
||||||
|
...toolResponse,
|
||||||
|
status: 'invoking'
|
||||||
|
},
|
||||||
|
onChunk
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const toolPromises = tools.map(async (tool, i) => {
|
const toolPromises = curToolResponses.map(async (toolResponse) => {
|
||||||
const images: string[] = []
|
const images: string[] = []
|
||||||
const toolCallResponse = await callMCPTool(tool.tool)
|
const toolCallResponse = await callMCPTool(toolResponse)
|
||||||
upsertMCPToolResponse(
|
upsertMCPToolResponse(
|
||||||
toolResponses,
|
allToolResponses,
|
||||||
{ id: `${tool.id}-${idx}-${i}`, tool: tool.tool, status: 'done', response: toolCallResponse },
|
{
|
||||||
|
...toolResponse,
|
||||||
|
status: 'done',
|
||||||
|
response: toolCallResponse
|
||||||
|
},
|
||||||
onChunk
|
onChunk
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -466,15 +565,15 @@ export async function parseAndCallTools(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return convertToMessage(tool.tool.id, toolCallResponse, isVisionModel)
|
return convertToMessage(toolResponse, toolCallResponse, model)
|
||||||
})
|
})
|
||||||
|
|
||||||
toolResults.push(...(await Promise.all(toolPromises)))
|
toolResults.push(...(await Promise.all(toolPromises)).filter((t) => typeof t !== 'undefined'))
|
||||||
return toolResults
|
return toolResults
|
||||||
}
|
}
|
||||||
|
|
||||||
export function mcpToolCallResponseToOpenAICompatibleMessage(
|
export function mcpToolCallResponseToOpenAICompatibleMessage(
|
||||||
toolCallId: string,
|
mcpToolResponse: MCPToolResponse,
|
||||||
resp: MCPCallToolResponse,
|
resp: MCPCallToolResponse,
|
||||||
isVisionModel: boolean = false
|
isVisionModel: boolean = false
|
||||||
): ChatCompletionMessageParam {
|
): ChatCompletionMessageParam {
|
||||||
@ -488,7 +587,7 @@ export function mcpToolCallResponseToOpenAICompatibleMessage(
|
|||||||
const content: ChatCompletionContentPart[] = [
|
const content: ChatCompletionContentPart[] = [
|
||||||
{
|
{
|
||||||
type: 'text',
|
type: 'text',
|
||||||
text: `Here is the result of tool call ${toolCallId}:`
|
text: `Here is the result of mcp tool use \`${mcpToolResponse.tool.name}\`:`
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -541,7 +640,7 @@ export function mcpToolCallResponseToOpenAICompatibleMessage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function mcpToolCallResponseToOpenAIMessage(
|
export function mcpToolCallResponseToOpenAIMessage(
|
||||||
toolCallId: string,
|
mcpToolResponse: MCPToolResponse,
|
||||||
resp: MCPCallToolResponse,
|
resp: MCPCallToolResponse,
|
||||||
isVisionModel: boolean = false
|
isVisionModel: boolean = false
|
||||||
): OpenAI.Responses.EasyInputMessage {
|
): OpenAI.Responses.EasyInputMessage {
|
||||||
@ -555,7 +654,7 @@ export function mcpToolCallResponseToOpenAIMessage(
|
|||||||
const content: OpenAI.Responses.ResponseInputContent[] = [
|
const content: OpenAI.Responses.ResponseInputContent[] = [
|
||||||
{
|
{
|
||||||
type: 'input_text',
|
type: 'input_text',
|
||||||
text: `Here is the result of tool call ${toolCallId}:`
|
text: `Here is the result of mcp tool use \`${mcpToolResponse.tool.name}\`:`
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -597,9 +696,9 @@ export function mcpToolCallResponseToOpenAIMessage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function mcpToolCallResponseToAnthropicMessage(
|
export function mcpToolCallResponseToAnthropicMessage(
|
||||||
toolCallId: string,
|
mcpToolResponse: MCPToolResponse,
|
||||||
resp: MCPCallToolResponse,
|
resp: MCPCallToolResponse,
|
||||||
isVisionModel: boolean = false
|
model: Model
|
||||||
): MessageParam {
|
): MessageParam {
|
||||||
const message = {
|
const message = {
|
||||||
role: 'user'
|
role: 'user'
|
||||||
@ -610,10 +709,10 @@ export function mcpToolCallResponseToAnthropicMessage(
|
|||||||
const content: ContentBlockParam[] = [
|
const content: ContentBlockParam[] = [
|
||||||
{
|
{
|
||||||
type: 'text',
|
type: 'text',
|
||||||
text: `Here is the result of tool call ${toolCallId}:`
|
text: `Here is the result of mcp tool use \`${mcpToolResponse.tool.name}\`:`
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
if (isVisionModel) {
|
if (isVisionModel(model)) {
|
||||||
for (const item of resp.content) {
|
for (const item of resp.content) {
|
||||||
switch (item.type) {
|
switch (item.type) {
|
||||||
case 'text':
|
case 'text':
|
||||||
@ -665,7 +764,7 @@ export function mcpToolCallResponseToAnthropicMessage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function mcpToolCallResponseToGeminiMessage(
|
export function mcpToolCallResponseToGeminiMessage(
|
||||||
toolCallId: string,
|
mcpToolResponse: MCPToolResponse,
|
||||||
resp: MCPCallToolResponse,
|
resp: MCPCallToolResponse,
|
||||||
isVisionModel: boolean = false
|
isVisionModel: boolean = false
|
||||||
): Content {
|
): Content {
|
||||||
@ -682,7 +781,7 @@ export function mcpToolCallResponseToGeminiMessage(
|
|||||||
} else {
|
} else {
|
||||||
const parts: Part[] = [
|
const parts: Part[] = [
|
||||||
{
|
{
|
||||||
text: `Here is the result of tool call ${toolCallId}:`
|
text: `Here is the result of mcp tool use \`${mcpToolResponse.tool.name}\`:`
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
if (isVisionModel) {
|
if (isVisionModel) {
|
||||||
|
|||||||
@ -147,7 +147,7 @@ ${availableTools}
|
|||||||
</tools>`
|
</tools>`
|
||||||
}
|
}
|
||||||
|
|
||||||
export const buildSystemPrompt = (userSystemPrompt: string, tools: MCPTool[]): string => {
|
export const buildSystemPrompt = (userSystemPrompt: string, tools?: MCPTool[]): string => {
|
||||||
if (tools && tools.length > 0) {
|
if (tools && tools.length > 0) {
|
||||||
return SYSTEM_PROMPT.replace('{{ USER_SYSTEM_PROMPT }}', userSystemPrompt)
|
return SYSTEM_PROMPT.replace('{{ USER_SYSTEM_PROMPT }}', userSystemPrompt)
|
||||||
.replace('{{ TOOL_USE_EXAMPLES }}', ToolUseExamples)
|
.replace('{{ TOOL_USE_EXAMPLES }}', ToolUseExamples)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user