From 0b9e9c4c136e974f0f4165248e137a34cc35536d Mon Sep 17 00:00:00 2001 From: kangfenmao Date: Wed, 5 Mar 2025 11:11:02 +0800 Subject: [PATCH] feat: Improve message filtering across providers - Add new `filterUserRoleStartMessages` function in MessagesService - Update Anthropic, Gemini, and OpenAI providers to use new message filtering - Refactor message handling to ensure user messages start the conversation - Remove redundant message filtering logic from individual providers --- .../src/providers/AnthropicProvider.ts | 15 ++--- src/renderer/src/providers/GeminiProvider.ts | 10 +-- src/renderer/src/providers/OpenAIProvider.ts | 22 +------ src/renderer/src/services/MessagesService.ts | 66 +++++++++++-------- 4 files changed, 47 insertions(+), 66 deletions(-) diff --git a/src/renderer/src/providers/AnthropicProvider.ts b/src/renderer/src/providers/AnthropicProvider.ts index 179a955acf..f2962c6d8b 100644 --- a/src/renderer/src/providers/AnthropicProvider.ts +++ b/src/renderer/src/providers/AnthropicProvider.ts @@ -6,7 +6,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { EVENT_NAMES } from '@renderer/services/EventService' -import { filterContextMessages } from '@renderer/services/MessagesService' +import { filterContextMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService' import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types' import { removeSpecialCharacters } from '@renderer/utils' import { first, flatten, sum, takeRight } from 'lodash' @@ -124,7 +124,7 @@ export default class AnthropicProvider extends BaseProvider { const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) const userMessagesParams: MessageParam[] = [] - const _messages = filterContextMessages(takeRight(messages, contextCount + 2)) + const _messages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 2))) onFilterMessages(_messages) @@ -134,10 +134,6 @@ export default class AnthropicProvider extends BaseProvider { const userMessages = flatten(userMessagesParams) - if (first(userMessages)?.role === 'assistant') { - userMessages.shift() - } - const body: MessageCreateParamsNonStreaming = { model: model.id, messages: userMessages, @@ -145,12 +141,9 @@ export default class AnthropicProvider extends BaseProvider { temperature: this.getTemperature(assistant, model), top_p: this.getTopP(assistant, model), system: assistant.prompt, - ...this.getCustomParameters(assistant) - } - - if (isReasoningModel(model)) { // @ts-ignore thinking - body.thinking = this.getReasoningEffort(assistant, model) + thinking: this.getReasoningEffort(assistant, model), + ...this.getCustomParameters(assistant) } let time_first_token_millsec = 0 diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index d712c3d978..be3a0304f6 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -15,11 +15,11 @@ import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { EVENT_NAMES } from '@renderer/services/EventService' -import { filterContextMessages } from '@renderer/services/MessagesService' +import { filterContextMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService' import { Assistant, FileType, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types' import { removeSpecialCharacters } from '@renderer/utils' import axios from 'axios' -import { first, isEmpty, takeRight } from 'lodash' +import { isEmpty, takeRight } from 'lodash' import OpenAI from 'openai' import { CompletionsParams } from '.' @@ -146,13 +146,9 @@ export default class GeminiProvider extends BaseProvider { const model = assistant.model || defaultModel const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) - const userMessages = filterContextMessages(takeRight(messages, contextCount + 2)) + const userMessages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 2))) onFilterMessages(userMessages) - if (first(userMessages)?.role === 'assistant') { - userMessages.shift() - } - const userLastMessage = userMessages.pop() const history: Content[] = [] diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index 044a9bc6b1..f7c5515c89 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -10,7 +10,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { EVENT_NAMES } from '@renderer/services/EventService' -import { filterContextMessages } from '@renderer/services/MessagesService' +import { filterContextMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService' import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types' import { removeSpecialCharacters } from '@renderer/utils' import { takeRight } from 'lodash' @@ -213,18 +213,6 @@ export default class OpenAIProvider extends BaseProvider { return model.id.startsWith('o1') } - private isForceUserMessageStart(model: Model) { - if (model.id === 'deepseek-reasoner') { - return true - } - - if (model.provider === 'xirang') { - return true - } - - return false - } - async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel @@ -241,15 +229,9 @@ export default class OpenAIProvider extends BaseProvider { const userMessages: ChatCompletionMessageParam[] = [] - const _messages = filterContextMessages(takeRight(messages, contextCount + 1)) + const _messages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 1))) onFilterMessages(_messages) - if (this.isForceUserMessageStart(model)) { - if (_messages[0]?.role !== 'user') { - userMessages.push({ role: 'user', content: '' }) - } - } - for (const message of _messages) { userMessages.push(await this.getMessageParam(message, model)) } diff --git a/src/renderer/src/services/MessagesService.ts b/src/renderer/src/services/MessagesService.ts index b8d8ccf335..967e48d838 100644 --- a/src/renderer/src/services/MessagesService.ts +++ b/src/renderer/src/services/MessagesService.ts @@ -29,6 +29,44 @@ export function filterContextMessages(messages: Message[]): Message[] { return messages.slice(clearIndex + 1) } +export function filterUserRoleStartMessages(messages: Message[]): Message[] { + const firstUserMessageIndex = messages.findIndex((message) => message.role === 'user') + + if (firstUserMessageIndex === -1) { + return messages + } + + return messages.slice(firstUserMessageIndex) +} + +export function filterUsefulMessages(messages: Message[]): Message[] { + const _messages = messages + const groupedMessages = getGroupedMessages(messages) + + Object.entries(groupedMessages).forEach(([key, messages]) => { + if (key.startsWith('assistant')) { + const usefulMessage = messages.find((m) => m.useful === true) + if (usefulMessage) { + messages.forEach((m) => { + if (m.id !== usefulMessage.id) { + remove(_messages, (o) => o.id === m.id) + } + }) + } else { + messages?.slice(0, -1).forEach((m) => { + remove(_messages, (o) => o.id === m.id) + }) + } + } + }) + + while (_messages.length > 0 && _messages[_messages.length - 1].role === 'assistant') { + _messages.pop() + } + + return _messages +} + export function getContextCount(assistant: Assistant, messages: Message[]) { const contextCount = assistant?.settings?.contextCount ?? DEFAULT_CONTEXTCOUNT const _messages = takeRight(messages, contextCount) @@ -111,34 +149,6 @@ export function getAssistantMessage({ assistant, topic }: { assistant: Assistant } } -export function filterUsefulMessages(messages: Message[]): Message[] { - const _messages = messages - const groupedMessages = getGroupedMessages(messages) - - Object.entries(groupedMessages).forEach(([key, messages]) => { - if (key.startsWith('assistant')) { - const usefulMessage = messages.find((m) => m.useful === true) - if (usefulMessage) { - messages.forEach((m) => { - if (m.id !== usefulMessage.id) { - remove(_messages, (o) => o.id === m.id) - } - }) - } else { - messages?.slice(0, -1).forEach((m) => { - remove(_messages, (o) => o.id === m.id) - }) - } - } - }) - - while (_messages.length > 0 && _messages[_messages.length - 1].role === 'assistant') { - _messages.pop() - } - - return _messages -} - export function getGroupedMessages(messages: Message[]): { [key: string]: (Message & { index: number })[] } { const groups: { [key: string]: (Message & { index: number })[] } = {} messages.forEach((message, index) => {