diff --git a/bson/bsoncodec/default_value_decoders.go b/bson/bsoncodec/default_value_decoders.go index fc4a7b1dbf..159297ef0a 100644 --- a/bson/bsoncodec/default_value_decoders.go +++ b/bson/bsoncodec/default_value_decoders.go @@ -1521,6 +1521,12 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} } + if vr.Type() == bsontype.Null { + val.Set(reflect.Zero(val.Type())) + + return vr.ReadNull() + } + if val.Kind() == reflect.Ptr && val.IsNil() { if !val.CanSet() { return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} diff --git a/bson/unmarshaling_cases_test.go b/bson/unmarshaling_cases_test.go index 37d9ded318..dd38369bff 100644 --- a/bson/unmarshaling_cases_test.go +++ b/bson/unmarshaling_cases_test.go @@ -199,6 +199,25 @@ type unmarshalerNonPtrStruct struct { type myInt64 int64 +var _ ValueUnmarshaler = (*myInt64)(nil) + +func (mi *myInt64) UnmarshalBSONValue(t bsontype.Type, bytes []byte) error { + if len(bytes) == 0 { + return nil + } + + if t == bsontype.Int64 { + i, err := bsonrw.NewBSONValueReader(bsontype.Int64, bytes).ReadInt64() + if err != nil { + return err + } + + *mi = myInt64(i) + } + + return nil +} + func (mi *myInt64) UnmarshalBSON(bytes []byte) error { if len(bytes) == 0 { return nil