Skip to content

Commit c52dcb4

Browse files
authored
Add recursion guards for the following nested messages: (#25807)
- map field for pure Python - message_set_extension for Pure Python - message_set_extension for UPB Python #25335 PiperOrigin-RevId: 868863633
1 parent 5975f13 commit c52dcb4

File tree

6 files changed

+77
-6
lines changed

6 files changed

+77
-6
lines changed

python/google/protobuf/internal/decoder.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ def MessageSetItemDecoder(descriptor):
835835
local_ReadTag = ReadTag
836836
local_DecodeVarint = _DecodeVarint
837837

838-
def DecodeItem(buffer, pos, end, message, field_dict):
838+
def DecodeItem(buffer, pos, end, message, field_dict, current_depth=0):
839839
"""Decode serialized message set to its value and new position.
840840
841841
Args:
@@ -888,10 +888,19 @@ def DecodeItem(buffer, pos, end, message, field_dict):
888888
message_factory.GetMessageClass(message_type)
889889
value = field_dict.setdefault(
890890
extension, message_type._concrete_class())
891-
if value._InternalParse(buffer, message_start,message_end) != message_end:
891+
current_depth += 1
892+
if current_depth > _recursion_limit:
893+
raise _DecodeError('Error parsing message: too many levels of nesting.')
894+
if (
895+
value._InternalParse(
896+
buffer, message_start, message_end, current_depth
897+
)
898+
!= message_end
899+
):
892900
# The only reason _InternalParse would return early is if it encountered
893901
# an end-group tag.
894902
raise _DecodeError('Unexpected end-group tag.')
903+
current_depth -= 1
895904
else:
896905
if not message._unknown_fields:
897906
message._unknown_fields = []
@@ -957,7 +966,6 @@ def MapDecoder(field_descriptor, new_default, is_message_map):
957966
message_type = field_descriptor.message_type
958967

959968
def DecodeMap(buffer, pos, end, message, field_dict, current_depth=0):
960-
del current_depth # Unused.
961969
submsg = message_type._concrete_class()
962970
value = field_dict.get(key)
963971
if value is None:
@@ -970,10 +978,14 @@ def DecodeMap(buffer, pos, end, message, field_dict, current_depth=0):
970978
raise _DecodeError('Truncated message.')
971979
# Read sub-message.
972980
submsg.Clear()
973-
if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
981+
current_depth += 1
982+
if current_depth > _recursion_limit:
983+
raise _DecodeError('Error parsing message: too many levels of nesting.')
984+
if submsg._InternalParse(buffer, pos, new_pos, current_depth) != new_pos:
974985
# The only reason _InternalParse would return early is if it
975986
# encountered an end-group tag.
976987
raise _DecodeError('Unexpected end-group tag.')
988+
current_depth -= 1
977989

978990
if is_message_map:
979991
value[submsg.key].CopyFrom(submsg.value)

python/google/protobuf/internal/message_set_extensions.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ message TestMessageSetExtension1 {
2626
optional TestMessageSetExtension1 message_set_extension = 98418603;
2727
}
2828
optional int32 i = 15;
29+
optional TestMessageSet sub_msg = 16;
2930
}
3031

3132
message TestMessageSetExtension2 {

python/google/protobuf/internal/message_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
cmp = lambda x, y: (x > y) - (x < y)
3131

32+
from google.protobuf.internal import message_set_extensions_pb2
3233
from google.protobuf.internal import api_implementation # pylint: disable=g-import-not-at-top
3334
from google.protobuf.internal import decoder
3435
from google.protobuf.internal import encoder
@@ -3083,6 +3084,7 @@ def GenerateNestedProto(self, n):
30833084

30843085
def testSucceedOkSizedProto(self):
30853086
msg = unittest_pb2.TestRecursiveMessage()
3087+
decoder.SetRecursionLimit(100)
30863088
msg.ParseFromString(self.GenerateNestedProto(100))
30873089

30883090
def testAssertOversizeProto(self):
@@ -3104,6 +3106,50 @@ def testSucceedOversizeProto(self):
31043106
msg.ParseFromString(self.GenerateNestedProto(101))
31053107
decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT)
31063108

3109+
def testRecursionMap(self):
3110+
if api_implementation.Type() == 'python':
3111+
# pure python need a smaller depth limit to avoid test timeout
3112+
depth = 10
3113+
decoder.SetRecursionLimit(depth * 2)
3114+
else:
3115+
depth = 50
3116+
msg = more_messages_pb2.TestRecursiveMapMessage()
3117+
sub = msg
3118+
for _ in range(depth):
3119+
sub.map_field[0].i = 123
3120+
sub = sub.map_field[0]
3121+
parsed_msg = more_messages_pb2.TestRecursiveMapMessage()
3122+
# message can be parsed with the max recursion depth
3123+
parsed_msg.ParseFromString(msg.SerializeToString())
3124+
# message can not be parsed with one more recursion
3125+
sub.map_field[0].i = 123
3126+
with self.assertRaises(message.DecodeError) as context:
3127+
parsed_msg.ParseFromString(msg.SerializeToString())
3128+
self.assertIn('Error parsing message', str(context.exception))
3129+
3130+
def testRecisionMessageSet(self):
3131+
msg = message_set_extensions_pb2.TestMessageSet()
3132+
test_msg = message_set_extensions_pb2.TestMessageSetExtension1
3133+
ext = test_msg.message_set_extension
3134+
sub = msg
3135+
if api_implementation.Type() == 'cpp':
3136+
# TODO: message_set_extension was double counted for
3137+
# depth in c++. Should fix it to only count once.
3138+
depth = 33
3139+
else:
3140+
depth = 50
3141+
for _ in range(depth):
3142+
sub.Extensions[ext].i = 123
3143+
sub = sub.Extensions[ext].sub_msg
3144+
# message can be parsed with the max recursion depth
3145+
parsed_msg = message_set_extensions_pb2.TestMessageSet()
3146+
parsed_msg.ParseFromString(msg.SerializeToString())
3147+
# message can not be parsed when exceed max recursion depth
3148+
sub.Extensions[ext].i = 123
3149+
with self.assertRaises(message.DecodeError) as context:
3150+
msg.ParseFromString(msg.SerializeToString())
3151+
self.assertIn('Error parsing message', str(context.exception))
3152+
31073153

31083154
if __name__ == '__main__':
31093155
unittest.main()

python/google/protobuf/internal/more_messages.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ message ExtendClass {
6767
}
6868
}
6969

70+
message TestRecursiveMapMessage {
71+
optional TestRecursiveMapMessage a = 1;
72+
optional int32 i = 2;
73+
map<int32, TestRecursiveMapMessage> map_field = 3;
74+
}
75+
7076
message TestFullKeyword {
7177
optional google.protobuf.internal.OutOfOrderFields field1 = 1;
7278
optional google.protobuf.internal.class field2 = 2;

python/google/protobuf/internal/python_message.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1241,7 +1241,9 @@ def InternalParse(self, buffer, pos, end, current_depth=0):
12411241
tag_bytes, (None, None)
12421242
)
12431243
if field_decoder:
1244-
pos = field_decoder(buffer, new_pos, end, self, field_dict)
1244+
pos = field_decoder(
1245+
buffer, new_pos, end, self, field_dict, current_depth
1246+
)
12451247
continue
12461248
field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None))
12471249
if field_des is None:

upb/wire/decode.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,9 +648,13 @@ static void upb_Decoder_AddKnownMessageSetItem(
648648
upb_Message* submsg = _upb_Decoder_NewSubMessage2(
649649
d, ext->ext->UPB_PRIVATE(sub).UPB_PRIVATE(submsg),
650650
&ext->ext->UPB_PRIVATE(field), submsgp);
651+
// upb_Decode_LimitDepth() takes uint32_t, d->depth - 1 can not be negative.
652+
if (d->depth <= 1) {
653+
upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_MaxDepthExceeded);
654+
}
651655
upb_DecodeStatus status = upb_Decode(
652656
data, size, submsg, upb_MiniTableExtension_GetSubMessage(item_mt),
653-
d->extreg, d->options, &d->arena);
657+
d->extreg, upb_Decode_LimitDepth(d->options, d->depth - 1), &d->arena);
654658
if (status != kUpb_DecodeStatus_Ok) {
655659
upb_ErrorHandler_ThrowError(&d->err, status);
656660
}

0 commit comments

Comments
 (0)