mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-25 03:10:08 +08:00
feat: Improve message filtering across providers
- Add new `filterUserRoleStartMessages` function in MessagesService - Update Anthropic, Gemini, and OpenAI providers to use new message filtering - Refactor message handling to ensure user messages start the conversation - Remove redundant message filtering logic from individual providers
This commit is contained in:
parent
452b0df476
commit
0b9e9c4c13
@ -6,7 +6,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
||||
import { filterContextMessages } from '@renderer/services/MessagesService'
|
||||
import { filterContextMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService'
|
||||
import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||
import { removeSpecialCharacters } from '@renderer/utils'
|
||||
import { first, flatten, sum, takeRight } from 'lodash'
|
||||
@ -124,7 +124,7 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||
|
||||
const userMessagesParams: MessageParam[] = []
|
||||
const _messages = filterContextMessages(takeRight(messages, contextCount + 2))
|
||||
const _messages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
|
||||
|
||||
onFilterMessages(_messages)
|
||||
|
||||
@ -134,10 +134,6 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
|
||||
const userMessages = flatten(userMessagesParams)
|
||||
|
||||
if (first(userMessages)?.role === 'assistant') {
|
||||
userMessages.shift()
|
||||
}
|
||||
|
||||
const body: MessageCreateParamsNonStreaming = {
|
||||
model: model.id,
|
||||
messages: userMessages,
|
||||
@ -145,12 +141,9 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
top_p: this.getTopP(assistant, model),
|
||||
system: assistant.prompt,
|
||||
...this.getCustomParameters(assistant)
|
||||
}
|
||||
|
||||
if (isReasoningModel(model)) {
|
||||
// @ts-ignore thinking
|
||||
body.thinking = this.getReasoningEffort(assistant, model)
|
||||
thinking: this.getReasoningEffort(assistant, model),
|
||||
...this.getCustomParameters(assistant)
|
||||
}
|
||||
|
||||
let time_first_token_millsec = 0
|
||||
|
||||
@ -15,11 +15,11 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
||||
import { filterContextMessages } from '@renderer/services/MessagesService'
|
||||
import { filterContextMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService'
|
||||
import { Assistant, FileType, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||
import { removeSpecialCharacters } from '@renderer/utils'
|
||||
import axios from 'axios'
|
||||
import { first, isEmpty, takeRight } from 'lodash'
|
||||
import { isEmpty, takeRight } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CompletionsParams } from '.'
|
||||
@ -146,13 +146,9 @@ export default class GeminiProvider extends BaseProvider {
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||
|
||||
const userMessages = filterContextMessages(takeRight(messages, contextCount + 2))
|
||||
const userMessages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
|
||||
onFilterMessages(userMessages)
|
||||
|
||||
if (first(userMessages)?.role === 'assistant') {
|
||||
userMessages.shift()
|
||||
}
|
||||
|
||||
const userLastMessage = userMessages.pop()
|
||||
|
||||
const history: Content[] = []
|
||||
|
||||
@ -10,7 +10,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
||||
import { filterContextMessages } from '@renderer/services/MessagesService'
|
||||
import { filterContextMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService'
|
||||
import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||
import { removeSpecialCharacters } from '@renderer/utils'
|
||||
import { takeRight } from 'lodash'
|
||||
@ -213,18 +213,6 @@ export default class OpenAIProvider extends BaseProvider {
|
||||
return model.id.startsWith('o1')
|
||||
}
|
||||
|
||||
private isForceUserMessageStart(model: Model) {
|
||||
if (model.id === 'deepseek-reasoner') {
|
||||
return true
|
||||
}
|
||||
|
||||
if (model.provider === 'xirang') {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
@ -241,15 +229,9 @@ export default class OpenAIProvider extends BaseProvider {
|
||||
|
||||
const userMessages: ChatCompletionMessageParam[] = []
|
||||
|
||||
const _messages = filterContextMessages(takeRight(messages, contextCount + 1))
|
||||
const _messages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 1)))
|
||||
onFilterMessages(_messages)
|
||||
|
||||
if (this.isForceUserMessageStart(model)) {
|
||||
if (_messages[0]?.role !== 'user') {
|
||||
userMessages.push({ role: 'user', content: '' })
|
||||
}
|
||||
}
|
||||
|
||||
for (const message of _messages) {
|
||||
userMessages.push(await this.getMessageParam(message, model))
|
||||
}
|
||||
|
||||
@ -29,6 +29,44 @@ export function filterContextMessages(messages: Message[]): Message[] {
|
||||
return messages.slice(clearIndex + 1)
|
||||
}
|
||||
|
||||
export function filterUserRoleStartMessages(messages: Message[]): Message[] {
|
||||
const firstUserMessageIndex = messages.findIndex((message) => message.role === 'user')
|
||||
|
||||
if (firstUserMessageIndex === -1) {
|
||||
return messages
|
||||
}
|
||||
|
||||
return messages.slice(firstUserMessageIndex)
|
||||
}
|
||||
|
||||
export function filterUsefulMessages(messages: Message[]): Message[] {
|
||||
const _messages = messages
|
||||
const groupedMessages = getGroupedMessages(messages)
|
||||
|
||||
Object.entries(groupedMessages).forEach(([key, messages]) => {
|
||||
if (key.startsWith('assistant')) {
|
||||
const usefulMessage = messages.find((m) => m.useful === true)
|
||||
if (usefulMessage) {
|
||||
messages.forEach((m) => {
|
||||
if (m.id !== usefulMessage.id) {
|
||||
remove(_messages, (o) => o.id === m.id)
|
||||
}
|
||||
})
|
||||
} else {
|
||||
messages?.slice(0, -1).forEach((m) => {
|
||||
remove(_messages, (o) => o.id === m.id)
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
while (_messages.length > 0 && _messages[_messages.length - 1].role === 'assistant') {
|
||||
_messages.pop()
|
||||
}
|
||||
|
||||
return _messages
|
||||
}
|
||||
|
||||
export function getContextCount(assistant: Assistant, messages: Message[]) {
|
||||
const contextCount = assistant?.settings?.contextCount ?? DEFAULT_CONTEXTCOUNT
|
||||
const _messages = takeRight(messages, contextCount)
|
||||
@ -111,34 +149,6 @@ export function getAssistantMessage({ assistant, topic }: { assistant: Assistant
|
||||
}
|
||||
}
|
||||
|
||||
export function filterUsefulMessages(messages: Message[]): Message[] {
|
||||
const _messages = messages
|
||||
const groupedMessages = getGroupedMessages(messages)
|
||||
|
||||
Object.entries(groupedMessages).forEach(([key, messages]) => {
|
||||
if (key.startsWith('assistant')) {
|
||||
const usefulMessage = messages.find((m) => m.useful === true)
|
||||
if (usefulMessage) {
|
||||
messages.forEach((m) => {
|
||||
if (m.id !== usefulMessage.id) {
|
||||
remove(_messages, (o) => o.id === m.id)
|
||||
}
|
||||
})
|
||||
} else {
|
||||
messages?.slice(0, -1).forEach((m) => {
|
||||
remove(_messages, (o) => o.id === m.id)
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
while (_messages.length > 0 && _messages[_messages.length - 1].role === 'assistant') {
|
||||
_messages.pop()
|
||||
}
|
||||
|
||||
return _messages
|
||||
}
|
||||
|
||||
export function getGroupedMessages(messages: Message[]): { [key: string]: (Message & { index: number })[] } {
|
||||
const groups: { [key: string]: (Message & { index: number })[] } = {}
|
||||
messages.forEach((message, index) => {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user