mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-20 23:22:05 +08:00
fix: agent supported model filter (#10788)
* Revert "fix: make anthropic model provided by cherryin visible to agent (#10695)"
This reverts commit 7b3b73d390.
* fix: agent supported model filter
This commit is contained in:
parent
ab3083f943
commit
131444ac52
@ -2,7 +2,12 @@ import { isEmpty } from 'lodash'
|
|||||||
|
|
||||||
import { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels'
|
import { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels'
|
||||||
import { loggerService } from '../../services/LoggerService'
|
import { loggerService } from '../../services/LoggerService'
|
||||||
import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils'
|
import {
|
||||||
|
getAvailableProviders,
|
||||||
|
getProviderAnthropicModelChecker,
|
||||||
|
listAllAvailableModels,
|
||||||
|
transformModelToOpenAI
|
||||||
|
} from '../utils'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ModelsService')
|
const logger = loggerService.withContext('ModelsService')
|
||||||
|
|
||||||
@ -10,10 +15,6 @@ const logger = loggerService.withContext('ModelsService')
|
|||||||
|
|
||||||
export type ModelsFilter = ApiModelsFilter
|
export type ModelsFilter = ApiModelsFilter
|
||||||
|
|
||||||
const isAnthropicProvider = (provider: { type: string; anthropicApiHost?: string }) => {
|
|
||||||
return provider.type === 'anthropic' || !isEmpty(provider.anthropicApiHost?.trim())
|
|
||||||
}
|
|
||||||
|
|
||||||
export class ModelsService {
|
export class ModelsService {
|
||||||
async getModels(filter: ModelsFilter): Promise<ApiModelsResponse> {
|
async getModels(filter: ModelsFilter): Promise<ApiModelsResponse> {
|
||||||
try {
|
try {
|
||||||
@ -22,7 +23,7 @@ export class ModelsService {
|
|||||||
let providers = await getAvailableProviders()
|
let providers = await getAvailableProviders()
|
||||||
|
|
||||||
if (filter.providerType === 'anthropic') {
|
if (filter.providerType === 'anthropic') {
|
||||||
providers = providers.filter(isAnthropicProvider)
|
providers = providers.filter((p) => p.type === 'anthropic' || !isEmpty(p.anthropicApiHost?.trim()))
|
||||||
}
|
}
|
||||||
|
|
||||||
const models = await listAllAvailableModels(providers)
|
const models = await listAllAvailableModels(providers)
|
||||||
@ -31,23 +32,19 @@ export class ModelsService {
|
|||||||
|
|
||||||
for (const model of models) {
|
for (const model of models) {
|
||||||
const provider = providers.find((p) => p.id === model.provider)
|
const provider = providers.find((p) => p.id === model.provider)
|
||||||
logger.debug(`Processing model ${model.id} from provider ${model.provider}`, {
|
logger.debug(`Processing model ${model.id}`)
|
||||||
isAnthropicModel: provider?.isAnthropicModel
|
if (!provider) {
|
||||||
})
|
logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`)
|
||||||
if (
|
|
||||||
!provider ||
|
|
||||||
(filter.providerType === 'anthropic' && provider.isAnthropicModel && !provider.isAnthropicModel(model))
|
|
||||||
) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Special case: For "aihubmix", it should be covered by above condition, but just in case
|
|
||||||
if (provider.id === 'aihubmix' && filter.providerType === 'anthropic' && !model.id.includes('claude')) {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if (filter.supportAnthropic && model.endpoint_type !== 'anthropic' && !isAnthropicProvider(provider)) {
|
if (filter.providerType === 'anthropic') {
|
||||||
|
const checker = getProviderAnthropicModelChecker(provider.id)
|
||||||
|
if (!checker(model)) {
|
||||||
|
logger.debug(`Skipping model ${model.id} from ${model.provider}. Reason: Not an Anthropic model.`)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const openAIModel = transformModelToOpenAI(model, provider)
|
const openAIModel = transformModelToOpenAI(model, provider)
|
||||||
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
|
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import { CacheService } from '@main/services/CacheService'
|
import { CacheService } from '@main/services/CacheService'
|
||||||
import { loggerService } from '@main/services/LoggerService'
|
import { loggerService } from '@main/services/LoggerService'
|
||||||
import { reduxService } from '@main/services/ReduxService'
|
import { reduxService } from '@main/services/ReduxService'
|
||||||
import { ApiModel, EndpointType, Model, Provider } from '@types'
|
import { ApiModel, Model, Provider } from '@types'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ApiServerUtils')
|
const logger = loggerService.withContext('ApiServerUtils')
|
||||||
|
|
||||||
@ -114,7 +114,6 @@ export async function validateModelId(model: string): Promise<{
|
|||||||
error?: ModelValidationError
|
error?: ModelValidationError
|
||||||
provider?: Provider
|
provider?: Provider
|
||||||
modelId?: string
|
modelId?: string
|
||||||
modelEndpointType?: EndpointType
|
|
||||||
}> {
|
}> {
|
||||||
try {
|
try {
|
||||||
if (!model || typeof model !== 'string') {
|
if (!model || typeof model !== 'string') {
|
||||||
@ -167,8 +166,7 @@ export async function validateModelId(model: string): Promise<{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if model exists in provider
|
// Check if model exists in provider
|
||||||
const modelInProvider = provider.models?.find((m) => m.id === modelId)
|
const modelExists = provider.models?.some((m) => m.id === modelId)
|
||||||
const modelExists = !!modelInProvider
|
|
||||||
if (!modelExists) {
|
if (!modelExists) {
|
||||||
const availableModels = provider.models?.map((m) => m.id).join(', ') || 'none'
|
const availableModels = provider.models?.map((m) => m.id).join(', ') || 'none'
|
||||||
return {
|
return {
|
||||||
@ -181,13 +179,10 @@ export async function validateModelId(model: string): Promise<{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelEndpointType = modelInProvider?.endpoint_type
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
valid: true,
|
valid: true,
|
||||||
provider,
|
provider,
|
||||||
modelId,
|
modelId
|
||||||
modelEndpointType
|
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error validating model ID', { error, model })
|
logger.error('Error validating model ID', { error, model })
|
||||||
@ -284,3 +279,16 @@ export function validateProvider(provider: Provider): boolean {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const getProviderAnthropicModelChecker = (providerId: string): ((m: Model) => boolean) => {
|
||||||
|
switch (providerId) {
|
||||||
|
case 'cherryin':
|
||||||
|
case 'new-api':
|
||||||
|
return (m: Model) => m.endpoint_type === 'anthropic'
|
||||||
|
case 'aihubmix':
|
||||||
|
return (m: Model) => m.id.includes('claude')
|
||||||
|
default:
|
||||||
|
// allow all models when checker not configured
|
||||||
|
return () => true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -8,7 +8,6 @@ import { config as apiConfigService } from '@main/apiServer/config'
|
|||||||
import { validateModelId } from '@main/apiServer/utils'
|
import { validateModelId } from '@main/apiServer/utils'
|
||||||
import getLoginShellEnvironment from '@main/utils/shell-env'
|
import getLoginShellEnvironment from '@main/utils/shell-env'
|
||||||
import { app } from 'electron'
|
import { app } from 'electron'
|
||||||
import { isEmpty } from 'lodash'
|
|
||||||
|
|
||||||
import { GetAgentSessionResponse } from '../..'
|
import { GetAgentSessionResponse } from '../..'
|
||||||
import { AgentServiceInterface, AgentStream, AgentStreamEvent } from '../../interfaces/AgentStreamInterface'
|
import { AgentServiceInterface, AgentStream, AgentStreamEvent } from '../../interfaces/AgentStreamInterface'
|
||||||
@ -61,20 +60,11 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
})
|
})
|
||||||
return aiStream
|
return aiStream
|
||||||
}
|
}
|
||||||
|
if (
|
||||||
const validateModelInfo: (m: typeof modelInfo) => boolean = (m) => {
|
(modelInfo.provider?.type !== 'anthropic' &&
|
||||||
const { provider, modelEndpointType } = m
|
(modelInfo.provider?.anthropicApiHost === undefined || modelInfo.provider.anthropicApiHost.trim() === '')) ||
|
||||||
if (!provider) return false
|
modelInfo.provider.apiKey === ''
|
||||||
if (isEmpty(provider.apiKey?.trim())) return false
|
) {
|
||||||
|
|
||||||
const isAnthropicType = provider.type === 'anthropic'
|
|
||||||
const isAnthropicEndpoint = modelEndpointType === 'anthropic'
|
|
||||||
const hasValidApiHost = !isEmpty(provider.anthropicApiHost?.trim())
|
|
||||||
|
|
||||||
return !(!isAnthropicType && !isAnthropicEndpoint && !hasValidApiHost)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!modelInfo.provider || !validateModelInfo(modelInfo)) {
|
|
||||||
logger.error('Anthropic provider configuration is missing', {
|
logger.error('Anthropic provider configuration is missing', {
|
||||||
modelInfo
|
modelInfo
|
||||||
})
|
})
|
||||||
|
|||||||
@ -17,7 +17,7 @@ import {
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import type { Selection } from '@react-types/shared'
|
import type { Selection } from '@react-types/shared'
|
||||||
import ClaudeIcon from '@renderer/assets/images/models/claude.png'
|
import ClaudeIcon from '@renderer/assets/images/models/claude.png'
|
||||||
import { getModelLogo } from '@renderer/config/models'
|
import { agentModelFilter, getModelLogo } from '@renderer/config/models'
|
||||||
import { permissionModeCards } from '@renderer/constants/permissionModes'
|
import { permissionModeCards } from '@renderer/constants/permissionModes'
|
||||||
import { useAgents } from '@renderer/hooks/agents/useAgents'
|
import { useAgents } from '@renderer/hooks/agents/useAgents'
|
||||||
import { useApiModels } from '@renderer/hooks/agents/useModels'
|
import { useApiModels } from '@renderer/hooks/agents/useModels'
|
||||||
@ -100,7 +100,7 @@ export const AgentModal: React.FC<Props> = ({ agent, trigger, isOpen: _isOpen, o
|
|||||||
const { addAgent } = useAgents()
|
const { addAgent } = useAgents()
|
||||||
const { updateAgent } = useUpdateAgent()
|
const { updateAgent } = useUpdateAgent()
|
||||||
// hard-coded. We only support anthropic for now.
|
// hard-coded. We only support anthropic for now.
|
||||||
const { models } = useApiModels({ supportAnthropic: true })
|
const { models } = useApiModels({ providerType: 'anthropic' })
|
||||||
const isEditing = (agent?: AgentWithTools) => agent !== undefined
|
const isEditing = (agent?: AgentWithTools) => agent !== undefined
|
||||||
|
|
||||||
const [form, setForm] = useState<BaseAgentForm>(() => buildAgentForm(agent))
|
const [form, setForm] = useState<BaseAgentForm>(() => buildAgentForm(agent))
|
||||||
@ -245,7 +245,16 @@ export const AgentModal: React.FC<Props> = ({ agent, trigger, isOpen: _isOpen, o
|
|||||||
|
|
||||||
const modelOptions = useMemo(() => {
|
const modelOptions = useMemo(() => {
|
||||||
// mocked data. not final version
|
// mocked data. not final version
|
||||||
return (models ?? []).map((model) => ({
|
return (models ?? [])
|
||||||
|
.filter((m) =>
|
||||||
|
agentModelFilter({
|
||||||
|
id: m.id,
|
||||||
|
provider: m.provider || '',
|
||||||
|
name: m.name,
|
||||||
|
group: ''
|
||||||
|
})
|
||||||
|
)
|
||||||
|
.map((model) => ({
|
||||||
type: 'model',
|
type: 'model',
|
||||||
key: model.id,
|
key: model.id,
|
||||||
label: model.name,
|
label: model.name,
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding'
|
||||||
import { Model } from '@renderer/types'
|
import { Model } from '@renderer/types'
|
||||||
import { getLowerBaseModelName } from '@renderer/utils'
|
import { getLowerBaseModelName } from '@renderer/utils'
|
||||||
import OpenAI from 'openai'
|
import OpenAI from 'openai'
|
||||||
@ -5,7 +6,7 @@ import OpenAI from 'openai'
|
|||||||
import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from '../prompts'
|
import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from '../prompts'
|
||||||
import { getWebSearchTools } from '../tools'
|
import { getWebSearchTools } from '../tools'
|
||||||
import { isOpenAIReasoningModel } from './reasoning'
|
import { isOpenAIReasoningModel } from './reasoning'
|
||||||
import { isGenerateImageModel, isVisionModel } from './vision'
|
import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision'
|
||||||
import { isOpenAIWebSearchChatCompletionOnlyModel } from './websearch'
|
import { isOpenAIWebSearchChatCompletionOnlyModel } from './websearch'
|
||||||
export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i
|
export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i
|
||||||
|
|
||||||
@ -246,3 +247,7 @@ export const isOpenAIOpenWeightModel = (model: Model) => {
|
|||||||
|
|
||||||
// zhipu 视觉推理模型用这组 special token 标记推理结果
|
// zhipu 视觉推理模型用这组 special token 标记推理结果
|
||||||
export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const
|
export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const
|
||||||
|
|
||||||
|
export const agentModelFilter = (model: Model): boolean => {
|
||||||
|
return !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model)
|
||||||
|
}
|
||||||
|
|||||||
@ -58,7 +58,6 @@ import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png'
|
|||||||
import {
|
import {
|
||||||
AtLeast,
|
AtLeast,
|
||||||
isSystemProvider,
|
isSystemProvider,
|
||||||
Model,
|
|
||||||
OpenAIServiceTiers,
|
OpenAIServiceTiers,
|
||||||
Provider,
|
Provider,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
@ -88,6 +87,7 @@ export const SYSTEM_PROVIDERS_CONFIG: Record<SystemProviderId, SystemProvider> =
|
|||||||
type: 'openai',
|
type: 'openai',
|
||||||
apiKey: '',
|
apiKey: '',
|
||||||
apiHost: 'https://open.cherryin.net',
|
apiHost: 'https://open.cherryin.net',
|
||||||
|
anthropicApiHost: 'https://open.cherryin.net',
|
||||||
models: [],
|
models: [],
|
||||||
isSystem: true,
|
isSystem: true,
|
||||||
enabled: true
|
enabled: true
|
||||||
@ -109,7 +109,6 @@ export const SYSTEM_PROVIDERS_CONFIG: Record<SystemProviderId, SystemProvider> =
|
|||||||
apiKey: '',
|
apiKey: '',
|
||||||
apiHost: 'https://aihubmix.com',
|
apiHost: 'https://aihubmix.com',
|
||||||
anthropicApiHost: 'https://aihubmix.com/anthropic',
|
anthropicApiHost: 'https://aihubmix.com/anthropic',
|
||||||
isAnthropicModel: (m: Model) => m.id.includes('claude'),
|
|
||||||
models: SYSTEM_MODELS.aihubmix,
|
models: SYSTEM_MODELS.aihubmix,
|
||||||
isSystem: true,
|
isSystem: true,
|
||||||
enabled: false
|
enabled: false
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
import { Button } from '@heroui/react'
|
import { Button } from '@heroui/react'
|
||||||
import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
|
import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
|
||||||
import { SelectApiModelPopup } from '@renderer/components/Popups/SelectModelPopup'
|
import { SelectApiModelPopup } from '@renderer/components/Popups/SelectModelPopup'
|
||||||
import { isEmbeddingModel, isRerankModel, isTextToImageModel } from '@renderer/config/models'
|
import { agentModelFilter } from '@renderer/config/models'
|
||||||
import { useApiModel } from '@renderer/hooks/agents/useModel'
|
import { useApiModel } from '@renderer/hooks/agents/useModel'
|
||||||
import { getProviderNameById } from '@renderer/services/ProviderService'
|
import { getProviderNameById } from '@renderer/services/ProviderService'
|
||||||
import { AgentBaseWithId, ApiModel, isAgentEntity, Model } from '@renderer/types'
|
import { AgentBaseWithId, ApiModel, isAgentEntity } from '@renderer/types'
|
||||||
import { getModelFilterByAgentType } from '@renderer/utils/agentSession'
|
import { getModelFilterByAgentType } from '@renderer/utils/agentSession'
|
||||||
import { apiModelAdapter } from '@renderer/utils/model'
|
import { apiModelAdapter } from '@renderer/utils/model'
|
||||||
import { ChevronsUpDown } from 'lucide-react'
|
import { ChevronsUpDown } from 'lucide-react'
|
||||||
@ -22,12 +22,11 @@ const SelectAgentBaseModelButton: FC<Props> = ({ agentBase: agent, onSelect, isD
|
|||||||
const model = useApiModel({ id: agent?.model })
|
const model = useApiModel({ id: agent?.model })
|
||||||
|
|
||||||
const apiFilter = isAgentEntity(agent) ? getModelFilterByAgentType(agent.type) : undefined
|
const apiFilter = isAgentEntity(agent) ? getModelFilterByAgentType(agent.type) : undefined
|
||||||
const modelFilter = (model: Model) => !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model)
|
|
||||||
|
|
||||||
if (!agent) return null
|
if (!agent) return null
|
||||||
|
|
||||||
const onSelectModel = async () => {
|
const onSelectModel = async () => {
|
||||||
const selectedModel = await SelectApiModelPopup.show({ model, apiFilter: apiFilter, modelFilter })
|
const selectedModel = await SelectApiModelPopup.show({ model, apiFilter: apiFilter, modelFilter: agentModelFilter })
|
||||||
if (selectedModel && selectedModel.id !== agent.model) {
|
if (selectedModel && selectedModel.id !== agent.model) {
|
||||||
onSelect(selectedModel)
|
onSelect(selectedModel)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -65,7 +65,7 @@ const persistedReducer = persistReducer(
|
|||||||
{
|
{
|
||||||
key: 'cherry-studio',
|
key: 'cherry-studio',
|
||||||
storage,
|
storage,
|
||||||
version: 162,
|
version: 163,
|
||||||
blacklist: ['runtime', 'messages', 'messageBlocks', 'tabs'],
|
blacklist: ['runtime', 'messages', 'messageBlocks', 'tabs'],
|
||||||
migrate
|
migrate
|
||||||
},
|
},
|
||||||
|
|||||||
@ -2671,6 +2671,11 @@ const migrateConfig = {
|
|||||||
'163': (state: RootState) => {
|
'163': (state: RootState) => {
|
||||||
try {
|
try {
|
||||||
addOcrProvider(state, BUILTIN_OCR_PROVIDERS_MAP.ovocr)
|
addOcrProvider(state, BUILTIN_OCR_PROVIDERS_MAP.ovocr)
|
||||||
|
state.llm.providers.forEach((provider) => {
|
||||||
|
if (provider.id === 'cherryin') {
|
||||||
|
provider.anthropicApiHost = 'https://open.cherryin.net'
|
||||||
|
}
|
||||||
|
})
|
||||||
return state
|
return state
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('migrate 163 error', error as Error)
|
logger.error('migrate 163 error', error as Error)
|
||||||
|
|||||||
@ -6,7 +6,6 @@ import { ProviderTypeSchema } from './provider'
|
|||||||
// Request schema for /v1/models
|
// Request schema for /v1/models
|
||||||
export const ApiModelsFilterSchema = z.object({
|
export const ApiModelsFilterSchema = z.object({
|
||||||
providerType: ProviderTypeSchema.optional(),
|
providerType: ProviderTypeSchema.optional(),
|
||||||
supportAnthropic: z.coerce.boolean().optional(),
|
|
||||||
offset: z.coerce.number().min(0).default(0).optional(),
|
offset: z.coerce.number().min(0).default(0).optional(),
|
||||||
limit: z.coerce.number().min(1).default(20).optional()
|
limit: z.coerce.number().min(1).default(20).optional()
|
||||||
})
|
})
|
||||||
|
|||||||
@ -18,7 +18,7 @@ export const getModelFilterByAgentType = (type: AgentType): ApiModelsFilter => {
|
|||||||
switch (type) {
|
switch (type) {
|
||||||
case 'claude-code':
|
case 'claude-code':
|
||||||
return {
|
return {
|
||||||
supportAnthropic: true
|
providerType: 'anthropic'
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return {}
|
return {}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user