Skip to content

Commit

Permalink
Improve the behavior of null decoding (#681).
Browse files Browse the repository at this point in the history
A null value in v3 is considered a request to maintain the default
untouched. If the value being decoded into is a map and there's no
prior value for the field, a new key will be added with the zero
map value type as its value.
  • Loading branch information
niemeyer committed Jan 7, 2021
1 parent c476de3 commit c8046fb
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 18 deletions.
6 changes: 4 additions & 2 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ func (d *decoder) callObsoleteUnmarshaler(n *Node, u obsoleteUnmarshaler) (good
//
// If n holds a null value, prepare returns before doing anything.
func (d *decoder) prepare(n *Node, out reflect.Value) (newout reflect.Value, unmarshaled, good bool) {
if n.ShortTag() == nullTag || n.Kind == 0 && n.IsZero() {
if n.ShortTag() == nullTag {
return out, false, false
}
again := true
Expand Down Expand Up @@ -808,8 +808,10 @@ func (d *decoder) mapping(n *Node, out reflect.Value) (good bool) {
}
}

mapIsNew := false
if out.IsNil() {
out.Set(reflect.MakeMap(outt))
mapIsNew = true
}
for i := 0; i < l; i += 2 {
if isMerge(n.Content[i]) {
Expand All @@ -826,7 +828,7 @@ func (d *decoder) mapping(n *Node, out reflect.Value) (good bool) {
failf("invalid map key: %#v", k.Interface())
}
e := reflect.New(et).Elem()
if d.unmarshal(n.Content[i+1], e) {
if d.unmarshal(n.Content[i+1], e) || n.Content[i+1].ShortTag() == nullTag && (mapIsNew || !out.MapIndex(k).IsValid()) {
out.SetMapIndex(k, e)
}
}
Expand Down
54 changes: 38 additions & 16 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ var unmarshalTests = []struct {
map[string]*string{"foo": nil},
}, {
"foo: null",
map[string]string{},
map[string]string{"foo": ""},
}, {
"foo: null",
map[string]interface{}{"foo": nil},
Expand All @@ -517,7 +517,7 @@ var unmarshalTests = []struct {
map[string]*string{"foo": nil},
}, {
"foo: ~",
map[string]string{},
map[string]string{"foo": ""},
}, {
"foo: ~",
map[string]interface{}{"foo": nil},
Expand Down Expand Up @@ -1436,29 +1436,51 @@ func (s *S) TestMergeStruct(c *C) {
}
}

var unmarshalNullTests = []func() interface{}{
var unmarshalNullTests = []struct{ input string; pristine, expected func() interface{} }{{
"null",
func() interface{} { var v interface{}; v = "v"; return &v },
func() interface{} { var v interface{}; v = nil; return &v },
}, {
"null",
func() interface{} { var s = "s"; return &s },
func() interface{} { var s = "s"; return &s },
}, {
"null",
func() interface{} { var s = "s"; sptr := &s; return &sptr },
func() interface{} { var sptr *string; return &sptr },
}, {
"null",
func() interface{} { var i = 1; return &i },
func() interface{} { var i = 1; return &i },
}, {
"null",
func() interface{} { var i = 1; iptr := &i; return &iptr },
func() interface{} { m := map[string]int{"s": 1}; return &m },
func() interface{} { m := map[string]int{"s": 1}; return m },
}
func() interface{} { var iptr *int; return &iptr },
}, {
"null",
func() interface{} { var m = map[string]int{"s": 1}; return &m },
func() interface{} { var m map[string]int; return &m },
}, {
"null",
func() interface{} { var m = map[string]int{"s": 1}; return m },
func() interface{} { var m = map[string]int{"s": 1}; return m },
}, {
"s2: null\ns3: null",
func() interface{} { var m = map[string]int{"s1": 1, "s2": 2}; return m },
func() interface{} { var m = map[string]int{"s1": 1, "s2": 2, "s3": 0}; return m },
}, {
"s2: null\ns3: null",
func() interface{} { var m = map[string]interface{}{"s1": 1, "s2": 2}; return m },
func() interface{} { var m = map[string]interface{}{"s1": 1, "s2": nil, "s3": nil}; return m },
}}

func (s *S) TestUnmarshalNull(c *C) {
for _, test := range unmarshalNullTests {
pristine := test()
decoded := test()
zero := reflect.Zero(reflect.TypeOf(decoded).Elem()).Interface()
err := yaml.Unmarshal([]byte("null"), decoded)
pristine := test.pristine()
expected := test.expected()
err := yaml.Unmarshal([]byte(test.input), pristine)
c.Assert(err, IsNil)
switch pristine.(type) {
case *interface{}, **string, **int, *map[string]int:
c.Assert(reflect.ValueOf(decoded).Elem().Interface(), DeepEquals, zero)
default:
c.Assert(reflect.ValueOf(decoded).Interface(), DeepEquals, pristine)
}
c.Assert(pristine, DeepEquals, expected)
}
}

Expand Down
3 changes: 3 additions & 0 deletions node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2688,6 +2688,9 @@ func (s *S) TestNodeZeroEncodeDecode(c *C) {
c.Assert(n.Decode(&v), IsNil)
c.Assert(v, IsNil)

// ... and even when looking for its tag.
c.Assert(n.ShortTag(), Equals, "!!null")

// Kind zero is still unknown, though.
n.Line = 1
_, err = yaml.Marshal(&n)
Expand Down
5 changes: 5 additions & 0 deletions yaml.go
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,11 @@ func (n *Node) ShortTag() string {
case ScalarNode:
tag, _ := resolve("", n.Value)
return tag
case 0:
// Special case to make the zero value convenient.
if n.IsZero() {
return nullTag
}
}
return ""
}
Expand Down

0 comments on commit c8046fb

Please sign in to comment.