mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-06 21:35:52 +08:00
feat: Improve message filtering across providers
- Add new `filterUserRoleStartMessages` function in MessagesService - Update Anthropic, Gemini, and OpenAI providers to use new message filtering - Refactor message handling to ensure user messages start the conversation - Remove redundant message filtering logic from individual providers
This commit is contained in:
parent
452b0df476
commit
0b9e9c4c13
@ -6,7 +6,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
|
|||||||
import i18n from '@renderer/i18n'
|
import i18n from '@renderer/i18n'
|
||||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
import { EVENT_NAMES } from '@renderer/services/EventService'
|
||||||
import { filterContextMessages } from '@renderer/services/MessagesService'
|
import { filterContextMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService'
|
||||||
import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
|
import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
import { removeSpecialCharacters } from '@renderer/utils'
|
import { removeSpecialCharacters } from '@renderer/utils'
|
||||||
import { first, flatten, sum, takeRight } from 'lodash'
|
import { first, flatten, sum, takeRight } from 'lodash'
|
||||||
@ -124,7 +124,7 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||||
|
|
||||||
const userMessagesParams: MessageParam[] = []
|
const userMessagesParams: MessageParam[] = []
|
||||||
const _messages = filterContextMessages(takeRight(messages, contextCount + 2))
|
const _messages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
|
||||||
|
|
||||||
onFilterMessages(_messages)
|
onFilterMessages(_messages)
|
||||||
|
|
||||||
@ -134,10 +134,6 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
|
|
||||||
const userMessages = flatten(userMessagesParams)
|
const userMessages = flatten(userMessagesParams)
|
||||||
|
|
||||||
if (first(userMessages)?.role === 'assistant') {
|
|
||||||
userMessages.shift()
|
|
||||||
}
|
|
||||||
|
|
||||||
const body: MessageCreateParamsNonStreaming = {
|
const body: MessageCreateParamsNonStreaming = {
|
||||||
model: model.id,
|
model: model.id,
|
||||||
messages: userMessages,
|
messages: userMessages,
|
||||||
@ -145,12 +141,9 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
temperature: this.getTemperature(assistant, model),
|
temperature: this.getTemperature(assistant, model),
|
||||||
top_p: this.getTopP(assistant, model),
|
top_p: this.getTopP(assistant, model),
|
||||||
system: assistant.prompt,
|
system: assistant.prompt,
|
||||||
...this.getCustomParameters(assistant)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isReasoningModel(model)) {
|
|
||||||
// @ts-ignore thinking
|
// @ts-ignore thinking
|
||||||
body.thinking = this.getReasoningEffort(assistant, model)
|
thinking: this.getReasoningEffort(assistant, model),
|
||||||
|
...this.getCustomParameters(assistant)
|
||||||
}
|
}
|
||||||
|
|
||||||
let time_first_token_millsec = 0
|
let time_first_token_millsec = 0
|
||||||
|
|||||||
@ -15,11 +15,11 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
|
|||||||
import i18n from '@renderer/i18n'
|
import i18n from '@renderer/i18n'
|
||||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
import { EVENT_NAMES } from '@renderer/services/EventService'
|
||||||
import { filterContextMessages } from '@renderer/services/MessagesService'
|
import { filterContextMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService'
|
||||||
import { Assistant, FileType, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
|
import { Assistant, FileType, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
import { removeSpecialCharacters } from '@renderer/utils'
|
import { removeSpecialCharacters } from '@renderer/utils'
|
||||||
import axios from 'axios'
|
import axios from 'axios'
|
||||||
import { first, isEmpty, takeRight } from 'lodash'
|
import { isEmpty, takeRight } from 'lodash'
|
||||||
import OpenAI from 'openai'
|
import OpenAI from 'openai'
|
||||||
|
|
||||||
import { CompletionsParams } from '.'
|
import { CompletionsParams } from '.'
|
||||||
@ -146,13 +146,9 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
const model = assistant.model || defaultModel
|
const model = assistant.model || defaultModel
|
||||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||||
|
|
||||||
const userMessages = filterContextMessages(takeRight(messages, contextCount + 2))
|
const userMessages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
|
||||||
onFilterMessages(userMessages)
|
onFilterMessages(userMessages)
|
||||||
|
|
||||||
if (first(userMessages)?.role === 'assistant') {
|
|
||||||
userMessages.shift()
|
|
||||||
}
|
|
||||||
|
|
||||||
const userLastMessage = userMessages.pop()
|
const userLastMessage = userMessages.pop()
|
||||||
|
|
||||||
const history: Content[] = []
|
const history: Content[] = []
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
|
|||||||
import i18n from '@renderer/i18n'
|
import i18n from '@renderer/i18n'
|
||||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
import { EVENT_NAMES } from '@renderer/services/EventService'
|
||||||
import { filterContextMessages } from '@renderer/services/MessagesService'
|
import { filterContextMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService'
|
||||||
import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
|
import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
import { removeSpecialCharacters } from '@renderer/utils'
|
import { removeSpecialCharacters } from '@renderer/utils'
|
||||||
import { takeRight } from 'lodash'
|
import { takeRight } from 'lodash'
|
||||||
@ -213,18 +213,6 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
return model.id.startsWith('o1')
|
return model.id.startsWith('o1')
|
||||||
}
|
}
|
||||||
|
|
||||||
private isForceUserMessageStart(model: Model) {
|
|
||||||
if (model.id === 'deepseek-reasoner') {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if (model.provider === 'xirang') {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
|
async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
|
||||||
const defaultModel = getDefaultModel()
|
const defaultModel = getDefaultModel()
|
||||||
const model = assistant.model || defaultModel
|
const model = assistant.model || defaultModel
|
||||||
@ -241,15 +229,9 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
|
|
||||||
const userMessages: ChatCompletionMessageParam[] = []
|
const userMessages: ChatCompletionMessageParam[] = []
|
||||||
|
|
||||||
const _messages = filterContextMessages(takeRight(messages, contextCount + 1))
|
const _messages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 1)))
|
||||||
onFilterMessages(_messages)
|
onFilterMessages(_messages)
|
||||||
|
|
||||||
if (this.isForceUserMessageStart(model)) {
|
|
||||||
if (_messages[0]?.role !== 'user') {
|
|
||||||
userMessages.push({ role: 'user', content: '' })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const message of _messages) {
|
for (const message of _messages) {
|
||||||
userMessages.push(await this.getMessageParam(message, model))
|
userMessages.push(await this.getMessageParam(message, model))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -29,6 +29,44 @@ export function filterContextMessages(messages: Message[]): Message[] {
|
|||||||
return messages.slice(clearIndex + 1)
|
return messages.slice(clearIndex + 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function filterUserRoleStartMessages(messages: Message[]): Message[] {
|
||||||
|
const firstUserMessageIndex = messages.findIndex((message) => message.role === 'user')
|
||||||
|
|
||||||
|
if (firstUserMessageIndex === -1) {
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages.slice(firstUserMessageIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function filterUsefulMessages(messages: Message[]): Message[] {
|
||||||
|
const _messages = messages
|
||||||
|
const groupedMessages = getGroupedMessages(messages)
|
||||||
|
|
||||||
|
Object.entries(groupedMessages).forEach(([key, messages]) => {
|
||||||
|
if (key.startsWith('assistant')) {
|
||||||
|
const usefulMessage = messages.find((m) => m.useful === true)
|
||||||
|
if (usefulMessage) {
|
||||||
|
messages.forEach((m) => {
|
||||||
|
if (m.id !== usefulMessage.id) {
|
||||||
|
remove(_messages, (o) => o.id === m.id)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
messages?.slice(0, -1).forEach((m) => {
|
||||||
|
remove(_messages, (o) => o.id === m.id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
while (_messages.length > 0 && _messages[_messages.length - 1].role === 'assistant') {
|
||||||
|
_messages.pop()
|
||||||
|
}
|
||||||
|
|
||||||
|
return _messages
|
||||||
|
}
|
||||||
|
|
||||||
export function getContextCount(assistant: Assistant, messages: Message[]) {
|
export function getContextCount(assistant: Assistant, messages: Message[]) {
|
||||||
const contextCount = assistant?.settings?.contextCount ?? DEFAULT_CONTEXTCOUNT
|
const contextCount = assistant?.settings?.contextCount ?? DEFAULT_CONTEXTCOUNT
|
||||||
const _messages = takeRight(messages, contextCount)
|
const _messages = takeRight(messages, contextCount)
|
||||||
@ -111,34 +149,6 @@ export function getAssistantMessage({ assistant, topic }: { assistant: Assistant
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export function filterUsefulMessages(messages: Message[]): Message[] {
|
|
||||||
const _messages = messages
|
|
||||||
const groupedMessages = getGroupedMessages(messages)
|
|
||||||
|
|
||||||
Object.entries(groupedMessages).forEach(([key, messages]) => {
|
|
||||||
if (key.startsWith('assistant')) {
|
|
||||||
const usefulMessage = messages.find((m) => m.useful === true)
|
|
||||||
if (usefulMessage) {
|
|
||||||
messages.forEach((m) => {
|
|
||||||
if (m.id !== usefulMessage.id) {
|
|
||||||
remove(_messages, (o) => o.id === m.id)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
messages?.slice(0, -1).forEach((m) => {
|
|
||||||
remove(_messages, (o) => o.id === m.id)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
while (_messages.length > 0 && _messages[_messages.length - 1].role === 'assistant') {
|
|
||||||
_messages.pop()
|
|
||||||
}
|
|
||||||
|
|
||||||
return _messages
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getGroupedMessages(messages: Message[]): { [key: string]: (Message & { index: number })[] } {
|
export function getGroupedMessages(messages: Message[]): { [key: string]: (Message & { index: number })[] } {
|
||||||
const groups: { [key: string]: (Message & { index: number })[] } = {}
|
const groups: { [key: string]: (Message & { index: number })[] } = {}
|
||||||
messages.forEach((message, index) => {
|
messages.forEach((message, index) => {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user