mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-07 05:39:05 +08:00
fix: initialize messageContents and improve message handling in GeminiAPIClient; add new Gemini model to configuration (#7307)
* fix: initialize messageContents and improve message handling in GeminiAPIClient; add new Gemini model to configuration * refactor: streamline message handling in GeminiAPIClient; enhance message extraction from SDK payload
This commit is contained in:
parent
5ada8739bf
commit
e197a6f3e4
@ -466,7 +466,7 @@ export class GeminiAPIClient extends BaseApiClient<
|
|||||||
systemInstruction = await buildSystemPrompt(assistant.prompt || '', mcpTools, assistant)
|
systemInstruction = await buildSystemPrompt(assistant.prompt || '', mcpTools, assistant)
|
||||||
}
|
}
|
||||||
|
|
||||||
let messageContents: Content
|
let messageContents: Content = { role: 'user', parts: [] } // Initialize messageContents
|
||||||
const history: Content[] = []
|
const history: Content[] = []
|
||||||
// 3. 处理用户消息
|
// 3. 处理用户消息
|
||||||
if (typeof messages === 'string') {
|
if (typeof messages === 'string') {
|
||||||
@ -475,10 +475,12 @@ export class GeminiAPIClient extends BaseApiClient<
|
|||||||
parts: [{ text: messages }]
|
parts: [{ text: messages }]
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const userLastMessage = messages.pop()!
|
const userLastMessage = messages.pop()
|
||||||
messageContents = await this.convertMessageToSdkParam(userLastMessage)
|
if (userLastMessage) {
|
||||||
for (const message of messages) {
|
messageContents = await this.convertMessageToSdkParam(userLastMessage)
|
||||||
history.push(await this.convertMessageToSdkParam(message))
|
for (const message of messages) {
|
||||||
|
history.push(await this.convertMessageToSdkParam(message))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -491,6 +493,10 @@ export class GeminiAPIClient extends BaseApiClient<
|
|||||||
if (isGemmaModel(model) && assistant.prompt) {
|
if (isGemmaModel(model) && assistant.prompt) {
|
||||||
const isFirstMessage = history.length === 0
|
const isFirstMessage = history.length === 0
|
||||||
if (isFirstMessage && messageContents) {
|
if (isFirstMessage && messageContents) {
|
||||||
|
const userMessageText =
|
||||||
|
messageContents.parts && messageContents.parts.length > 0
|
||||||
|
? (messageContents.parts[0] as Part).text || ''
|
||||||
|
: ''
|
||||||
const systemMessage = [
|
const systemMessage = [
|
||||||
{
|
{
|
||||||
text:
|
text:
|
||||||
@ -498,7 +504,7 @@ export class GeminiAPIClient extends BaseApiClient<
|
|||||||
systemInstruction +
|
systemInstruction +
|
||||||
'<end_of_turn>\n' +
|
'<end_of_turn>\n' +
|
||||||
'<start_of_turn>user\n' +
|
'<start_of_turn>user\n' +
|
||||||
(messageContents?.parts?.[0] as Part).text +
|
userMessageText +
|
||||||
'<end_of_turn>'
|
'<end_of_turn>'
|
||||||
}
|
}
|
||||||
] as Part[]
|
] as Part[]
|
||||||
@ -515,13 +521,7 @@ export class GeminiAPIClient extends BaseApiClient<
|
|||||||
|
|
||||||
const newMessageContents =
|
const newMessageContents =
|
||||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||||
? {
|
? recursiveSdkMessages[recursiveSdkMessages.length - 1]
|
||||||
...messageContents,
|
|
||||||
parts: [
|
|
||||||
...(messageContents.parts || []),
|
|
||||||
...(recursiveSdkMessages[recursiveSdkMessages.length - 1].parts || [])
|
|
||||||
]
|
|
||||||
}
|
|
||||||
: messageContents
|
: messageContents
|
||||||
|
|
||||||
const generateContentConfig: GenerateContentConfig = {
|
const generateContentConfig: GenerateContentConfig = {
|
||||||
@ -555,7 +555,7 @@ export class GeminiAPIClient extends BaseApiClient<
|
|||||||
getResponseChunkTransformer(): ResponseChunkTransformer<GeminiSdkRawChunk> {
|
getResponseChunkTransformer(): ResponseChunkTransformer<GeminiSdkRawChunk> {
|
||||||
return () => ({
|
return () => ({
|
||||||
async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||||
let toolCalls: FunctionCall[] = []
|
const toolCalls: FunctionCall[] = []
|
||||||
if (chunk.candidates && chunk.candidates.length > 0) {
|
if (chunk.candidates && chunk.candidates.length > 0) {
|
||||||
for (const candidate of chunk.candidates) {
|
for (const candidate of chunk.candidates) {
|
||||||
if (candidate.content) {
|
if (candidate.content) {
|
||||||
@ -583,6 +583,8 @@ export class GeminiAPIClient extends BaseApiClient<
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
} else if (part.functionCall) {
|
||||||
|
toolCalls.push(part.functionCall)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -597,9 +599,6 @@ export class GeminiAPIClient extends BaseApiClient<
|
|||||||
}
|
}
|
||||||
} as LLMWebSearchCompleteChunk)
|
} as LLMWebSearchCompleteChunk)
|
||||||
}
|
}
|
||||||
if (chunk.functionCalls) {
|
|
||||||
toolCalls = toolCalls.concat(chunk.functionCalls)
|
|
||||||
}
|
|
||||||
controller.enqueue({
|
controller.enqueue({
|
||||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||||
response: {
|
response: {
|
||||||
@ -702,12 +701,11 @@ export class GeminiAPIClient extends BaseApiClient<
|
|||||||
.filter((p) => p !== undefined)
|
.filter((p) => p !== undefined)
|
||||||
)
|
)
|
||||||
|
|
||||||
const userMessage: Content = {
|
const lastMessage = currentReqMessages[currentReqMessages.length - 1]
|
||||||
role: 'user',
|
if (lastMessage) {
|
||||||
parts: parts
|
lastMessage.parts?.push(...parts)
|
||||||
}
|
}
|
||||||
|
return currentReqMessages
|
||||||
return [...currentReqMessages, userMessage]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override estimateMessageTokens(message: GeminiSdkMessageParam): number {
|
override estimateMessageTokens(message: GeminiSdkMessageParam): number {
|
||||||
@ -734,7 +732,20 @@ export class GeminiAPIClient extends BaseApiClient<
|
|||||||
}
|
}
|
||||||
|
|
||||||
public extractMessagesFromSdkPayload(sdkPayload: GeminiSdkParams): GeminiSdkMessageParam[] {
|
public extractMessagesFromSdkPayload(sdkPayload: GeminiSdkParams): GeminiSdkMessageParam[] {
|
||||||
return sdkPayload.history || []
|
const messageParam: GeminiSdkMessageParam = {
|
||||||
|
role: 'user',
|
||||||
|
parts: []
|
||||||
|
}
|
||||||
|
if (Array.isArray(sdkPayload.message)) {
|
||||||
|
sdkPayload.message.forEach((part) => {
|
||||||
|
if (typeof part === 'string') {
|
||||||
|
messageParam.parts?.push({ text: part })
|
||||||
|
} else if (typeof part === 'object') {
|
||||||
|
messageParam.parts?.push(part)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return [messageParam, ...(sdkPayload.history || [])]
|
||||||
}
|
}
|
||||||
|
|
||||||
private async uploadFile(file: FileType): Promise<File> {
|
private async uploadFile(file: FileType): Promise<File> {
|
||||||
|
|||||||
@ -2313,6 +2313,7 @@ export const GEMINI_SEARCH_MODELS = [
|
|||||||
'gemini-2.0-flash-lite',
|
'gemini-2.0-flash-lite',
|
||||||
'gemini-2.0-flash-exp',
|
'gemini-2.0-flash-exp',
|
||||||
'gemini-2.0-flash-001',
|
'gemini-2.0-flash-001',
|
||||||
|
'gemini-2.5-pro',
|
||||||
'gemini-2.0-pro-exp-02-05',
|
'gemini-2.0-pro-exp-02-05',
|
||||||
'gemini-2.0-pro-exp',
|
'gemini-2.0-pro-exp',
|
||||||
'gemini-2.5-pro-exp',
|
'gemini-2.5-pro-exp',
|
||||||
@ -2322,7 +2323,8 @@ export const GEMINI_SEARCH_MODELS = [
|
|||||||
'gemini-2.5-pro-preview-05-06',
|
'gemini-2.5-pro-preview-05-06',
|
||||||
'gemini-2.5-flash-preview',
|
'gemini-2.5-flash-preview',
|
||||||
'gemini-2.5-flash-preview-04-17',
|
'gemini-2.5-flash-preview-04-17',
|
||||||
'gemini-2.5-flash-preview-05-20'
|
'gemini-2.5-flash-preview-05-20',
|
||||||
|
'gemini-2.5-flash-lite-preview-06-17'
|
||||||
]
|
]
|
||||||
|
|
||||||
export const OPENAI_NO_SUPPORT_DEV_ROLE_MODELS = ['o1-preview', 'o1-mini']
|
export const OPENAI_NO_SUPPORT_DEV_ROLE_MODELS = ['o1-preview', 'o1-mini']
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user