@@ -32,7 +32,6 @@ import {
32
32
FunctionMessageChunk ,
33
33
HumanMessage ,
34
34
HumanMessageChunk ,
35
- MessageType ,
36
35
SystemMessage ,
37
36
SystemMessageChunk ,
38
37
} from "../schema/index.js" ;
@@ -56,9 +55,23 @@ interface OpenAILLMOutput {
56
55
tokenUsage : TokenUsage ;
57
56
}
58
57
59
- function messageTypeToOpenAIRole (
60
- type : MessageType
58
+ function extractGenericMessageCustomRole ( message : ChatMessage ) {
59
+ if (
60
+ message . role !== "system" &&
61
+ message . role !== "assistant" &&
62
+ message . role !== "user" &&
63
+ message . role !== "function"
64
+ ) {
65
+ console . warn ( `Unknown message role: ${ message . role } ` ) ;
66
+ }
67
+
68
+ return message . role as ChatCompletionResponseMessageRoleEnum ;
69
+ }
70
+
71
+ function messageToOpenAIRole (
72
+ message : BaseMessage
61
73
) : ChatCompletionResponseMessageRoleEnum {
74
+ const type = message . _getType ( ) ;
62
75
switch ( type ) {
63
76
case "system" :
64
77
return "system" ;
@@ -68,6 +81,11 @@ function messageTypeToOpenAIRole(
68
81
return "user" ;
69
82
case "function" :
70
83
return "function" ;
84
+ case "generic" : {
85
+ if ( ! ChatMessage . isInstance ( message ) )
86
+ throw new Error ( "Invalid generic chat message" ) ;
87
+ return extractGenericMessageCustomRole ( message ) ;
88
+ }
71
89
default :
72
90
throw new Error ( `Unknown message type: ${ type } ` ) ;
73
91
}
@@ -340,7 +358,7 @@ export class ChatOpenAI
340
358
) : AsyncGenerator < ChatGenerationChunk > {
341
359
const messagesMapped : ChatCompletionRequestMessage [ ] = messages . map (
342
360
( message ) => ( {
343
- role : messageTypeToOpenAIRole ( message . _getType ( ) ) ,
361
+ role : messageToOpenAIRole ( message ) ,
344
362
content : message . content ,
345
363
name : message . name ,
346
364
function_call : message . additional_kwargs
@@ -455,7 +473,7 @@ export class ChatOpenAI
455
473
const params = this . invocationParams ( options ) ;
456
474
const messagesMapped : ChatCompletionRequestMessage [ ] = messages . map (
457
475
( message ) => ( {
458
- role : messageTypeToOpenAIRole ( message . _getType ( ) ) ,
476
+ role : messageToOpenAIRole ( message ) ,
459
477
content : message . content ,
460
478
name : message . name ,
461
479
function_call : message . additional_kwargs
@@ -661,9 +679,7 @@ export class ChatOpenAI
661
679
const countPerMessage = await Promise . all (
662
680
messages . map ( async ( message ) => {
663
681
const textCount = await this . getNumTokens ( message . content ) ;
664
- const roleCount = await this . getNumTokens (
665
- messageTypeToOpenAIRole ( message . _getType ( ) )
666
- ) ;
682
+ const roleCount = await this . getNumTokens ( messageToOpenAIRole ( message ) ) ;
667
683
const nameCount =
668
684
message . name !== undefined
669
685
? tokensPerName + ( await this . getNumTokens ( message . name ) )
@@ -865,7 +881,7 @@ export class PromptLayerChatOpenAI extends ChatOpenAI {
865
881
const parsedResp = [
866
882
{
867
883
content : generation . text ,
868
- role : messageTypeToOpenAIRole ( generation . message . _getType ( ) ) ,
884
+ role : messageToOpenAIRole ( generation . message ) ,
869
885
} ,
870
886
] ;
871
887
0 commit comments