mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-30 15:59:09 +08:00
* feat: Add health check to check all the models at one time * fix: add model avatars to the health-check list * style: Use segmented instead of switch * fix: remove redundant timing reports * refactor: Extract small functions * refactor: use more hooks to make the main component clearer * fix: mask API keys with asterisks * refactor: split health check popup and model list - rename ModelHealthCheckPopup to HealthCheckPopup - add HealthCheckModelList - add maskApiKey to utils * refactor: compute latency in checkApi * fix: remove unused i18n keys * refactor: use checkModel instead of checkApi for better semantics * fix: update comments * refactor: extract health checking functions to services * refactor: extract model list * refactor: render statuses on the existing model list * fix: reset button style on completion * fix: disable model card while checking - remove unused i18n keys - better window message * refactor: show provider name in messages * refactor: change default values * refactor: fully migrate model list from ProviderSetting to ModelList
317 lines
8.3 KiB
TypeScript
317 lines
8.3 KiB
TypeScript
import { getOpenAIWebSearchParams } from '@renderer/config/models'
|
|
import i18n from '@renderer/i18n'
|
|
import store from '@renderer/store'
|
|
import { setGenerating } from '@renderer/store/runtime'
|
|
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types'
|
|
import { addAbortController } from '@renderer/utils/abortController'
|
|
import { formatMessageError } from '@renderer/utils/error'
|
|
import { findLast, isEmpty } from 'lodash'
|
|
|
|
import AiProvider from '../providers/AiProvider'
|
|
import {
|
|
getAssistantProvider,
|
|
getDefaultModel,
|
|
getProviderByModel,
|
|
getTopNamingModel,
|
|
getTranslateModel
|
|
} from './AssistantService'
|
|
import { EVENT_NAMES, EventEmitter } from './EventService'
|
|
import { filterMessages, filterUsefulMessages } from './MessagesService'
|
|
import { estimateMessagesUsage } from './TokenService'
|
|
import WebSearchService from './WebSearchService'
|
|
|
|
export async function fetchChatCompletion({
|
|
message,
|
|
messages,
|
|
assistant,
|
|
onResponse
|
|
}: {
|
|
message: Message
|
|
messages: Message[]
|
|
assistant: Assistant
|
|
onResponse: (message: Message) => void
|
|
}) {
|
|
window.keyv.set(EVENT_NAMES.CHAT_COMPLETION_PAUSED, false)
|
|
|
|
const provider = getAssistantProvider(assistant)
|
|
const webSearchProvider = WebSearchService.getWebSearchProvider()
|
|
const AI = new AiProvider(provider)
|
|
|
|
store.dispatch(setGenerating(true))
|
|
|
|
onResponse({ ...message })
|
|
|
|
const pauseFn = (message: Message) => {
|
|
message.status = 'paused'
|
|
EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message)
|
|
store.dispatch(setGenerating(false))
|
|
onResponse({ ...message, status: 'paused' })
|
|
}
|
|
|
|
addAbortController(message.askId ?? message.id, pauseFn.bind(null, message))
|
|
|
|
try {
|
|
let _messages: Message[] = []
|
|
let isFirstChunk = true
|
|
|
|
// Search web
|
|
if (WebSearchService.isWebSearchEnabled() && assistant.enableWebSearch && assistant.model) {
|
|
const webSearchParams = getOpenAIWebSearchParams(assistant, assistant.model)
|
|
|
|
if (isEmpty(webSearchParams)) {
|
|
const lastMessage = findLast(messages, (m) => m.role === 'user')
|
|
const hasKnowledgeBase = !isEmpty(lastMessage?.knowledgeBaseIds)
|
|
if (lastMessage) {
|
|
if (hasKnowledgeBase) {
|
|
window.message.info({
|
|
content: i18n.t('message.ignore.knowledge.base'),
|
|
key: 'knowledge-base-no-match-info'
|
|
})
|
|
}
|
|
onResponse({ ...message, status: 'searching' })
|
|
const webSearch = await WebSearchService.search(webSearchProvider, lastMessage.content)
|
|
message.metadata = {
|
|
...message.metadata,
|
|
webSearch: webSearch
|
|
}
|
|
window.keyv.set(`web-search-${lastMessage?.id}`, webSearch)
|
|
}
|
|
}
|
|
}
|
|
|
|
const allMCPTools = await window.api.mcp.listTools()
|
|
await AI.completions({
|
|
messages: filterUsefulMessages(messages),
|
|
assistant,
|
|
onFilterMessages: (messages) => (_messages = messages),
|
|
onChunk: ({ text, reasoning_content, usage, metrics, search, citations, mcpToolResponse }) => {
|
|
message.content = message.content + text || ''
|
|
message.usage = usage
|
|
message.metrics = metrics
|
|
|
|
if (reasoning_content) {
|
|
message.reasoning_content = (message.reasoning_content || '') + reasoning_content
|
|
}
|
|
|
|
if (search) {
|
|
message.metadata = { ...message.metadata, groundingMetadata: search }
|
|
}
|
|
|
|
if (mcpToolResponse) {
|
|
message.metadata = { ...message.metadata, mcpTools: mcpToolResponse }
|
|
}
|
|
|
|
// Handle citations from Perplexity API
|
|
if (isFirstChunk && citations) {
|
|
message.metadata = {
|
|
...message.metadata,
|
|
citations
|
|
}
|
|
isFirstChunk = false
|
|
}
|
|
|
|
onResponse({ ...message, status: 'pending' })
|
|
},
|
|
mcpTools: allMCPTools
|
|
})
|
|
|
|
message.status = 'success'
|
|
|
|
if (!message.usage || !message?.usage?.completion_tokens) {
|
|
message.usage = await estimateMessagesUsage({
|
|
assistant,
|
|
messages: [..._messages, message]
|
|
})
|
|
// Set metrics.completion_tokens
|
|
if (message.metrics && message?.usage?.completion_tokens) {
|
|
if (!message.metrics?.completion_tokens) {
|
|
message.metrics.completion_tokens = message.usage.completion_tokens
|
|
}
|
|
}
|
|
}
|
|
} catch (error: any) {
|
|
console.log('error', error)
|
|
message.status = 'error'
|
|
message.error = formatMessageError(error)
|
|
}
|
|
|
|
// Update message status
|
|
message.status = window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED) ? 'paused' : message.status
|
|
|
|
// Emit chat completion event
|
|
EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message)
|
|
onResponse(message)
|
|
|
|
// Reset generating state
|
|
store.dispatch(setGenerating(false))
|
|
|
|
return message
|
|
}
|
|
|
|
interface FetchTranslateProps {
|
|
message: Message
|
|
assistant: Assistant
|
|
onResponse?: (text: string) => void
|
|
}
|
|
|
|
export async function fetchTranslate({ message, assistant, onResponse }: FetchTranslateProps) {
|
|
const model = getTranslateModel()
|
|
|
|
if (!model) {
|
|
return ''
|
|
}
|
|
|
|
const provider = getProviderByModel(model)
|
|
|
|
if (!hasApiKey(provider)) {
|
|
return ''
|
|
}
|
|
|
|
const AI = new AiProvider(provider)
|
|
|
|
try {
|
|
return await AI.translate(message, assistant, onResponse)
|
|
} catch (error: any) {
|
|
return ''
|
|
}
|
|
}
|
|
|
|
export async function fetchMessagesSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
|
|
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
|
const provider = getProviderByModel(model)
|
|
|
|
if (!hasApiKey(provider)) {
|
|
return null
|
|
}
|
|
|
|
const AI = new AiProvider(provider)
|
|
|
|
try {
|
|
return await AI.summaries(filterMessages(messages), assistant)
|
|
} catch (error: any) {
|
|
return null
|
|
}
|
|
}
|
|
|
|
export async function fetchGenerate({ prompt, content }: { prompt: string; content: string }): Promise<string> {
|
|
const model = getDefaultModel()
|
|
const provider = getProviderByModel(model)
|
|
|
|
if (!hasApiKey(provider)) {
|
|
return ''
|
|
}
|
|
|
|
const AI = new AiProvider(provider)
|
|
|
|
try {
|
|
return await AI.generateText({ prompt, content })
|
|
} catch (error: any) {
|
|
return ''
|
|
}
|
|
}
|
|
|
|
export async function fetchSuggestions({
|
|
messages,
|
|
assistant
|
|
}: {
|
|
messages: Message[]
|
|
assistant: Assistant
|
|
}): Promise<Suggestion[]> {
|
|
const model = assistant.model
|
|
if (!model) {
|
|
return []
|
|
}
|
|
|
|
if (model.owned_by !== 'graphrag') {
|
|
return []
|
|
}
|
|
|
|
if (model.id.endsWith('global')) {
|
|
return []
|
|
}
|
|
|
|
const provider = getAssistantProvider(assistant)
|
|
const AI = new AiProvider(provider)
|
|
|
|
try {
|
|
return await AI.suggestions(filterMessages(messages), assistant)
|
|
} catch (error: any) {
|
|
return []
|
|
}
|
|
}
|
|
|
|
// Helper function to validate provider's basic settings such as API key, host, and model list
|
|
export function checkApiProvider(provider: Provider): {
|
|
valid: boolean
|
|
error: Error | null
|
|
} {
|
|
const key = 'api-check'
|
|
const style = { marginTop: '3vh' }
|
|
|
|
if (provider.id !== 'ollama' && provider.id !== 'lmstudio') {
|
|
if (!provider.apiKey) {
|
|
window.message.error({ content: i18n.t('message.error.enter.api.key'), key, style })
|
|
return {
|
|
valid: false,
|
|
error: new Error(i18n.t('message.error.enter.api.key'))
|
|
}
|
|
}
|
|
}
|
|
|
|
if (!provider.apiHost) {
|
|
window.message.error({ content: i18n.t('message.error.enter.api.host'), key, style })
|
|
return {
|
|
valid: false,
|
|
error: new Error(i18n.t('message.error.enter.api.host'))
|
|
}
|
|
}
|
|
|
|
if (isEmpty(provider.models)) {
|
|
window.message.error({ content: i18n.t('message.error.enter.model'), key, style })
|
|
return {
|
|
valid: false,
|
|
error: new Error(i18n.t('message.error.enter.model'))
|
|
}
|
|
}
|
|
|
|
return {
|
|
valid: true,
|
|
error: null
|
|
}
|
|
}
|
|
|
|
export async function checkApi(provider: Provider, model: Model) {
|
|
const validation = checkApiProvider(provider)
|
|
if (!validation.valid) {
|
|
return {
|
|
valid: validation.valid,
|
|
error: validation.error
|
|
}
|
|
}
|
|
|
|
const AI = new AiProvider(provider)
|
|
|
|
const { valid, error } = await AI.check(model)
|
|
|
|
return {
|
|
valid,
|
|
error
|
|
}
|
|
}
|
|
|
|
function hasApiKey(provider: Provider) {
|
|
if (!provider) return false
|
|
if (provider.id === 'ollama' || provider.id === 'lmstudio') return true
|
|
return !isEmpty(provider.apiKey)
|
|
}
|
|
|
|
export async function fetchModels(provider: Provider) {
|
|
const AI = new AiProvider(provider)
|
|
|
|
try {
|
|
return await AI.models()
|
|
} catch (error) {
|
|
return []
|
|
}
|
|
}
|