@@ -661,6 +661,7 @@ export namespace Session {
661661 description : item . description ,
662662 inputSchema : item . parameters as ZodSchema ,
663663 async execute ( args , options ) {
664+ await processor . track ( options . toolCallId )
664665 const result = await item . execute ( args , {
665666 sessionID : input . sessionID ,
666667 abort : abort . signal ,
@@ -699,6 +700,7 @@ export namespace Session {
699700 const execute = item . execute
700701 if ( ! execute ) continue
701702 item . execute = async ( args , opts ) => {
703+ await processor . track ( opts . toolCallId )
702704 const result = await execute ( args , opts )
703705 const output = result . content
704706 . filter ( ( x : any ) => x . type === "text" )
@@ -814,7 +816,12 @@ export namespace Session {
814816
815817 function createProcessor ( assistantMsg : MessageV2 . Assistant , model : ModelsDev . Model ) {
816818 const toolCalls : Record < string , MessageV2 . ToolPart > = { }
819+ const snapshots : Record < string , string > = { }
817820 return {
821+ async track ( toolCallID : string ) {
822+ const hash = await Snapshot . track ( )
823+ if ( hash ) snapshots [ toolCallID ] = hash
824+ } ,
818825 partFromToolCall ( toolCallID : string ) {
819826 return toolCalls [ toolCallID ]
820827 } ,
@@ -828,15 +835,6 @@ export namespace Session {
828835 } )
829836 switch ( value . type ) {
830837 case "start" :
831- const snapshot = await Snapshot . create ( )
832- if ( snapshot )
833- await updatePart ( {
834- id : Identifier . ascending ( "part" ) ,
835- messageID : assistantMsg . id ,
836- sessionID : assistantMsg . sessionID ,
837- type : "snapshot" ,
838- snapshot,
839- } )
840838 break
841839
842840 case "tool-input-start" :
@@ -857,6 +855,9 @@ export namespace Session {
857855 case "tool-input-delta" :
858856 break
859857
858+ case "tool-input-end" :
859+ break
860+
860861 case "tool-call" : {
861862 const match = toolCalls [ value . toolCallId ]
862863 if ( match ) {
@@ -892,15 +893,20 @@ export namespace Session {
892893 } ,
893894 } )
894895 delete toolCalls [ value . toolCallId ]
895- const snapshot = await Snapshot . create ( )
896- if ( snapshot )
897- await updatePart ( {
898- id : Identifier . ascending ( "part" ) ,
899- messageID : assistantMsg . id ,
900- sessionID : assistantMsg . sessionID ,
901- type : "snapshot" ,
902- snapshot,
903- } )
896+ const snapshot = snapshots [ value . toolCallId ]
897+ if ( snapshot ) {
898+ const patch = await Snapshot . patch ( snapshot )
899+ if ( patch . files . length ) {
900+ await updatePart ( {
901+ id : Identifier . ascending ( "part" ) ,
902+ messageID : assistantMsg . id ,
903+ sessionID : assistantMsg . sessionID ,
904+ type : "patch" ,
905+ hash : patch . hash ,
906+ files : patch . files ,
907+ } )
908+ }
909+ }
904910 }
905911 break
906912 }
@@ -921,15 +927,18 @@ export namespace Session {
921927 } ,
922928 } )
923929 delete toolCalls [ value . toolCallId ]
924- const snapshot = await Snapshot . create ( )
925- if ( snapshot )
930+ const snapshot = snapshots [ value . toolCallId ]
931+ if ( snapshot ) {
932+ const patch = await Snapshot . patch ( snapshot )
926933 await updatePart ( {
927934 id : Identifier . ascending ( "part" ) ,
928935 messageID : assistantMsg . id ,
929936 sessionID : assistantMsg . sessionID ,
930- type : "snapshot" ,
931- snapshot,
937+ type : "patch" ,
938+ hash : patch . hash ,
939+ files : patch . files ,
932940 } )
941+ }
933942 }
934943 break
935944 }
@@ -1073,33 +1082,45 @@ export namespace Session {
10731082
10741083 export async function revert ( input : RevertInput ) {
10751084 const all = await messages ( input . sessionID )
1076- const session = await get ( input . sessionID )
10771085 let lastUser : MessageV2 . User | undefined
1078- let lastSnapshot : MessageV2 . SnapshotPart | undefined
1086+ const session = await get ( input . sessionID )
1087+
1088+ let revert : Info [ "revert" ]
1089+ const patches : Snapshot . Patch [ ] = [ ]
10791090 for ( const msg of all ) {
10801091 if ( msg . info . role === "user" ) lastUser = msg . info
10811092 const remaining = [ ]
10821093 for ( const part of msg . parts ) {
1083- if ( part . type === "snapshot" ) lastSnapshot = part
1084- if ( ( msg . info . id === input . messageID && ! input . partID ) || part . id === input . partID ) {
1085- // if no useful parts left in message, same as reverting whole message
1086- const partID = remaining . some ( ( item ) => [ "text" , "tool" ] . includes ( item . type ) ) ? input . partID : undefined
1087- const snapshot = session . revert ?. snapshot ?? ( await Snapshot . create ( ) )
1088- log . info ( "revert snapshot" , { snapshot } )
1089- if ( lastSnapshot ) await Snapshot . restore ( lastSnapshot . snapshot )
1090- const next = await update ( input . sessionID , ( draft ) => {
1091- draft . revert = {
1092- // if not part id jump to the last user message
1094+ if ( revert ) {
1095+ if ( part . type === "patch" ) {
1096+ patches . push ( part )
1097+ }
1098+ continue
1099+ }
1100+
1101+ if ( ! revert ) {
1102+ if ( ( msg . info . id === input . messageID && ! input . partID ) || part . id === input . partID ) {
1103+ // if no useful parts left in message, same as reverting whole message
1104+ const partID = remaining . some ( ( item ) => [ "text" , "tool" ] . includes ( item . type ) ) ? input . partID : undefined
1105+ revert = {
10931106 messageID : ! partID && lastUser ? lastUser . id : msg . info . id ,
10941107 partID,
1095- snapshot,
10961108 }
1097- } )
1098- return next
1109+ }
1110+ remaining . push ( part )
10991111 }
1100- remaining . push ( part )
11011112 }
11021113 }
1114+
1115+ if ( revert ) {
1116+ const session = await get ( input . sessionID )
1117+ revert . snapshot = session . revert ?. snapshot ?? ( await Snapshot . track ( ) )
1118+ await Snapshot . revert ( patches )
1119+ return update ( input . sessionID , ( draft ) => {
1120+ draft . revert = revert
1121+ } )
1122+ }
1123+ return session
11031124 }
11041125
11051126 export async function unrevert ( input : { sessionID : string } ) {
0 commit comments