feat: Improve message filtering across providers

- Add new `filterUserRoleStartMessages` function in MessagesService
- Update Anthropic, Gemini, and OpenAI providers to use new message filtering
- Refactor message handling to ensure user messages start the conversation
- Remove redundant message filtering logic from individual providers
This commit is contained in:
kangfenmao 2025-03-05 11:11:02 +08:00
parent 452b0df476
commit 0b9e9c4c13
4 changed files with 47 additions and 66 deletions

View File

@ -6,7 +6,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
import { EVENT_NAMES } from '@renderer/services/EventService'
import { filterContextMessages } from '@renderer/services/MessagesService'
import { filterContextMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService'
import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
import { removeSpecialCharacters } from '@renderer/utils'
import { first, flatten, sum, takeRight } from 'lodash'
@ -124,7 +124,7 @@ export default class AnthropicProvider extends BaseProvider {
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
const userMessagesParams: MessageParam[] = []
const _messages = filterContextMessages(takeRight(messages, contextCount + 2))
const _messages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
onFilterMessages(_messages)
@ -134,10 +134,6 @@ export default class AnthropicProvider extends BaseProvider {
const userMessages = flatten(userMessagesParams)
if (first(userMessages)?.role === 'assistant') {
userMessages.shift()
}
const body: MessageCreateParamsNonStreaming = {
model: model.id,
messages: userMessages,
@ -145,12 +141,9 @@ export default class AnthropicProvider extends BaseProvider {
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
system: assistant.prompt,
...this.getCustomParameters(assistant)
}
if (isReasoningModel(model)) {
// @ts-ignore thinking
body.thinking = this.getReasoningEffort(assistant, model)
thinking: this.getReasoningEffort(assistant, model),
...this.getCustomParameters(assistant)
}
let time_first_token_millsec = 0

View File

@ -15,11 +15,11 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
import { EVENT_NAMES } from '@renderer/services/EventService'
import { filterContextMessages } from '@renderer/services/MessagesService'
import { filterContextMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService'
import { Assistant, FileType, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
import { removeSpecialCharacters } from '@renderer/utils'
import axios from 'axios'
import { first, isEmpty, takeRight } from 'lodash'
import { isEmpty, takeRight } from 'lodash'
import OpenAI from 'openai'
import { CompletionsParams } from '.'
@ -146,13 +146,9 @@ export default class GeminiProvider extends BaseProvider {
const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
const userMessages = filterContextMessages(takeRight(messages, contextCount + 2))
const userMessages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
onFilterMessages(userMessages)
if (first(userMessages)?.role === 'assistant') {
userMessages.shift()
}
const userLastMessage = userMessages.pop()
const history: Content[] = []

View File

@ -10,7 +10,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
import { EVENT_NAMES } from '@renderer/services/EventService'
import { filterContextMessages } from '@renderer/services/MessagesService'
import { filterContextMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService'
import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
import { removeSpecialCharacters } from '@renderer/utils'
import { takeRight } from 'lodash'
@ -213,18 +213,6 @@ export default class OpenAIProvider extends BaseProvider {
return model.id.startsWith('o1')
}
private isForceUserMessageStart(model: Model) {
if (model.id === 'deepseek-reasoner') {
return true
}
if (model.provider === 'xirang') {
return true
}
return false
}
async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
@ -241,15 +229,9 @@ export default class OpenAIProvider extends BaseProvider {
const userMessages: ChatCompletionMessageParam[] = []
const _messages = filterContextMessages(takeRight(messages, contextCount + 1))
const _messages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 1)))
onFilterMessages(_messages)
if (this.isForceUserMessageStart(model)) {
if (_messages[0]?.role !== 'user') {
userMessages.push({ role: 'user', content: '' })
}
}
for (const message of _messages) {
userMessages.push(await this.getMessageParam(message, model))
}

View File

@ -29,6 +29,44 @@ export function filterContextMessages(messages: Message[]): Message[] {
return messages.slice(clearIndex + 1)
}
export function filterUserRoleStartMessages(messages: Message[]): Message[] {
const firstUserMessageIndex = messages.findIndex((message) => message.role === 'user')
if (firstUserMessageIndex === -1) {
return messages
}
return messages.slice(firstUserMessageIndex)
}
export function filterUsefulMessages(messages: Message[]): Message[] {
const _messages = messages
const groupedMessages = getGroupedMessages(messages)
Object.entries(groupedMessages).forEach(([key, messages]) => {
if (key.startsWith('assistant')) {
const usefulMessage = messages.find((m) => m.useful === true)
if (usefulMessage) {
messages.forEach((m) => {
if (m.id !== usefulMessage.id) {
remove(_messages, (o) => o.id === m.id)
}
})
} else {
messages?.slice(0, -1).forEach((m) => {
remove(_messages, (o) => o.id === m.id)
})
}
}
})
while (_messages.length > 0 && _messages[_messages.length - 1].role === 'assistant') {
_messages.pop()
}
return _messages
}
export function getContextCount(assistant: Assistant, messages: Message[]) {
const contextCount = assistant?.settings?.contextCount ?? DEFAULT_CONTEXTCOUNT
const _messages = takeRight(messages, contextCount)
@ -111,34 +149,6 @@ export function getAssistantMessage({ assistant, topic }: { assistant: Assistant
}
}
export function filterUsefulMessages(messages: Message[]): Message[] {
const _messages = messages
const groupedMessages = getGroupedMessages(messages)
Object.entries(groupedMessages).forEach(([key, messages]) => {
if (key.startsWith('assistant')) {
const usefulMessage = messages.find((m) => m.useful === true)
if (usefulMessage) {
messages.forEach((m) => {
if (m.id !== usefulMessage.id) {
remove(_messages, (o) => o.id === m.id)
}
})
} else {
messages?.slice(0, -1).forEach((m) => {
remove(_messages, (o) => o.id === m.id)
})
}
}
})
while (_messages.length > 0 && _messages[_messages.length - 1].role === 'assistant') {
_messages.pop()
}
return _messages
}
export function getGroupedMessages(messages: Message[]): { [key: string]: (Message & { index: number })[] } {
const groups: { [key: string]: (Message & { index: number })[] } = {}
messages.forEach((message, index) => {