Skip to content

Commit cd4a807

Browse files
authored
Permit generic chat message with role (langchain-ai#2075)
* Permit generic chat message with role * Address code review notes, only warn if invalid roles found * Fix streaming * Fix for PaLM * Code review * Rename to `isInstance
1 parent db9f000 commit cd4a807

9 files changed

+162
-36
lines changed

langchain/src/chat_models/anthropic.ts

+21-5
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,38 @@ import {
88
BaseMessage,
99
ChatGeneration,
1010
ChatGenerationChunk,
11+
ChatMessage,
1112
ChatResult,
12-
MessageType,
1313
} from "../schema/index.js";
1414
import { getEnvironmentVariable } from "../util/env.js";
1515
import { BaseChatModel, BaseChatModelParams } from "./base.js";
1616

17-
function getAnthropicPromptFromMessage(type: MessageType): string {
17+
function extractGenericMessageCustomRole(message: ChatMessage) {
18+
if (
19+
message.role !== AI_PROMPT &&
20+
message.role !== HUMAN_PROMPT &&
21+
message.role !== ""
22+
) {
23+
console.warn(`Unknown message role: ${message.role}`);
24+
}
25+
26+
return message.role;
27+
}
28+
29+
function getAnthropicPromptFromMessage(message: BaseMessage): string {
30+
const type = message._getType();
1831
switch (type) {
1932
case "ai":
2033
return AI_PROMPT;
2134
case "human":
2235
return HUMAN_PROMPT;
2336
case "system":
2437
return "";
38+
case "generic": {
39+
if (!ChatMessage.isInstance(message))
40+
throw new Error("Invalid generic chat message");
41+
return extractGenericMessageCustomRole(message);
42+
}
2543
default:
2644
throw new Error(`Unknown message type: ${type}`);
2745
}
@@ -250,9 +268,7 @@ export class ChatAnthropic extends BaseChatModel implements AnthropicInput {
250268
return (
251269
messages
252270
.map((message) => {
253-
const messagePrompt = getAnthropicPromptFromMessage(
254-
message._getType()
255-
);
271+
const messagePrompt = getAnthropicPromptFromMessage(message);
256272
return `${messagePrompt} ${message.content}`;
257273
})
258274
.join("") + AI_PROMPT

langchain/src/chat_models/baiduwenxin.ts

+19-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ import {
33
AIMessage,
44
BaseMessage,
55
ChatGeneration,
6+
ChatMessage,
67
ChatResult,
7-
MessageType,
88
} from "../schema/index.js";
99
import { CallbackManagerForLLMRun } from "../callbacks/manager.js";
1010
import { getEnvironmentVariable } from "../util/env.js";
@@ -90,14 +90,30 @@ declare interface BaiduWenxinChatInput {
9090
penaltyScore?: number;
9191
}
9292

93-
function messageTypeToWenxinRole(type: MessageType): WenxinMessageRole {
93+
function extractGenericMessageCustomRole(message: ChatMessage) {
94+
if (message.role !== "assistant" && message.role !== "user") {
95+
console.warn(`Unknown message role: ${message.role}`);
96+
}
97+
98+
return message.role as WenxinMessageRole;
99+
}
100+
101+
function messageToWenxinRole(message: BaseMessage): WenxinMessageRole {
102+
const type = message._getType();
94103
switch (type) {
95104
case "ai":
96105
return "assistant";
97106
case "human":
98107
return "user";
99108
case "system":
100109
throw new Error("System messages not supported");
110+
case "function":
111+
throw new Error("Function messages not supported");
112+
case "generic": {
113+
if (!ChatMessage.isInstance(message))
114+
throw new Error("Invalid generic chat message");
115+
return extractGenericMessageCustomRole(message);
116+
}
101117
default:
102118
throw new Error(`Unknown message type: ${type}`);
103119
}
@@ -263,7 +279,7 @@ export class ChatBaiduWenxin
263279

264280
const params = this.invocationParams();
265281
const messagesMapped: WenxinMessage[] = messages.map((message) => ({
266-
role: messageTypeToWenxinRole(message._getType()),
282+
role: messageToWenxinRole(message),
267283
content: message.text,
268284
}));
269285

langchain/src/chat_models/googlepalm.ts

+22-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@ import { DiscussServiceClient } from "@google-ai/generativelanguage";
22
import type { protos } from "@google-ai/generativelanguage";
33
import { GoogleAuth } from "google-auth-library";
44
import { CallbackManagerForLLMRun } from "../callbacks/manager.js";
5-
import { AIMessage, BaseMessage, ChatResult } from "../schema/index.js";
5+
import {
6+
AIMessage,
7+
BaseMessage,
8+
ChatMessage,
9+
ChatResult,
10+
} from "../schema/index.js";
611
import { getEnvironmentVariable } from "../util/env.js";
712
import { BaseChatModel, BaseChatModelParams } from "./base.js";
813

@@ -60,6 +65,14 @@ export interface GooglePaLMChatInput extends BaseChatModelParams {
6065
apiKey?: string;
6166
}
6267

68+
function getMessageAuthor(message: BaseMessage) {
69+
const type = message._getType();
70+
if (ChatMessage.isInstance(message)) {
71+
return message.role;
72+
}
73+
return message.name ?? type;
74+
}
75+
6376
export class ChatGooglePaLM
6477
extends BaseChatModel
6578
implements GooglePaLMChatInput
@@ -175,7 +188,7 @@ export class ChatGooglePaLM
175188
): string | undefined {
176189
// get the first message and checks if it's a system 'system' messages
177190
const systemMessage =
178-
messages.length > 0 && messages[0]._getType() === "system"
191+
messages.length > 0 && getMessageAuthor(messages[0]) === "system"
179192
? messages[0]
180193
: undefined;
181194
return systemMessage?.content;
@@ -185,20 +198,24 @@ export class ChatGooglePaLM
185198
messages: BaseMessage[]
186199
): protos.google.ai.generativelanguage.v1beta2.IMessage[] {
187200
// remove all 'system' messages
188-
const nonSystemMessages = messages.filter((m) => m._getType() !== "system");
201+
const nonSystemMessages = messages.filter(
202+
(m) => getMessageAuthor(m) !== "system"
203+
);
189204

190205
// requires alternate human & ai messages. Throw error if two messages are consecutive
191206
nonSystemMessages.forEach((msg, index) => {
192207
if (index < 1) return;
193-
if (msg._getType() === nonSystemMessages[index - 1]._getType()) {
208+
if (
209+
getMessageAuthor(msg) === getMessageAuthor(nonSystemMessages[index - 1])
210+
) {
194211
throw new Error(
195212
`Google PaLM requires alternate messages between authors`
196213
);
197214
}
198215
});
199216

200217
return nonSystemMessages.map((m) => ({
201-
author: m.name ?? m._getType(),
218+
author: getMessageAuthor(m),
202219
content: m.content,
203220
citationMetadata: {
204221
citationSources: m.additional_kwargs.citationSources as

langchain/src/chat_models/googlevertexai.ts

+40-12
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ import {
33
AIMessage,
44
BaseMessage,
55
ChatGeneration,
6+
ChatMessage,
67
ChatResult,
78
LLMResult,
8-
MessageType,
99
} from "../schema/index.js";
1010
import { GoogleVertexAIConnection } from "../util/googlevertexai-connection.js";
1111
import {
@@ -54,11 +54,25 @@ export class GoogleVertexAIChatMessage {
5454
this.name = fields.name;
5555
}
5656

57+
static extractGenericMessageCustomRole(message: ChatMessage) {
58+
if (
59+
message.role !== "system" &&
60+
message.role !== "bot" &&
61+
message.role !== "user" &&
62+
message.role !== "context"
63+
) {
64+
console.warn(`Unknown message role: ${message.role}`);
65+
}
66+
67+
return message.role as GoogleVertexAIChatAuthor;
68+
}
69+
5770
static mapMessageTypeToVertexChatAuthor(
58-
baseMessageType: MessageType,
71+
message: BaseMessage,
5972
model: string
6073
): GoogleVertexAIChatAuthor {
61-
switch (baseMessageType) {
74+
const type = message._getType();
75+
switch (type) {
6276
case "ai":
6377
return model.startsWith("codechat-") ? "system" : "bot";
6478
case "human":
@@ -67,17 +81,22 @@ export class GoogleVertexAIChatMessage {
6781
throw new Error(
6882
`System messages are only supported as the first passed message for Google Vertex AI.`
6983
);
70-
default:
71-
throw new Error(
72-
`Unknown / unsupported message type: ${baseMessageType}`
84+
case "generic": {
85+
if (!ChatMessage.isInstance(message))
86+
throw new Error("Invalid generic chat message");
87+
return GoogleVertexAIChatMessage.extractGenericMessageCustomRole(
88+
message
7389
);
90+
}
91+
default:
92+
throw new Error(`Unknown / unsupported message type: ${message}`);
7493
}
7594
}
7695

7796
static fromChatMessage(message: BaseMessage, model: string) {
7897
return new GoogleVertexAIChatMessage({
7998
author: GoogleVertexAIChatMessage.mapMessageTypeToVertexChatAuthor(
80-
message._getType(),
99+
message,
81100
model
82101
),
83102
content: message.content,
@@ -211,16 +230,25 @@ export class ChatGoogleVertexAI
211230
);
212231
}
213232
const vertexChatMessages = conversationMessages.map((baseMessage, i) => {
233+
const currMessage = GoogleVertexAIChatMessage.fromChatMessage(
234+
baseMessage,
235+
this.model
236+
);
237+
const prevMessage =
238+
i > 0
239+
? GoogleVertexAIChatMessage.fromChatMessage(
240+
conversationMessages[i - 1],
241+
this.model
242+
)
243+
: null;
244+
214245
// https://cloud.google.com/vertex-ai/docs/generative-ai/chat/chat-prompts#messages
215-
if (
216-
i > 0 &&
217-
baseMessage._getType() === conversationMessages[i - 1]._getType()
218-
) {
246+
if (prevMessage && currMessage.author === prevMessage.author) {
219247
throw new Error(
220248
`Google Vertex AI requires AI and human messages to alternate.`
221249
);
222250
}
223-
return GoogleVertexAIChatMessage.fromChatMessage(baseMessage, this.model);
251+
return currMessage;
224252
});
225253

226254
const examples = this.examples.map((example) => ({

langchain/src/chat_models/openai.ts

+25-9
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ import {
3232
FunctionMessageChunk,
3333
HumanMessage,
3434
HumanMessageChunk,
35-
MessageType,
3635
SystemMessage,
3736
SystemMessageChunk,
3837
} from "../schema/index.js";
@@ -56,9 +55,23 @@ interface OpenAILLMOutput {
5655
tokenUsage: TokenUsage;
5756
}
5857

59-
function messageTypeToOpenAIRole(
60-
type: MessageType
58+
function extractGenericMessageCustomRole(message: ChatMessage) {
59+
if (
60+
message.role !== "system" &&
61+
message.role !== "assistant" &&
62+
message.role !== "user" &&
63+
message.role !== "function"
64+
) {
65+
console.warn(`Unknown message role: ${message.role}`);
66+
}
67+
68+
return message.role as ChatCompletionResponseMessageRoleEnum;
69+
}
70+
71+
function messageToOpenAIRole(
72+
message: BaseMessage
6173
): ChatCompletionResponseMessageRoleEnum {
74+
const type = message._getType();
6275
switch (type) {
6376
case "system":
6477
return "system";
@@ -68,6 +81,11 @@ function messageTypeToOpenAIRole(
6881
return "user";
6982
case "function":
7083
return "function";
84+
case "generic": {
85+
if (!ChatMessage.isInstance(message))
86+
throw new Error("Invalid generic chat message");
87+
return extractGenericMessageCustomRole(message);
88+
}
7189
default:
7290
throw new Error(`Unknown message type: ${type}`);
7391
}
@@ -340,7 +358,7 @@ export class ChatOpenAI
340358
): AsyncGenerator<ChatGenerationChunk> {
341359
const messagesMapped: ChatCompletionRequestMessage[] = messages.map(
342360
(message) => ({
343-
role: messageTypeToOpenAIRole(message._getType()),
361+
role: messageToOpenAIRole(message),
344362
content: message.content,
345363
name: message.name,
346364
function_call: message.additional_kwargs
@@ -455,7 +473,7 @@ export class ChatOpenAI
455473
const params = this.invocationParams(options);
456474
const messagesMapped: ChatCompletionRequestMessage[] = messages.map(
457475
(message) => ({
458-
role: messageTypeToOpenAIRole(message._getType()),
476+
role: messageToOpenAIRole(message),
459477
content: message.content,
460478
name: message.name,
461479
function_call: message.additional_kwargs
@@ -661,9 +679,7 @@ export class ChatOpenAI
661679
const countPerMessage = await Promise.all(
662680
messages.map(async (message) => {
663681
const textCount = await this.getNumTokens(message.content);
664-
const roleCount = await this.getNumTokens(
665-
messageTypeToOpenAIRole(message._getType())
666-
);
682+
const roleCount = await this.getNumTokens(messageToOpenAIRole(message));
667683
const nameCount =
668684
message.name !== undefined
669685
? tokensPerName + (await this.getNumTokens(message.name))
@@ -865,7 +881,7 @@ export class PromptLayerChatOpenAI extends ChatOpenAI {
865881
const parsedResp = [
866882
{
867883
content: generation.text,
868-
role: messageTypeToOpenAIRole(generation.message._getType()),
884+
role: messageToOpenAIRole(generation.message),
869885
},
870886
];
871887

langchain/src/chat_models/tests/chatanthropic.int.test.ts

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { expect, test } from "@jest/globals";
2-
import { HumanMessage } from "../../schema/index.js";
2+
import { HUMAN_PROMPT } from "@anthropic-ai/sdk";
3+
import { ChatMessage, HumanMessage } from "../../schema/index.js";
34
import { ChatPromptValue } from "../../prompts/chat.js";
45
import {
56
PromptTemplate,
@@ -213,6 +214,16 @@ test("ChatAnthropic, Claude V2", async () => {
213214
console.log(responseA.generations);
214215
});
215216

217+
test("ChatAnthropic with specific roles in ChatMessage", async () => {
218+
const chat = new ChatAnthropic({
219+
modelName: "claude-instant-v1",
220+
maxTokensToSample: 10,
221+
});
222+
const user_message = new ChatMessage("Hello!", HUMAN_PROMPT);
223+
const res = await chat.call([user_message]);
224+
console.log({ res });
225+
});
226+
216227
test("Test ChatAnthropic stream method", async () => {
217228
const model = new ChatAnthropic({
218229
maxTokensToSample: 50,

langchain/src/chat_models/tests/chatgooglevertexai.int.test.ts

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { test } from "@jest/globals";
2-
import { HumanMessage } from "../../schema/index.js";
2+
import { ChatMessage, HumanMessage } from "../../schema/index.js";
33
import {
44
PromptTemplate,
55
ChatPromptTemplate,
@@ -26,6 +26,12 @@ test("Test ChatGoogleVertexAI generate", async () => {
2626
console.log(JSON.stringify(res, null, 2));
2727
});
2828

29+
test("Google code messages with custom messages", async () => {
30+
const chat = new ChatGoogleVertexAI();
31+
const res = await chat.call([new ChatMessage("Hello!", "user")]);
32+
console.log(JSON.stringify(res, null, 2));
33+
});
34+
2935
test("ChatGoogleVertexAI, prompt templates", async () => {
3036
const chat = new ChatGoogleVertexAI();
3137

0 commit comments

Comments
 (0)