cherry-studio/src/renderer/src/services/ModelService.ts
one 1978cfc356
feat: Add health check to check all the models at one time (#2613)
* feat: Add health check to check all the models at one time

* fix: add model avatars to the health-check list

* style: Use segmented instead of switch

* fix: remove redundant timing reports

* refactor: Extract small functions

* refactor: use more hooks to make the main component clearer

* fix: mask API keys with asterisks

* refactor: split health check popup and model list

- rename ModelHealthCheckPopup to HealthCheckPopup
- add HealthCheckModelList
- add maskApiKey to utils

* refactor: compute latency in checkApi

* fix: remove unused i18n keys

* refactor: use checkModel instead of checkApi for better semantics

* fix: update comments

* refactor: extract health checking functions to services

* refactor: extract model list

* refactor: render statuses on the existing model list

* fix: reset button style on completion

* fix: disable model card while checking

- remove unused i18n keys
- better window message

* refactor: show provider name in messages

* refactor: change default values

* refactor: fully migrate model list from ProviderSetting to ModelList
2025-03-08 22:24:56 +08:00

90 lines
2.4 KiB
TypeScript

import { isEmbeddingModel } from '@renderer/config/models'
import AiProvider from '@renderer/providers/AiProvider'
import store from '@renderer/store'
import { Model, Provider } from '@renderer/types'
import { t } from 'i18next'
import { pick } from 'lodash'
import { checkApiProvider } from './ApiService'
export const getModelUniqId = (m?: Model) => {
return m?.id ? JSON.stringify(pick(m, ['id', 'provider'])) : ''
}
export const hasModel = (m?: Model) => {
const allModels = store
.getState()
.llm.providers.filter((p) => p.enabled)
.map((p) => p.models)
.flat()
return allModels.find((model) => model.id === m?.id)
}
export function getModelName(model?: Model) {
const provider = store.getState().llm.providers.find((p) => p.id === model?.provider)
const modelName = model?.name || model?.id || ''
if (provider) {
const providerName = provider?.isSystem ? t(`provider.${provider.id}`) : provider?.name
return `${modelName} | ${providerName}`
}
return modelName
}
// Generic function to perform model checks
// Abstracts provider validation and error handling, allowing different types of check logic
async function performModelCheck<T>(
provider: Provider,
model: Model,
checkFn: (ai: AiProvider, model: Model) => Promise<T>,
processResult: (result: T) => { valid: boolean; error: Error | null }
): Promise<{ valid: boolean; error: Error | null; latency?: number }> {
const validation = checkApiProvider(provider)
if (!validation.valid) {
return {
valid: validation.valid,
error: validation.error
}
}
const AI = new AiProvider(provider)
try {
const startTime = performance.now()
const result = await checkFn(AI, model)
const latency = performance.now() - startTime
return {
...processResult(result),
latency
}
} catch (error: any) {
return {
valid: false,
error
}
}
}
// Unified model check function
// Automatically selects appropriate check method based on model type
export async function checkModel(provider: Provider, model: Model) {
if (isEmbeddingModel(model)) {
return performModelCheck(
provider,
model,
(ai, model) => ai.getEmbeddingDimensions(model),
(dimensions) => ({ valid: dimensions > 0, error: null })
)
} else {
return performModelCheck(
provider,
model,
(ai, model) => ai.check(model),
({ valid, error }) => ({ valid, error: error || null })
)
}
}