From 71de844febff6ac5b6fc6ba165e24e9677d2ee05 Mon Sep 17 00:00:00 2001 From: Kris Brandow Date: Wed, 5 Sep 2018 16:18:39 -0400 Subject: [PATCH] Split bson package and update mongo package GODRIVER-537 GODRIVER-551 GODRIVER-527 GODRIVER-566 GODRIVER-494 Change-Id: I55bdfe7f48ded7fd9923d906d3cc6ecaec8156da --- benchmark/bson_map.go | 19 +- benchmark/bson_struct.go | 29 +- benchmark/multi.go | 4 +- bson/bson_test.go | 66 - bson/{ => bsoncodec}/benchmark_test.go | 52 +- bson/bsoncodec/bsoncodec.go | 146 + bson/bsoncodec/bsoncodec_test.go | 77 + bson/bsoncodec/codec.go | 114 + bson/bsoncodec/codec_test.go | 111 + bson/bsoncodec/copy.go | 369 ++ bson/bsoncodec/copy_test.go | 465 ++ bson/bsoncodec/decoder.go | 77 + bson/{ => bsoncodec}/decoder_test.go | 78 +- bson/bsoncodec/default_value_decoders.go | 1575 +++++++ bson/bsoncodec/default_value_decoders_test.go | 2690 ++++++++++++ bson/bsoncodec/default_value_encoders.go | 822 ++++ bson/bsoncodec/default_value_encoders_test.go | 1675 +++++++ bson/{ => bsoncodec}/document_value_reader.go | 73 +- .../document_value_reader_test.go | 93 +- bson/{ => bsoncodec}/document_value_writer.go | 65 +- .../document_value_writer_test.go | 89 +- bson/{ => bsoncodec}/encoder.go | 34 +- bson/bsoncodec/encoder_test.go | 96 + bson/{ => bsoncodec}/extjson_reader.go | 5 +- bson/{ => bsoncodec}/extjson_writer.go | 2 +- bson/{ => bsoncodec}/internal_reader.go | 29 +- .../llvalue_reader_writer_test.go | 17 +- bson/bsoncodec/marshal.go | 92 + bson/bsoncodec/marshal_test.go | 125 + bson/{ => bsoncodec}/marshaling_cases_test.go | 6 +- bson/{ => bsoncodec}/mode.go | 8 +- bson/bsoncodec/registry.go | 451 ++ bson/{ => bsoncodec}/registry_test.go | 167 +- bson/{ => bsoncodec}/struct_codec.go | 77 +- bson/bsoncodec/struct_codec_test.go | 32 + bson/{ => bsoncodec}/struct_tag_parser.go | 2 +- .../{ => bsoncodec}/struct_tag_parser_test.go | 2 +- bson/bsoncodec/types.go | 57 + bson/bsoncodec/unmarshal.go | 47 + bson/bsoncodec/unmarshal_test.go | 43 + .../unmarshaling_cases_test.go | 12 +- bson/{ => bsoncodec}/value_reader.go | 222 +- bson/{ => bsoncodec}/value_reader_test.go | 400 +- bson/{ => bsoncodec}/value_writer.go | 53 +- bson/{ => bsoncodec}/value_writer_test.go | 58 +- bson/{ => bsoncodec}/writer.go | 10 +- bson/codec.go | 2055 --------- bson/codec_test.go | 3402 --------------- bson/constructor.go | 56 +- bson/decode.go | 875 ---- bson/decode_test.go | 3862 ----------------- bson/decoder.go | 53 - bson/element.go | 13 +- bson/element_test.go | 2 +- bson/empty_interface_codec.go | 144 - bson/empty_interface_codec_test.go | 310 -- bson/encode.go | 780 ---- bson/encode_test.go | 1114 ----- bson/encoder_test.go | 25 - bson/internal/llbson/llbson.go | 67 +- bson/map_codec.go | 152 - bson/marshal.go | 180 - bson/marshal_test.go | 156 - bson/objectid/objectid.go | 6 + bson/reader.go | 11 + bson/reader_iterator.go | 1 + bson/registry.go | 254 -- bson/slice_codec.go | 182 - bson/unmarshal.go | 68 - bson/value.go | 27 +- bson/value_read_writer_copy.go | 239 - core/auth/mongodbcr.go | 3 +- core/auth/sasl.go | 5 +- core/command/abort_transaction.go | 3 +- core/command/buildinfo.go | 3 +- core/command/commit_transaction.go | 3 +- core/command/count_documents.go | 4 +- core/command/create_indexes.go | 3 +- core/command/delete.go | 3 +- core/command/distinct.go | 3 +- core/command/end_sessions.go | 3 +- core/command/getlasterror.go | 3 +- core/command/insert.go | 3 +- core/command/ismaster.go | 3 +- core/command/kill_cursors.go | 3 +- core/command/list_databases.go | 3 +- core/command/start_session.go | 3 +- core/command/update.go | 3 +- core/integration/aggregate_test.go | 14 +- core/integration/cursor_test.go | 7 +- core/integration/internal/israce/norace.go | 7 + core/integration/internal/israce/race.go | 7 + core/integration/list_collections_test.go | 4 +- core/integration/list_indexes_test.go | 12 +- core/option/options.go | 145 +- core/topology/cursor.go | 7 +- core/topology/server_options.go | 14 + core/topology/topology.go | 3 + mongo/change_stream.go | 6 +- mongo/change_stream_test.go | 12 +- mongo/client.go | 12 +- mongo/client_internal_test.go | 4 +- mongo/client_options_test.go | 3 +- mongo/clientopt/clientopt.go | 22 +- mongo/collection.go | 53 +- mongo/collection_internal_test.go | 46 +- mongo/crud_util_test.go | 8 +- mongo/database.go | 5 +- mongo/document_result.go | 4 +- mongo/findopt/findopt.go | 8 +- mongo/findopt/findopt_deleteone_test.go | 32 +- mongo/gridfs/gridfs_test.go | 23 +- mongo/index_view_internal_test.go | 13 +- mongo/mongo.go | 114 +- mongo/mongo_test.go | 26 +- mongo/read_write_concern_spec_test.go | 8 +- mongo/results.go | 4 +- mongo/results_test.go | 11 +- mongo/retryable_writes_test.go | 1 + mongo/transactions_test.go | 2 + mongo/updateopt/updateopt.go | 2 +- 121 files changed, 10585 insertions(+), 14878 deletions(-) rename bson/{ => bsoncodec}/benchmark_test.go (78%) create mode 100644 bson/bsoncodec/bsoncodec.go create mode 100644 bson/bsoncodec/bsoncodec_test.go create mode 100644 bson/bsoncodec/codec.go create mode 100644 bson/bsoncodec/codec_test.go create mode 100644 bson/bsoncodec/copy.go create mode 100644 bson/bsoncodec/copy_test.go create mode 100644 bson/bsoncodec/decoder.go rename bson/{ => bsoncodec}/decoder_test.go (59%) create mode 100644 bson/bsoncodec/default_value_decoders.go create mode 100644 bson/bsoncodec/default_value_decoders_test.go create mode 100644 bson/bsoncodec/default_value_encoders.go create mode 100644 bson/bsoncodec/default_value_encoders_test.go rename bson/{ => bsoncodec}/document_value_reader.go (78%) rename bson/{ => bsoncodec}/document_value_reader_test.go (84%) rename bson/{ => bsoncodec}/document_value_writer.go (77%) rename bson/{ => bsoncodec}/document_value_writer_test.go (75%) rename bson/{ => bsoncodec}/encoder.go (55%) create mode 100644 bson/bsoncodec/encoder_test.go rename bson/{ => bsoncodec}/extjson_reader.go (95%) rename bson/{ => bsoncodec}/extjson_writer.go (99%) rename bson/{ => bsoncodec}/internal_reader.go (77%) rename bson/{ => bsoncodec}/llvalue_reader_writer_test.go (97%) create mode 100644 bson/bsoncodec/marshal.go create mode 100644 bson/bsoncodec/marshal_test.go rename bson/{ => bsoncodec}/marshaling_cases_test.go (62%) rename bson/{ => bsoncodec}/mode.go (81%) create mode 100644 bson/bsoncodec/registry.go rename bson/{ => bsoncodec}/registry_test.go (57%) rename bson/{ => bsoncodec}/struct_codec.go (82%) create mode 100644 bson/bsoncodec/struct_codec_test.go rename bson/{ => bsoncodec}/struct_tag_parser.go (99%) rename bson/{ => bsoncodec}/struct_tag_parser_test.go (99%) create mode 100644 bson/bsoncodec/types.go create mode 100644 bson/bsoncodec/unmarshal.go create mode 100644 bson/bsoncodec/unmarshal_test.go rename bson/{ => bsoncodec}/unmarshaling_cases_test.go (68%) rename bson/{ => bsoncodec}/value_reader.go (73%) rename bson/{ => bsoncodec}/value_reader_test.go (73%) rename bson/{ => bsoncodec}/value_writer.go (95%) rename bson/{ => bsoncodec}/value_writer_test.go (84%) rename bson/{ => bsoncodec}/writer.go (85%) delete mode 100644 bson/codec.go delete mode 100644 bson/codec_test.go delete mode 100644 bson/decoder.go delete mode 100644 bson/empty_interface_codec.go delete mode 100644 bson/empty_interface_codec_test.go delete mode 100644 bson/encoder_test.go delete mode 100644 bson/map_codec.go delete mode 100644 bson/registry.go delete mode 100644 bson/slice_codec.go delete mode 100644 bson/unmarshal.go delete mode 100644 bson/value_read_writer_copy.go create mode 100644 core/integration/internal/israce/norace.go create mode 100644 core/integration/internal/israce/race.go diff --git a/benchmark/bson_map.go b/benchmark/bson_map.go index 940df53200..80569da892 100644 --- a/benchmark/bson_map.go +++ b/benchmark/bson_map.go @@ -1,12 +1,11 @@ package benchmark import ( - "bytes" "context" "errors" "fmt" - "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" ) func bsonMapDecoding(ctx context.Context, tm TimerManager, iters int, dataSet string) error { @@ -19,10 +18,9 @@ func bsonMapDecoding(ctx context.Context, tm TimerManager, iters int, dataSet st for i := 0; i < iters; i++ { out := make(map[string]interface{}) - dec := bson.NewDecoder(bytes.NewReader(r)) - err := dec.Decode(out) + err := bsoncodec.Unmarshal(r, &out) if err != nil { - return err + return nil } if len(out) == 0 { return fmt.Errorf("decoding failed") @@ -38,19 +36,20 @@ func bsonMapEncoding(ctx context.Context, tm TimerManager, iters int, dataSet st } doc := make(map[string]interface{}) - dec := bson.NewDecoder(bytes.NewReader(r)) - if err = dec.Decode(doc); err != nil { + err = bsoncodec.Unmarshal(r, &doc) + if err != nil { return err } - buf := bytes.NewBuffer([]byte{}) + var buf []byte tm.ResetTimer() for i := 0; i < iters; i++ { - if err = bson.NewEncoder(buf).Encode(&doc); err != nil { + buf, err = bsoncodec.MarshalAppend(buf[:0], doc) + if err != nil { return nil } - if buf.Len() == 0 { + if len(buf) == 0 { return errors.New("encoding failed") } } diff --git a/benchmark/bson_struct.go b/benchmark/bson_struct.go index b2b2c894c0..0ed08924cb 100644 --- a/benchmark/bson_struct.go +++ b/benchmark/bson_struct.go @@ -1,11 +1,10 @@ package benchmark import ( - "bytes" "context" "errors" - "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" ) func BSONFlatStructDecoding(ctx context.Context, tm TimerManager, iters int) error { @@ -18,8 +17,7 @@ func BSONFlatStructDecoding(ctx context.Context, tm TimerManager, iters int) err for i := 0; i < iters; i++ { out := flatBSON{} - dec := bson.NewDecoder(bytes.NewReader(r)) - err := dec.Decode(&out) + err := bsoncodec.Unmarshal(r, &out) if err != nil { return err } @@ -34,18 +32,20 @@ func BSONFlatStructEncoding(ctx context.Context, tm TimerManager, iters int) err } doc := flatBSON{} - if err = bson.NewDecoder(bytes.NewReader(r)).Decode(&doc); err != nil { + err = bsoncodec.Unmarshal(r, &doc) + if err != nil { return err } - buf := bytes.NewBuffer([]byte{}) + var buf []byte tm.ResetTimer() for i := 0; i < iters; i++ { - if err = bson.NewEncoder(buf).Encode(&doc); err != nil { + buf, err = bsoncodec.Marshal(doc) + if err != nil { return err } - if buf.Len() == 0 { + if len(buf) == 0 { return errors.New("encoding failed") } } @@ -59,18 +59,20 @@ func BSONFlatStructTagsEncoding(ctx context.Context, tm TimerManager, iters int) } doc := flatBSONTags{} - if err = bson.NewDecoder(bytes.NewReader(r)).Decode(&doc); err != nil { + err = bsoncodec.Unmarshal(r, &doc) + if err != nil { return err } - buf := bytes.NewBuffer([]byte{}) + var buf []byte tm.ResetTimer() for i := 0; i < iters; i++ { - if err = bson.NewEncoder(buf).Encode(&doc); err != nil { + buf, err = bsoncodec.MarshalAppend(buf[:0], doc) + if err != nil { return err } - if buf.Len() == 0 { + if len(buf) == 0 { return errors.New("encoding failed") } } @@ -86,8 +88,7 @@ func BSONFlatStructTagsDecoding(ctx context.Context, tm TimerManager, iters int) tm.ResetTimer() for i := 0; i < iters; i++ { out := flatBSONTags{} - dec := bson.NewDecoder(bytes.NewReader(r)) - err := dec.Decode(&out) + err := bsoncodec.Unmarshal(r, &out) if err != nil { return err } diff --git a/benchmark/multi.go b/benchmark/multi.go index aa2fb241ea..4e42314418 100644 --- a/benchmark/multi.go +++ b/benchmark/multi.go @@ -31,7 +31,7 @@ func MultiFindMany(ctx context.Context, tm TimerManager, iters int) error { payload := make([]interface{}, iters) for idx := range payload { - payload[idx] = *doc + payload[idx] = doc } if _, err = coll.InsertMany(ctx, payload); err != nil { @@ -109,7 +109,7 @@ func multiInsertCase(ctx context.Context, tm TimerManager, iters int, data strin payload := make([]interface{}, iters) for idx := range payload { - payload[idx] = *doc + payload[idx] = doc } coll := db.Collection("corpus") diff --git a/bson/bson_test.go b/bson/bson_test.go index 915e44170e..dd57f00d9d 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -7,14 +7,9 @@ package bson import ( - "bytes" "encoding/binary" "math" - "reflect" "testing" - "time" - - "github.com/stretchr/testify/assert" ) func TestValue(t *testing.T) { @@ -107,64 +102,3 @@ func TestValue(t *testing.T) { }) t.Run("document", func(t *testing.T) {}) } - -func TestTimeRoundTrip(t *testing.T) { - val := struct { - Value time.Time - ID string - }{ - ID: "time-rt-test", - } - - assert.True(t, val.Value.IsZero()) - - bsonOut, err := Marshal(val) - assert.NoError(t, err) - rtval := struct { - Value time.Time - ID string - }{} - - err = Unmarshal(bsonOut, &rtval) - assert.NoError(t, err) - assert.Equal(t, val, rtval) - assert.True(t, rtval.Value.IsZero()) - -} - -func TestBasicEncode(t *testing.T) { - for _, tc := range marshalingTestCases { - t.Run(tc.name, func(t *testing.T) { - got := make(writer, 0, 1024) - vw := newValueWriter(&got) - reg := NewRegistryBuilder().Build() - codec, err := reg.Lookup(reflect.TypeOf(tc.val)) - noerr(t, err) - err = codec.EncodeValue(EncodeContext{Registry: reg}, vw, tc.val) - noerr(t, err) - - if !bytes.Equal(got, tc.want) { - t.Errorf("Bytes are not equal. got %v; want %v", Reader(got), Reader(tc.want)) - t.Errorf("Bytes:\n%v\n%v", got, tc.want) - } - }) - } -} - -func TestBasicDecode(t *testing.T) { - for _, tc := range unmarshalingTestCases { - t.Run(tc.name, func(t *testing.T) { - got := reflect.New(tc.sType).Interface() - vr := newValueReader(tc.data) - reg := NewRegistryBuilder().Build() - codec, err := reg.Lookup(reflect.TypeOf(got)) - noerr(t, err) - err = codec.DecodeValue(DecodeContext{Registry: reg}, vr, got) - noerr(t, err) - - if !reflect.DeepEqual(got, tc.want) { - t.Errorf("Results do not match. got %+v; want %+v", got, tc.want) - } - }) - } -} diff --git a/bson/benchmark_test.go b/bson/bsoncodec/benchmark_test.go similarity index 78% rename from bson/benchmark_test.go rename to bson/bsoncodec/benchmark_test.go index 92e6aad8ba..dcf70bae31 100644 --- a/bson/benchmark_test.go +++ b/bson/bsoncodec/benchmark_test.go @@ -1,6 +1,10 @@ -package bson +package bsoncodec -import "testing" +import ( + "testing" + + "github.com/mongodb/mongo-go-driver/bson" +) type encodetest struct { Field1String string @@ -113,54 +117,42 @@ var nestedInstance = nestedtest1{ }, } -func BenchmarkEncodingv1(b *testing.B) { - for i := 0; i < b.N; i++ { - _, _ = Marshal(encodetestInstance) - } -} - func BenchmarkEncodingv2(b *testing.B) { for i := 0; i < b.N; i++ { - _, _ = Marshalv2(encodetestInstance) + _, _ = Marshal(encodetestInstance) } } func BenchmarkEncodingv2ToDocument(b *testing.B) { var buf []byte for i := 0; i < b.N; i++ { - buf, _ = Marshalv2(encodetestInstance) - _, _ = ReadDocument(buf) + buf, _ = Marshal(encodetestInstance) + _, _ = bson.ReadDocument(buf) } } -func BenchmarkEncodingDocument(b *testing.B) { - for i := 0; i < b.N; i++ { - _, _ = MarshalDocument(encodetestInstance) - } -} - -func BenchmarkEncodingv1Nested(b *testing.B) { - for i := 0; i < b.N; i++ { - _, _ = Marshal(nestedInstance) - } -} +// func BenchmarkEncodingDocument(b *testing.B) { +// for i := 0; i < b.N; i++ { +// _, _ = MarshalDocument(encodetestInstance) +// } +// } func BenchmarkEncodingv2Nested(b *testing.B) { for i := 0; i < b.N; i++ { - _, _ = Marshalv2(nestedInstance) + _, _ = Marshal(nestedInstance) } } func BenchmarkEncodingv2ToDocumentNested(b *testing.B) { var buf []byte for i := 0; i < b.N; i++ { - buf, _ = Marshalv2(nestedInstance) - _, _ = ReadDocument(buf) + buf, _ = Marshal(nestedInstance) + _, _ = bson.ReadDocument(buf) } } -func BenchmarkEncodingDocumentNested(b *testing.B) { - for i := 0; i < b.N; i++ { - _, _ = MarshalDocument(nestedInstance) - } -} +// func BenchmarkEncodingDocumentNested(b *testing.B) { +// for i := 0; i < b.N; i++ { +// _, _ = MarshalDocument(nestedInstance) +// } +// } diff --git a/bson/bsoncodec/bsoncodec.go b/bson/bsoncodec/bsoncodec.go new file mode 100644 index 0000000000..a5619d07c4 --- /dev/null +++ b/bson/bsoncodec/bsoncodec.go @@ -0,0 +1,146 @@ +package bsoncodec + +import ( + "errors" + "fmt" + "math" + + "github.com/mongodb/mongo-go-driver/bson" +) + +// ConstructElement will attempt to turn the provided key and value into an Element. +// For common types, type casting is used, if the type is more complex, such as +// a map or struct, reflection is used. If the value cannot be converted either +// by typecasting or through reflection, a null Element is constructed with the +// key. This method will never return a nil *Element. If an error turning the +// value into an Element is desired, use the InterfaceErr method. +func ConstructElement(key string, value interface{}) *bson.Element { + var elem *bson.Element + switch t := value.(type) { + case bool: + elem = bson.EC.Boolean(key, t) + case int8: + elem = bson.EC.Int32(key, int32(t)) + case int16: + elem = bson.EC.Int32(key, int32(t)) + case int32: + elem = bson.EC.Int32(key, int32(t)) + case int: + if t < math.MaxInt32 { + elem = bson.EC.Int32(key, int32(t)) + } + elem = bson.EC.Int64(key, int64(t)) + case int64: + if t < math.MaxInt32 { + elem = bson.EC.Int32(key, int32(t)) + } + elem = bson.EC.Int64(key, int64(t)) + case uint8: + elem = bson.EC.Int32(key, int32(t)) + case uint16: + elem = bson.EC.Int32(key, int32(t)) + case uint: + switch { + case t < math.MaxInt32: + elem = bson.EC.Int32(key, int32(t)) + case uint64(t) > math.MaxInt64: + elem = bson.EC.Null(key) + default: + elem = bson.EC.Int64(key, int64(t)) + } + case uint32: + if t < math.MaxInt32 { + elem = bson.EC.Int32(key, int32(t)) + } + elem = bson.EC.Int64(key, int64(t)) + case uint64: + switch { + case t < math.MaxInt32: + elem = bson.EC.Int32(key, int32(t)) + case t > math.MaxInt64: + elem = bson.EC.Null(key) + default: + elem = bson.EC.Int64(key, int64(t)) + } + case float32: + elem = bson.EC.Double(key, float64(t)) + case float64: + elem = bson.EC.Double(key, t) + case string: + elem = bson.EC.String(key, t) + case *bson.Element: + elem = t + case *bson.Document: + elem = bson.EC.SubDocument(key, t) + case bson.Reader: + elem = bson.EC.SubDocumentFromReader(key, t) + case *bson.Value: + elem = bson.EC.FromValue(key, t) + if elem == nil { + elem = bson.EC.Null(key) + } + default: + // TODO(skriptble): Allow users to provide registry + // TODO(skriptble): Use a pool of []byte + buf, err := marshalElement(defaultRegistry, nil, key, t) + if err != nil { + elem = bson.EC.Null(key) + } + elem, err = bson.EC.FromBytesErr(buf) + if err != nil { + elem = bson.EC.Null(key) + } + } + + return elem +} + +// ConstructElementErr does the same thing as ConstructElement but returns an +// error instead of returning a BSON Null element. +func ConstructElementErr(key string, value interface{}) (*bson.Element, error) { + var elem *bson.Element + var err error + switch t := value.(type) { + case bool, int8, int16, int32, int, int64, uint8, uint16, + uint32, float32, float64, string, *bson.Element, *bson.Document, bson.Reader: + elem = ConstructElement(key, value) + case uint: + switch { + case t < math.MaxInt32: + elem = bson.EC.Int32(key, int32(t)) + case uint64(t) > math.MaxInt64: + err = fmt.Errorf("BSON only has signed integer types and %d overflows an int64", t) + default: + elem = bson.EC.Int64(key, int64(t)) + } + case uint64: + switch { + case t < math.MaxInt32: + elem = bson.EC.Int32(key, int32(t)) + case uint64(t) > math.MaxInt64: + err = fmt.Errorf("BSON only has signed integer types and %d overflows an int64", t) + default: + elem = bson.EC.Int64(key, int64(t)) + } + case *bson.Value: + elem = bson.EC.FromValue(key, t) + if elem == nil { + err = errors.New("invalid *Value provided, cannot convert to *Element") + } + default: + // TODO(skriptble): Allow users to provide registry + // TODO(skriptble): Use a pool of []byte + var buf []byte + buf, err = marshalElement(defaultRegistry, nil, key, t) + if err != nil { + break + } + elem, err = bson.EC.FromBytesErr(buf) + } + + if err != nil { + return nil, err + } + + return elem, nil +} diff --git a/bson/bsoncodec/bsoncodec_test.go b/bson/bsoncodec/bsoncodec_test.go new file mode 100644 index 0000000000..df6d2bf7cd --- /dev/null +++ b/bson/bsoncodec/bsoncodec_test.go @@ -0,0 +1,77 @@ +package bsoncodec + +import ( + "bytes" + "reflect" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/mongodb/mongo-go-driver/bson" +) + +func TestBasicEncode(t *testing.T) { + for _, tc := range marshalingTestCases { + t.Run(tc.name, func(t *testing.T) { + got := make(writer, 0, 1024) + vw := newValueWriter(&got) + reg := NewRegistryBuilder().Build() + encoder, err := reg.LookupEncoder(reflect.TypeOf(tc.val)) + noerr(t, err) + err = encoder.EncodeValue(EncodeContext{Registry: reg}, vw, tc.val) + noerr(t, err) + + if !bytes.Equal(got, tc.want) { + t.Errorf("Bytes are not equal. got %v; want %v", bson.Reader(got), bson.Reader(tc.want)) + t.Errorf("Bytes:\n%v\n%v", got, tc.want) + } + }) + } +} + +func TestBasicDecode(t *testing.T) { + for _, tc := range unmarshalingTestCases { + t.Run(tc.name, func(t *testing.T) { + got := reflect.New(tc.sType).Interface() + vr := newValueReader(tc.data) + reg := NewRegistryBuilder().Build() + decoder, err := reg.LookupDecoder(reflect.TypeOf(got)) + noerr(t, err) + err = decoder.DecodeValue(DecodeContext{Registry: reg}, vr, got) + noerr(t, err) + + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("Results do not match. got %+v; want %+v", got, tc.want) + } + }) + } +} + +func TestTimeRoundTrip(t *testing.T) { + val := struct { + Value time.Time + ID string + }{ + ID: "time-rt-test", + } + + if !val.Value.IsZero() { + t.Errorf("Did not get zero time as expected.") + } + + bsonOut, err := Marshal(val) + noerr(t, err) + rtval := struct { + Value time.Time + ID string + }{} + + err = Unmarshal(bsonOut, &rtval) + noerr(t, err) + if !cmp.Equal(val, rtval) { + t.Errorf("Did not round trip properly. got %v; want %v", val, rtval) + } + if !rtval.Value.IsZero() { + t.Errorf("Did not get zero time as expected.") + } +} diff --git a/bson/bsoncodec/codec.go b/bson/bsoncodec/codec.go new file mode 100644 index 0000000000..f34f1b9d5f --- /dev/null +++ b/bson/bsoncodec/codec.go @@ -0,0 +1,114 @@ +package bsoncodec + +import ( + "fmt" + "reflect" + "strings" +) + +var ptBool = reflect.TypeOf((*bool)(nil)) +var ptInt8 = reflect.TypeOf((*int8)(nil)) +var ptInt16 = reflect.TypeOf((*int16)(nil)) +var ptInt32 = reflect.TypeOf((*int32)(nil)) +var ptInt64 = reflect.TypeOf((*int64)(nil)) +var ptInt = reflect.TypeOf((*int)(nil)) +var ptUint8 = reflect.TypeOf((*uint8)(nil)) +var ptUint16 = reflect.TypeOf((*uint16)(nil)) +var ptUint32 = reflect.TypeOf((*uint32)(nil)) +var ptUint64 = reflect.TypeOf((*uint64)(nil)) +var ptUint = reflect.TypeOf((*uint)(nil)) +var ptFloat32 = reflect.TypeOf((*float32)(nil)) +var ptFloat64 = reflect.TypeOf((*float64)(nil)) +var ptString = reflect.TypeOf((*string)(nil)) + +// ValueEncoderError is an error returned from a ValueEncoder when the provided +// value can't be encoded by the ValueEncoder. +type ValueEncoderError struct { + Name string + Types []interface{} + Received interface{} +} + +func (vee ValueEncoderError) Error() string { + types := make([]string, 0, len(vee.Types)) + for _, t := range vee.Types { + types = append(types, fmt.Sprintf("%T", t)) + } + return fmt.Sprintf("%s can only process %s, but got a %T", vee.Name, strings.Join(types, ", "), vee.Received) +} + +// ValueDecoderError is an error returned from a ValueDecoder when the provided +// value can't be decoded by the ValueDecoder. +type ValueDecoderError struct { + Name string + Types []interface{} + Received interface{} +} + +func (vde ValueDecoderError) Error() string { + types := make([]string, 0, len(vde.Types)) + for _, t := range vde.Types { + types = append(types, fmt.Sprintf("%T", t)) + } + return fmt.Sprintf("%s can only process %s, but got a %T", vde.Name, strings.Join(types, ", "), vde.Received) +} + +// EncodeContext is the contextual information required for a Codec to encode a +// value. +type EncodeContext struct { + *Registry + MinSize bool +} + +// DecodeContext is the contextual information required for a Codec to decode a +// value. +type DecodeContext struct { + *Registry + Truncate bool +} + +// ValueCodec is the interface that groups the methods to encode and decode +// values. +type ValueCodec interface { + ValueEncoder + ValueDecoder +} + +// ValueEncoder is the interface implemented by types that can handle the +// encoding of a value. Implementations must handle both values and +// pointers to values. +type ValueEncoder interface { + EncodeValue(EncodeContext, ValueWriter, interface{}) error +} + +// ValueEncoderFunc is an adapter function that allows a function with the +// correct signature to be used as a ValueEncoder. +type ValueEncoderFunc func(EncodeContext, ValueWriter, interface{}) error + +// EncodeValue implements the ValueEncoder interface. +func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val interface{}) error { + return fn(ec, vw, val) +} + +// ValueDecoder is the interface implemented by types that can handle the +// decoding of a value. Implementations must handle pointers to values, +// including pointers to pointer values. The implementation may create a new +// value and assign it to the pointer if necessary. +type ValueDecoder interface { + DecodeValue(DecodeContext, ValueReader, interface{}) error +} + +// ValueDecoderFunc is an adapter function that allows a function with the +// correct signature to be used as a ValueDecoder. +type ValueDecoderFunc func(DecodeContext, ValueReader, interface{}) error + +// DecodeValue implements the ValueDecoder interface. +func (fn ValueDecoderFunc) DecodeValue(dc DecodeContext, vr ValueReader, val interface{}) error { + return fn(dc, vr, val) +} + +// CodecZeroer is the interface implemented by Codecs that can also determine if +// a value of the type that would be encoded is zero. +type CodecZeroer interface { + IsTypeZero(interface{}) bool +} diff --git a/bson/bsoncodec/codec_test.go b/bson/bsoncodec/codec_test.go new file mode 100644 index 0000000000..3a7a0f7f8b --- /dev/null +++ b/bson/bsoncodec/codec_test.go @@ -0,0 +1,111 @@ +package bsoncodec + +import ( + "reflect" + "testing" + + "github.com/mongodb/mongo-go-driver/bson" +) + +func compareValues(v1, v2 *bson.Value) bool { return v1.Equal(v2) } +func compareElements(e1, e2 *bson.Element) bool { return e1.Equal(e2) } +func compareStrings(s1, s2 string) bool { return s1 == s2 } + +type noPrivateFields struct { + a string +} + +func compareNoPrivateFields(npf1, npf2 noPrivateFields) bool { + return npf1.a != npf2.a // We don't want these to be equal +} + +func docToBytes(d *bson.Document) []byte { + b, err := d.MarshalBSON() + if err != nil { + panic(err) + } + return b +} + +type zeroTest struct { + reportZero bool +} + +func (z zeroTest) IsZero() bool { return z.reportZero } + +func compareZeroTest(_, _ zeroTest) bool { return true } + +type nonZeroer struct { + value bool +} + +type llCodec struct { + t *testing.T + decodeval interface{} + encodeval interface{} + err error +} + +func (llc *llCodec) EncodeValue(_ EncodeContext, _ ValueWriter, i interface{}) error { + if llc.err != nil { + return llc.err + } + + llc.encodeval = i + return nil +} + +func (llc *llCodec) DecodeValue(_ DecodeContext, _ ValueReader, i interface{}) error { + if llc.err != nil { + return llc.err + } + + val := reflect.ValueOf(i) + if val.Type().Kind() != reflect.Ptr { + llc.t.Errorf("Value provided to DecodeValue must be a pointer, but got %T", i) + return nil + } + + switch val.Type() { + case tDocument: + decodeval, ok := llc.decodeval.(*bson.Document) + if !ok { + llc.t.Errorf("decodeval must be a *Document if the i is a *Document. decodeval %T", llc.decodeval) + return nil + } + + doc := i.(*bson.Document) + doc.Reset() + err := doc.Concat(decodeval) + if err != nil { + llc.t.Errorf("could not concatenate the decoded val to doc: %v", err) + return err + } + + return nil + case tArray: + decodeval, ok := llc.decodeval.(*bson.Array) + if !ok { + llc.t.Errorf("decodeval must be a *Array if the i is a *Array. decodeval %T", llc.decodeval) + return nil + } + + arr := i.(*bson.Array) + arr.Reset() + err := arr.Concat(decodeval) + if err != nil { + llc.t.Errorf("could not concatenate the decoded val to array: %v", err) + return err + } + + return nil + } + + if !reflect.TypeOf(llc.decodeval).AssignableTo(val.Type().Elem()) { + llc.t.Errorf("decodeval must be assignable to i provided to DecodeValue, but is not. decodeval %T; i %T", llc.decodeval, i) + return nil + } + + val.Elem().Set(reflect.ValueOf(llc.decodeval)) + return nil +} diff --git a/bson/bsoncodec/copy.go b/bson/bsoncodec/copy.go new file mode 100644 index 0000000000..60be793a23 --- /dev/null +++ b/bson/bsoncodec/copy.go @@ -0,0 +1,369 @@ +package bsoncodec + +import ( + "fmt" + + "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/decimal" + "github.com/mongodb/mongo-go-driver/bson/objectid" +) + +// Copier is a type that allows copying between ValueReaders, ValueWriters, and +// []byte values. +type Copier struct { + r *Registry +} + +// NewCopier creates a new copier with the given registry. If a nil registry is provided +// a default registry is used. +func NewCopier(r *Registry) Copier { + return Copier{r: r} +} + +func (c Copier) getRegistry() *Registry { + if c.r != nil { + return c.r + } + + return defaultRegistry +} + +// CopyDocument handles copying a document from src to dst. +func CopyDocument(dst ValueWriter, src ValueReader) error { + return Copier{}.CopyDocument(dst, src) +} + +// CopyDocument handles copying one document from the src to the dst. +func (c Copier) CopyDocument(dst ValueWriter, src ValueReader) error { + dr, err := src.ReadDocument() + if err != nil { + return err + } + + dw, err := dst.WriteDocument() + if err != nil { + return err + } + + return c.copyDocumentCore(dw, dr) +} + +// CopyDocumentFromBytes copies the values from a BSON document represented as a +// []byte to a ValueWriter. +func (c Copier) CopyDocumentFromBytes(dst ValueWriter, src []byte) error { + dw, err := dst.WriteDocument() + if err != nil { + return err + } + + itr, err := bson.Reader(src).Iterator() + if err != nil { + return err + } + + for itr.Next() { + elem := itr.Element() + dvw, err := dw.WriteDocumentElement(elem.Key()) + if err != nil { + return err + } + + val := elem.Value() + err = defaultValueEncoders.encodeValue(EncodeContext{Registry: c.getRegistry()}, dvw, val) + + if err != nil { + return err + } + } + + if err := itr.Err(); err != nil { + return err + } + + return dw.WriteDocumentEnd() +} + +// CopyDocumentToBytes copies an entire document from the ValueReader and +// returns it as bytes. +func (c Copier) CopyDocumentToBytes(src ValueReader) ([]byte, error) { + return c.AppendDocumentBytes(nil, src) +} + +// AppendDocumentBytes functions the same as CopyDocumentToBytes, but will +// append the result to dst. +func (c Copier) AppendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) { + if vr, ok := src.(*valueReader); ok { + length, err := vr.peakLength() + if err != nil { + return dst, err + } + dst = append(dst, vr.d[vr.offset:vr.offset+int64(length)]...) + vr.offset += int64(length) + vr.pop() + return dst, nil + } + + vw := vwPool.Get().(*valueWriter) + defer vwPool.Put(vw) + + vw.reset(dst) + + err := c.CopyDocument(vw, src) + dst = vw.buf + return dst, err +} + +// CopyValueFromBytes will write the value represtend by t and src to dst. +func (c Copier) CopyValueFromBytes(dst ValueWriter, t bson.Type, src []byte) error { + if wvb, ok := dst.(BytesWriter); ok { + return wvb.WriteValueBytes(t, src) + } + + vr := vrPool.Get().(*valueReader) + defer vrPool.Put(vr) + + vr.reset(src) + vr.pushElement(t) + + return c.CopyValue(dst, vr) +} + +// CopyValueToBytes copies a value from src and returns it as a bson.Type and a +// []byte. +func (c Copier) CopyValueToBytes(src ValueReader) (bson.Type, []byte, error) { + return c.AppendValueBytes(nil, src) +} + +// AppendValueBytes functions the same as CopyValueToBytes, but will append the +// result to dst. +func (c Copier) AppendValueBytes(dst []byte, src ValueReader) (bson.Type, []byte, error) { + if br, ok := src.(BytesReader); ok { + return br.ReadValueBytes(dst) + } + + vw := vwPool.Get().(*valueWriter) + defer vwPool.Put(vw) + + start := len(dst) + + vw.reset(dst) + vw.push(mElement) + + err := c.CopyValue(vw, src) + if err != nil { + return 0, dst, err + } + + return bson.Type(vw.buf[start]), vw.buf[start+2:], nil +} + +// CopyValue will copy a single value from src to dst. +func (c Copier) CopyValue(dst ValueWriter, src ValueReader) error { + var err error + switch src.Type() { + case bson.TypeDouble: + var f64 float64 + f64, err = src.ReadDouble() + if err != nil { + break + } + err = dst.WriteDouble(f64) + case bson.TypeString: + var str string + str, err = src.ReadString() + if err != nil { + return err + } + err = dst.WriteString(str) + case bson.TypeEmbeddedDocument: + err = c.CopyDocument(dst, src) + case bson.TypeArray: + err = c.copyArray(dst, src) + case bson.TypeBinary: + var data []byte + var subtype byte + data, subtype, err = src.ReadBinary() + if err != nil { + break + } + err = dst.WriteBinaryWithSubtype(data, subtype) + case bson.TypeUndefined: + err = src.ReadUndefined() + if err != nil { + break + } + err = dst.WriteUndefined() + case bson.TypeObjectID: + var oid objectid.ObjectID + oid, err = src.ReadObjectID() + if err != nil { + break + } + err = dst.WriteObjectID(oid) + case bson.TypeBoolean: + var b bool + b, err = src.ReadBoolean() + if err != nil { + break + } + err = dst.WriteBoolean(b) + case bson.TypeDateTime: + var dt int64 + dt, err = src.ReadDateTime() + if err != nil { + break + } + err = dst.WriteDateTime(dt) + case bson.TypeNull: + err = src.ReadNull() + if err != nil { + break + } + err = dst.WriteNull() + case bson.TypeRegex: + var pattern, options string + pattern, options, err = src.ReadRegex() + if err != nil { + break + } + err = dst.WriteRegex(pattern, options) + case bson.TypeDBPointer: + var ns string + var pointer objectid.ObjectID + ns, pointer, err = src.ReadDBPointer() + if err != nil { + break + } + err = dst.WriteDBPointer(ns, pointer) + case bson.TypeJavaScript: + var js string + js, err = src.ReadJavascript() + if err != nil { + break + } + err = dst.WriteJavascript(js) + case bson.TypeSymbol: + var symbol string + symbol, err = src.ReadSymbol() + if err != nil { + break + } + err = dst.WriteSymbol(symbol) + case bson.TypeCodeWithScope: + var code string + var srcScope DocumentReader + code, srcScope, err = src.ReadCodeWithScope() + if err != nil { + break + } + + var dstScope DocumentWriter + dstScope, err = dst.WriteCodeWithScope(code) + if err != nil { + break + } + err = c.copyDocumentCore(dstScope, srcScope) + case bson.TypeInt32: + var i32 int32 + i32, err = src.ReadInt32() + if err != nil { + break + } + err = dst.WriteInt32(i32) + case bson.TypeTimestamp: + var t, i uint32 + t, i, err = src.ReadTimestamp() + if err != nil { + break + } + err = dst.WriteTimestamp(t, i) + case bson.TypeInt64: + var i64 int64 + i64, err = src.ReadInt64() + if err != nil { + break + } + err = dst.WriteInt64(i64) + case bson.TypeDecimal128: + var d128 decimal.Decimal128 + d128, err = src.ReadDecimal128() + if err != nil { + break + } + err = dst.WriteDecimal128(d128) + case bson.TypeMinKey: + err = src.ReadMinKey() + if err != nil { + break + } + err = dst.WriteMinKey() + case bson.TypeMaxKey: + err = src.ReadMaxKey() + if err != nil { + break + } + err = dst.WriteMaxKey() + default: + err = fmt.Errorf("Cannot copy unknown BSON type %s", src.Type()) + } + + return err +} + +func (c Copier) copyArray(dst ValueWriter, src ValueReader) error { + ar, err := src.ReadArray() + if err != nil { + return err + } + + aw, err := dst.WriteArray() + if err != nil { + return err + } + + for { + vr, err := ar.ReadValue() + if err == ErrEOA { + break + } + if err != nil { + return err + } + + vw, err := aw.WriteArrayElement() + if err != nil { + return err + } + + err = c.CopyValue(vw, vr) + if err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (c Copier) copyDocumentCore(dw DocumentWriter, dr DocumentReader) error { + for { + key, vr, err := dr.ReadElement() + if err == ErrEOD { + break + } + if err != nil { + return err + } + + vw, err := dw.WriteDocumentElement(key) + if err != nil { + return err + } + + err = c.CopyValue(vw, vr) + if err != nil { + return err + } + } + + return dw.WriteDocumentEnd() +} diff --git a/bson/bsoncodec/copy_test.go b/bson/bsoncodec/copy_test.go new file mode 100644 index 0000000000..56ecb9cf8e --- /dev/null +++ b/bson/bsoncodec/copy_test.go @@ -0,0 +1,465 @@ +package bsoncodec + +import ( + "bytes" + "errors" + "fmt" + "testing" + + "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/decimal" + "github.com/mongodb/mongo-go-driver/bson/internal/llbson" + "github.com/mongodb/mongo-go-driver/bson/objectid" +) + +func TestCopier(t *testing.T) { + t.Run("CopyDocument", func(t *testing.T) { + t.Run("ReadDocument Error", func(t *testing.T) { + want := errors.New("ReadDocumentError") + src := &llValueReaderWriter{t: t, err: want, errAfter: llvrwReadDocument} + got := Copier{}.CopyDocument(nil, src) + if !compareErrors(got, want) { + t.Errorf("Did not receive correct error. got %v; want %v", got, want) + } + }) + t.Run("WriteDocument Error", func(t *testing.T) { + want := errors.New("WriteDocumentError") + src := &llValueReaderWriter{} + dst := &llValueReaderWriter{t: t, err: want, errAfter: llvrwWriteDocument} + got := Copier{}.CopyDocument(dst, src) + if !compareErrors(got, want) { + t.Errorf("Did not receive correct error. got %v; want %v", got, want) + } + }) + t.Run("success", func(t *testing.T) { + doc := bson.NewDocument(bson.EC.String("Hello", "world")) + src := newDocumentValueReader(doc) + dst := newValueWriterFromSlice(make([]byte, 0)) + want, err := doc.MarshalBSON() + noerr(t, err) + err = Copier{}.CopyDocument(dst, src) + noerr(t, err) + got := dst.buf + if !bytes.Equal(got, want) { + t.Errorf("Bytes are not equal. got %v; want %v", bson.Reader(got), bson.Reader(want)) + } + }) + }) + t.Run("copyArray", func(t *testing.T) { + t.Run("ReadArray Error", func(t *testing.T) { + want := errors.New("ReadArrayError") + src := &llValueReaderWriter{t: t, err: want, errAfter: llvrwReadArray} + got := Copier{}.copyArray(nil, src) + if !compareErrors(got, want) { + t.Errorf("Did not receive correct error. got %v; want %v", got, want) + } + }) + t.Run("WriteArray Error", func(t *testing.T) { + want := errors.New("WriteArrayError") + src := &llValueReaderWriter{} + dst := &llValueReaderWriter{t: t, err: want, errAfter: llvrwWriteArray} + got := Copier{}.copyArray(dst, src) + if !compareErrors(got, want) { + t.Errorf("Did not receive correct error. got %v; want %v", got, want) + } + }) + t.Run("success", func(t *testing.T) { + doc := bson.NewDocument(bson.EC.ArrayFromElements("foo", bson.VC.String("Hello, world!"))) + src := newDocumentValueReader(doc) + _, err := src.ReadDocument() + noerr(t, err) + _, _, err = src.ReadElement() + noerr(t, err) + + dst := newValueWriterFromSlice(make([]byte, 0)) + _, err = dst.WriteDocument() + noerr(t, err) + _, err = dst.WriteDocumentElement("foo") + noerr(t, err) + want, err := doc.MarshalBSON() + noerr(t, err) + + err = Copier{}.copyArray(dst, src) + noerr(t, err) + + err = dst.WriteDocumentEnd() + noerr(t, err) + + got := dst.buf + if !bytes.Equal(got, want) { + t.Errorf("Bytes are not equal. got %v; want %v", bson.Reader(got), bson.Reader(want)) + } + }) + }) + t.Run("CopyValue", func(t *testing.T) { + testCases := []struct { + name string + dst *llValueReaderWriter + src *llValueReaderWriter + err error + }{ + { + "Double/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeDouble, err: errors.New("1"), errAfter: llvrwReadDouble}, + errors.New("1"), + }, + { + "Double/dst/error", + &llValueReaderWriter{bsontype: bson.TypeDouble, err: errors.New("2"), errAfter: llvrwWriteDouble}, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.14159)}, + errors.New("2"), + }, + { + "String/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeString, err: errors.New("1"), errAfter: llvrwReadString}, + errors.New("1"), + }, + { + "String/dst/error", + &llValueReaderWriter{bsontype: bson.TypeString, err: errors.New("2"), errAfter: llvrwWriteString}, + &llValueReaderWriter{bsontype: bson.TypeString, readval: string("hello, world")}, + errors.New("2"), + }, + { + "Document/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeEmbeddedDocument, err: errors.New("1"), errAfter: llvrwReadDocument}, + errors.New("1"), + }, + { + "Array/dst/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeArray, err: errors.New("2"), errAfter: llvrwReadArray}, + errors.New("2"), + }, + { + "Binary/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeBinary, err: errors.New("1"), errAfter: llvrwReadBinary}, + errors.New("1"), + }, + { + "Binary/dst/error", + &llValueReaderWriter{bsontype: bson.TypeBinary, err: errors.New("2"), errAfter: llvrwWriteBinaryWithSubtype}, + &llValueReaderWriter{bsontype: bson.TypeBinary, readval: bson.Binary{Subtype: 0xFF, Data: []byte{0x01, 0x02, 0x03}}}, + errors.New("2"), + }, + { + "Undefined/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeUndefined, err: errors.New("1"), errAfter: llvrwReadUndefined}, + errors.New("1"), + }, + { + "Undefined/dst/error", + &llValueReaderWriter{bsontype: bson.TypeUndefined, err: errors.New("2"), errAfter: llvrwWriteUndefined}, + &llValueReaderWriter{bsontype: bson.TypeUndefined}, + errors.New("2"), + }, + { + "ObjectID/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeObjectID, err: errors.New("1"), errAfter: llvrwReadObjectID}, + errors.New("1"), + }, + { + "ObjectID/dst/error", + &llValueReaderWriter{bsontype: bson.TypeObjectID, err: errors.New("2"), errAfter: llvrwWriteObjectID}, + &llValueReaderWriter{bsontype: bson.TypeObjectID, readval: objectid.ObjectID{0x01, 0x02, 0x03}}, + errors.New("2"), + }, + { + "Boolean/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeBoolean, err: errors.New("1"), errAfter: llvrwReadBoolean}, + errors.New("1"), + }, + { + "Boolean/dst/error", + &llValueReaderWriter{bsontype: bson.TypeBoolean, err: errors.New("2"), errAfter: llvrwWriteBoolean}, + &llValueReaderWriter{bsontype: bson.TypeBoolean, readval: bool(true)}, + errors.New("2"), + }, + { + "DateTime/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeDateTime, err: errors.New("1"), errAfter: llvrwReadDateTime}, + errors.New("1"), + }, + { + "DateTime/dst/error", + &llValueReaderWriter{bsontype: bson.TypeDateTime, err: errors.New("2"), errAfter: llvrwWriteDateTime}, + &llValueReaderWriter{bsontype: bson.TypeDateTime, readval: int64(1234567890)}, + errors.New("2"), + }, + { + "Null/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeNull, err: errors.New("1"), errAfter: llvrwReadNull}, + errors.New("1"), + }, + { + "Null/dst/error", + &llValueReaderWriter{bsontype: bson.TypeNull, err: errors.New("2"), errAfter: llvrwWriteNull}, + &llValueReaderWriter{bsontype: bson.TypeNull}, + errors.New("2"), + }, + { + "Regex/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeRegex, err: errors.New("1"), errAfter: llvrwReadRegex}, + errors.New("1"), + }, + { + "Regex/dst/error", + &llValueReaderWriter{bsontype: bson.TypeRegex, err: errors.New("2"), errAfter: llvrwWriteRegex}, + &llValueReaderWriter{bsontype: bson.TypeRegex, readval: bson.Regex{Pattern: "hello", Options: "world"}}, + errors.New("2"), + }, + { + "DBPointer/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeDBPointer, err: errors.New("1"), errAfter: llvrwReadDBPointer}, + errors.New("1"), + }, + { + "DBPointer/dst/error", + &llValueReaderWriter{bsontype: bson.TypeDBPointer, err: errors.New("2"), errAfter: llvrwWriteDBPointer}, + &llValueReaderWriter{bsontype: bson.TypeDBPointer, readval: bson.DBPointer{DB: "foo", Pointer: objectid.ObjectID{0x01, 0x02, 0x03}}}, + errors.New("2"), + }, + { + "Javascript/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeJavaScript, err: errors.New("1"), errAfter: llvrwReadJavascript}, + errors.New("1"), + }, + { + "Javascript/dst/error", + &llValueReaderWriter{bsontype: bson.TypeJavaScript, err: errors.New("2"), errAfter: llvrwWriteJavascript}, + &llValueReaderWriter{bsontype: bson.TypeJavaScript, readval: string("hello, world")}, + errors.New("2"), + }, + { + "Symbol/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeSymbol, err: errors.New("1"), errAfter: llvrwReadSymbol}, + errors.New("1"), + }, + { + "Symbol/dst/error", + &llValueReaderWriter{bsontype: bson.TypeSymbol, err: errors.New("2"), errAfter: llvrwWriteSymbol}, + &llValueReaderWriter{bsontype: bson.TypeSymbol, readval: bson.Symbol("hello, world")}, + errors.New("2"), + }, + { + "CodeWithScope/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeCodeWithScope, err: errors.New("1"), errAfter: llvrwReadCodeWithScope}, + errors.New("1"), + }, + { + "CodeWithScope/dst/error", + &llValueReaderWriter{bsontype: bson.TypeCodeWithScope, err: errors.New("2"), errAfter: llvrwWriteCodeWithScope}, + &llValueReaderWriter{bsontype: bson.TypeCodeWithScope}, + errors.New("2"), + }, + { + "CodeWithScope/dst/copyDocumentCore error", + &llValueReaderWriter{err: errors.New("3"), errAfter: llvrwWriteDocumentElement}, + &llValueReaderWriter{bsontype: bson.TypeCodeWithScope}, + errors.New("3"), + }, + { + "Int32/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeInt32, err: errors.New("1"), errAfter: llvrwReadInt32}, + errors.New("1"), + }, + { + "Int32/dst/error", + &llValueReaderWriter{bsontype: bson.TypeInt32, err: errors.New("2"), errAfter: llvrwWriteInt32}, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(12345)}, + errors.New("2"), + }, + { + "Timestamp/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeTimestamp, err: errors.New("1"), errAfter: llvrwReadTimestamp}, + errors.New("1"), + }, + { + "Timestamp/dst/error", + &llValueReaderWriter{bsontype: bson.TypeTimestamp, err: errors.New("2"), errAfter: llvrwWriteTimestamp}, + &llValueReaderWriter{bsontype: bson.TypeTimestamp, readval: bson.Timestamp{T: 12345, I: 67890}}, + errors.New("2"), + }, + { + "Int64/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeInt64, err: errors.New("1"), errAfter: llvrwReadInt64}, + errors.New("1"), + }, + { + "Int64/dst/error", + &llValueReaderWriter{bsontype: bson.TypeInt64, err: errors.New("2"), errAfter: llvrwWriteInt64}, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(1234567890)}, + errors.New("2"), + }, + { + "Decimal128/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeDecimal128, err: errors.New("1"), errAfter: llvrwReadDecimal128}, + errors.New("1"), + }, + { + "Decimal128/dst/error", + &llValueReaderWriter{bsontype: bson.TypeDecimal128, err: errors.New("2"), errAfter: llvrwWriteDecimal128}, + &llValueReaderWriter{bsontype: bson.TypeDecimal128, readval: decimal.NewDecimal128(12345, 67890)}, + errors.New("2"), + }, + { + "MinKey/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeMinKey, err: errors.New("1"), errAfter: llvrwReadMinKey}, + errors.New("1"), + }, + { + "MinKey/dst/error", + &llValueReaderWriter{bsontype: bson.TypeMinKey, err: errors.New("2"), errAfter: llvrwWriteMinKey}, + &llValueReaderWriter{bsontype: bson.TypeMinKey}, + errors.New("2"), + }, + { + "MaxKey/src/error", + &llValueReaderWriter{}, + &llValueReaderWriter{bsontype: bson.TypeMaxKey, err: errors.New("1"), errAfter: llvrwReadMaxKey}, + errors.New("1"), + }, + { + "MaxKey/dst/error", + &llValueReaderWriter{bsontype: bson.TypeMaxKey, err: errors.New("2"), errAfter: llvrwWriteMaxKey}, + &llValueReaderWriter{bsontype: bson.TypeMaxKey}, + errors.New("2"), + }, + { + "Unknown BSON type error", + &llValueReaderWriter{}, + &llValueReaderWriter{}, + fmt.Errorf("Cannot copy unknown BSON type %s", bson.Type(0)), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.dst.t, tc.src.t = t, t + err := Copier{}.CopyValue(tc.dst, tc.src) + if !compareErrors(err, tc.err) { + t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) + } + }) + } + }) + t.Run("CopyValueFromBytes", func(t *testing.T) { + t.Run("BytesWriter", func(t *testing.T) { + vw := newValueWriterFromSlice(make([]byte, 0)) + _, err := vw.WriteDocument() + noerr(t, err) + _, err = vw.WriteDocumentElement("foo") + noerr(t, err) + err = Copier{}.CopyValueFromBytes(vw, bson.TypeString, llbson.AppendString(nil, "bar")) + noerr(t, err) + err = vw.WriteDocumentEnd() + noerr(t, err) + want, err := bson.NewDocument(bson.EC.String("foo", "bar")).MarshalBSON() + noerr(t, err) + got := vw.buf + if !bytes.Equal(got, want) { + t.Errorf("Bytes are not equal. got %v; want %v", bson.Reader(got), bson.Reader(want)) + } + }) + t.Run("Non BytesWriter", func(t *testing.T) { + llvrw := &llValueReaderWriter{t: t} + err := Copier{}.CopyValueFromBytes(llvrw, bson.TypeString, llbson.AppendString(nil, "bar")) + noerr(t, err) + got, want := llvrw.invoked, llvrwWriteString + if got != want { + t.Errorf("Incorrect method invoked on llvrw. got %v; want %v", got, want) + } + }) + }) + t.Run("CopyValueToBytes", func(t *testing.T) { + t.Run("BytesReader", func(t *testing.T) { + b, err := bson.NewDocument(bson.EC.String("hello", "world")).MarshalBSON() + noerr(t, err) + vr := newValueReader(b) + _, err = vr.ReadDocument() + noerr(t, err) + _, _, err = vr.ReadElement() + noerr(t, err) + bsontype, got, err := Copier{}.CopyValueToBytes(vr) + noerr(t, err) + want := llbson.AppendString(nil, "world") + if bsontype != bson.TypeString { + t.Errorf("Incorrect type returned. got %v; want %v", bsontype, bson.TypeString) + } + if !bytes.Equal(got, want) { + t.Errorf("Bytes do not match. got %v; want %v", got, want) + } + }) + t.Run("Non BytesReader", func(t *testing.T) { + llvrw := &llValueReaderWriter{t: t, bsontype: bson.TypeString, readval: string("Hello, world!")} + bsontype, got, err := Copier{}.CopyValueToBytes(llvrw) + noerr(t, err) + want := llbson.AppendString(nil, "Hello, world!") + if bsontype != bson.TypeString { + t.Errorf("Incorrect type returned. got %v; want %v", bsontype, bson.TypeString) + } + if !bytes.Equal(got, want) { + t.Errorf("Bytes do not match. got %v; want %v", got, want) + } + }) + }) + t.Run("AppendValueBytes", func(t *testing.T) { + t.Run("BytesReader", func(t *testing.T) { + b, err := bson.NewDocument(bson.EC.String("hello", "world")).MarshalBSON() + noerr(t, err) + vr := newValueReader(b) + _, err = vr.ReadDocument() + noerr(t, err) + _, _, err = vr.ReadElement() + noerr(t, err) + bsontype, got, err := Copier{}.AppendValueBytes(nil, vr) + noerr(t, err) + want := llbson.AppendString(nil, "world") + if bsontype != bson.TypeString { + t.Errorf("Incorrect type returned. got %v; want %v", bsontype, bson.TypeString) + } + if !bytes.Equal(got, want) { + t.Errorf("Bytes do not match. got %v; want %v", got, want) + } + }) + t.Run("Non BytesReader", func(t *testing.T) { + llvrw := &llValueReaderWriter{t: t, bsontype: bson.TypeString, readval: string("Hello, world!")} + bsontype, got, err := Copier{}.AppendValueBytes(nil, llvrw) + noerr(t, err) + want := llbson.AppendString(nil, "Hello, world!") + if bsontype != bson.TypeString { + t.Errorf("Incorrect type returned. got %v; want %v", bsontype, bson.TypeString) + } + if !bytes.Equal(got, want) { + t.Errorf("Bytes do not match. got %v; want %v", got, want) + } + }) + t.Run("CopyValue error", func(t *testing.T) { + want := errors.New("CopyValue error") + llvrw := &llValueReaderWriter{t: t, bsontype: bson.TypeString, err: want, errAfter: llvrwReadString} + _, _, got := Copier{}.AppendValueBytes(make([]byte, 0), llvrw) + if !compareErrors(got, want) { + t.Errorf("Errors do not match. got %v; want %v", got, want) + } + }) + }) +} diff --git a/bson/bsoncodec/decoder.go b/bson/bsoncodec/decoder.go new file mode 100644 index 0000000000..0879b97587 --- /dev/null +++ b/bson/bsoncodec/decoder.go @@ -0,0 +1,77 @@ +package bsoncodec + +import ( + "errors" + "fmt" + "reflect" + "sync" +) + +// This pool is used to keep the allocations of Decoders down. This is only used for the Marshal* +// methods and is not consumable from outside of this package. The Encoders retrieved from this pool +// must have both Reset and SetRegistry called on them. +var decPool = sync.Pool{ + New: func() interface{} { + return new(Decoder) + }, +} + +// A Decoder reads and decodes BSON documents from a stream. +type Decoder struct { + r *Registry + vr ValueReader +} + +// NewDecoder returns a new decoder that uses Registry reg to read from r. +func NewDecoder(r *Registry, vr ValueReader) (*Decoder, error) { + if r == nil { + return nil, errors.New("cannot create a new Decoder with a nil Registry") + } + if vr == nil { + return nil, errors.New("cannot create a new Decoder with a nil ValueReader") + } + + return &Decoder{ + r: r, + vr: vr, + }, nil +} + +// Decode reads the next BSON document from the stream and decodes it into the +// value pointed to by val. +// +// The documentation for Unmarshal contains details about of BSON into a Go +// value. +func (d *Decoder) Decode(val interface{}) error { + if unmarshaler, ok := val.(Unmarshaler); ok { + // TODO(skriptble): Reuse a []byte here and use the AppendDocumentBytes method. + buf, err := Copier{r: d.r}.CopyDocumentToBytes(d.vr) + if err != nil { + return err + } + return unmarshaler.UnmarshalBSON(buf) + } + + rval := reflect.TypeOf(val) + if rval.Kind() != reflect.Ptr { + return fmt.Errorf("argument to Decode must be a pointer to a type, but got %v", rval) + } + decoder, err := d.r.LookupDecoder(rval.Elem()) + if err != nil { + return err + } + return decoder.DecodeValue(DecodeContext{Registry: d.r}, d.vr, val) +} + +// Reset will reset the state of the decoder, using the same *Registry used in +// the original construction but using r for reading. +func (d *Decoder) Reset(vr ValueReader) error { + d.vr = vr + return nil +} + +// SetRegistry replaces the current registry of the decoder with r. +func (d *Decoder) SetRegistry(r *Registry) error { + d.r = r + return nil +} diff --git a/bson/decoder_test.go b/bson/bsoncodec/decoder_test.go similarity index 59% rename from bson/decoder_test.go rename to bson/bsoncodec/decoder_test.go index 9116b1a35b..788f573f44 100644 --- a/bson/decoder_test.go +++ b/bson/bsoncodec/decoder_test.go @@ -1,4 +1,4 @@ -package bson +package bsoncodec import ( "errors" @@ -6,6 +6,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/mongodb/mongo-go-driver/bson" ) func TestDecoderv2(t *testing.T) { @@ -20,7 +21,7 @@ func TestDecoderv2(t *testing.T) { } else { reg = NewRegistryBuilder().Build() } - dec, err := NewDecoderv2(reg, vr) + dec, err := NewDecoder(reg, vr) noerr(t, err) err = dec.Decode(got) noerr(t, err) @@ -33,30 +34,77 @@ func TestDecoderv2(t *testing.T) { t.Run("lookup error", func(t *testing.T) { type certainlydoesntexistelsewhereihope func(string, string) string cdeih := func(string, string) string { return "certainlydoesntexistelsewhereihope" } - dec, err := NewDecoderv2(defaultRegistry, new(valueReader)) + dec, err := NewDecoder(defaultRegistry, new(valueReader)) noerr(t, err) - want := ErrNoCodec{Type: reflect.TypeOf(cdeih)} - got := dec.Decode(cdeih) + want := ErrNoDecoder{Type: reflect.TypeOf(cdeih)} + got := dec.Decode(&cdeih) if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { t.Errorf("Received unexpected error. got %v; want %v", got, want) } }) + t.Run("Unmarshaler", func(t *testing.T) { + testCases := []struct { + name string + err error + vr ValueReader + invoked bool + }{ + { + "error", + errors.New("Unmarshaler error"), + &llValueReaderWriter{bsontype: bson.TypeEmbeddedDocument, err: ErrEOD, errAfter: llvrwReadElement}, + true, + }, + { + "copy error", + errors.New("copy error"), + &llValueReaderWriter{err: errors.New("copy error"), errAfter: llvrwReadDocument}, + false, + }, + { + "success", + nil, + &llValueReaderWriter{bsontype: bson.TypeEmbeddedDocument, err: ErrEOD, errAfter: llvrwReadElement}, + true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + unmarshaler := &testUnmarshaler{err: tc.err} + dec, err := NewDecoder(defaultRegistry, tc.vr) + noerr(t, err) + got := dec.Decode(unmarshaler) + want := tc.err + if !compareErrors(got, want) { + t.Errorf("Did not receive expected error. got %v; want %v", got, want) + } + if unmarshaler.invoked != tc.invoked { + if tc.invoked { + t.Error("Expected to have UnmarshalBSON invoked, but it wasn't.") + } else { + t.Error("Expected UnmarshalBSON to not be invoked, but it was.") + } + } + }) + } + }) }) t.Run("NewDecoderv2", func(t *testing.T) { t.Run("errors", func(t *testing.T) { - _, got := NewDecoderv2(nil, &valueReader{}) + _, got := NewDecoder(nil, &valueReader{}) want := errors.New("cannot create a new Decoder with a nil Registry") if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { t.Errorf("Was expecting error but got different error. got %v; want %v", got, want) } - _, got = NewDecoderv2(defaultRegistry, nil) + _, got = NewDecoder(defaultRegistry, nil) want = errors.New("cannot create a new Decoder with a nil ValueReader") if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { t.Errorf("Was expecting error but got different error. got %v; want %v", got, want) } }) t.Run("success", func(t *testing.T) { - got, err := NewDecoderv2(defaultRegistry, &valueReader{}) + got, err := NewDecoder(defaultRegistry, &valueReader{}) noerr(t, err) if got == nil { t.Errorf("Was expecting a non-nil Decoder, but got ") @@ -65,7 +113,7 @@ func TestDecoderv2(t *testing.T) { }) t.Run("Reset", func(t *testing.T) { vr1, vr2 := new(valueReader), new(documentValueReader) - dec, err := NewDecoderv2(defaultRegistry, vr1) + dec, err := NewDecoder(defaultRegistry, vr1) noerr(t, err) if dec.vr != vr1 { t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr1) @@ -78,7 +126,7 @@ func TestDecoderv2(t *testing.T) { }) t.Run("SetRegistry", func(t *testing.T) { reg1, reg2 := defaultRegistry, NewRegistryBuilder().Build() - dec, err := NewDecoderv2(reg1, new(valueReader)) + dec, err := NewDecoder(reg1, new(valueReader)) noerr(t, err) if dec.r != reg1 { t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.r, reg1) @@ -105,3 +153,13 @@ func (tdc *testDecoderCodec) DecodeValue(DecodeContext, ValueReader, interface{} tdc.DecodeValueCalled = true return nil } + +type testUnmarshaler struct { + invoked bool + err error +} + +func (tu *testUnmarshaler) UnmarshalBSON(_ []byte) error { + tu.invoked = true + return tu.err +} diff --git a/bson/bsoncodec/default_value_decoders.go b/bson/bsoncodec/default_value_decoders.go new file mode 100644 index 0000000000..beb45df62a --- /dev/null +++ b/bson/bsoncodec/default_value_decoders.go @@ -0,0 +1,1575 @@ +package bsoncodec + +import ( + "encoding/json" + "errors" + "fmt" + "math" + "net/url" + "reflect" + "strconv" + "time" + + "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/decimal" + "github.com/mongodb/mongo-go-driver/bson/objectid" +) + +var defaultValueDecoders DefaultValueDecoders + +// DefaultValueDecoders is a namespace type for the default ValueDecoders used +// when creating a registry. +type DefaultValueDecoders struct{} + +// BooleanDecodeValue is the ValueDecoderFunc for bool types. +func (dvd DefaultValueDecoders) BooleanDecodeValue(dctx DecodeContext, vr ValueReader, i interface{}) error { + if vr.Type() != bson.TypeBoolean { + return fmt.Errorf("cannot decode %v into a boolean", vr.Type()) + } + + var err error + if target, ok := i.(*bool); ok && target != nil { // if it is nil, we go the slow path. + *target, err = vr.ReadBoolean() + return err + } + + val := reflect.ValueOf(i) + if !val.IsValid() || val.Kind() != reflect.Ptr || !val.Elem().CanSet() { + return errors.New("BooleanDecodeValue can only be used to decode settable (non-nil) values") + } + val = val.Elem() + if val.Type().Kind() != reflect.Bool { + return ValueDecoderError{Name: "BooleanDecodeValue", Types: []interface{}{bool(true)}, Received: i} + } + + b, err := vr.ReadBoolean() + val.SetBool(b) + return err +} + +// IntDecodeValue is the ValueDecoderFunc for bool types. +func (dvd DefaultValueDecoders) IntDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + var i64 int64 + var err error + switch vr.Type() { + case bson.TypeInt32: + i32, err := vr.ReadInt32() + if err != nil { + return err + } + i64 = int64(i32) + case bson.TypeInt64: + i64, err = vr.ReadInt64() + if err != nil { + return err + } + case bson.TypeDouble: + f64, err := vr.ReadDouble() + if err != nil { + return err + } + if !dc.Truncate && math.Floor(f64) != f64 { + return errors.New("IntDecodeValue can only truncate float64 to an integer type when truncation is enabled") + } + if f64 > float64(math.MaxInt64) { + return fmt.Errorf("%g overflows int64", f64) + } + i64 = int64(f64) + default: + return fmt.Errorf("cannot decode %v into an integer type", vr.Type()) + } + + switch target := i.(type) { + case *int8: + if target == nil { + return errors.New("IntDecodeValue can only be used to decode non-nil *int8") + } + if i64 < math.MinInt8 || i64 > math.MaxInt8 { + return fmt.Errorf("%d overflows int8", i64) + } + *target = int8(i64) + return nil + case *int16: + if target == nil { + return errors.New("IntDecodeValue can only be used to decode non-nil *int16") + } + if i64 < math.MinInt16 || i64 > math.MaxInt16 { + return fmt.Errorf("%d overflows int16", i64) + } + *target = int16(i64) + return nil + case *int32: + if target == nil { + return errors.New("IntDecodeValue can only be used to decode non-nil *int32") + } + if i64 < math.MinInt32 || i64 > math.MaxInt32 { + return fmt.Errorf("%d overflows int32", i64) + } + *target = int32(i64) + return nil + case *int64: + if target == nil { + return errors.New("IntDecodeValue can only be used to decode non-nil *int64") + } + *target = int64(i64) + return nil + case *int: + if target == nil { + return errors.New("IntDecodeValue can only be used to decode non-nil *int") + } + if int64(int(i64)) != i64 { // Can we fit this inside of an int + return fmt.Errorf("%d overflows int", i64) + } + *target = int(i64) + return nil + } + + val := reflect.ValueOf(i) + if !val.IsValid() || val.Kind() != reflect.Ptr || !val.Elem().CanSet() { + return fmt.Errorf("IntDecodeValue can only be used to decode settable (non-nil) values") + } + val = val.Elem() + + switch val.Type().Kind() { + case reflect.Int8: + if i64 < math.MinInt8 || i64 > math.MaxInt8 { + return fmt.Errorf("%d overflows int8", i64) + } + case reflect.Int16: + if i64 < math.MinInt16 || i64 > math.MaxInt16 { + return fmt.Errorf("%d overflows int16", i64) + } + case reflect.Int32: + if i64 < math.MinInt32 || i64 > math.MaxInt32 { + return fmt.Errorf("%d overflows int32", i64) + } + case reflect.Int64: + case reflect.Int: + if int64(int(i64)) != i64 { // Can we fit this inside of an int + return fmt.Errorf("%d overflows int", i64) + } + default: + return ValueDecoderError{ + Name: "IntDecodeValue", + Types: []interface{}{(*int8)(nil), (*int16)(nil), (*int32)(nil), (*int64)(nil), (*int)(nil)}, + Received: i, + } + } + + val.SetInt(i64) + return nil +} + +// UintDecodeValue is the ValueDecoderFunc for uint types. +func (dvd DefaultValueDecoders) UintDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + var i64 int64 + var err error + switch vr.Type() { + case bson.TypeInt32: + i32, err := vr.ReadInt32() + if err != nil { + return err + } + i64 = int64(i32) + case bson.TypeInt64: + i64, err = vr.ReadInt64() + if err != nil { + return err + } + case bson.TypeDouble: + f64, err := vr.ReadDouble() + if err != nil { + return err + } + if !dc.Truncate && math.Floor(f64) != f64 { + return errors.New("UintDecodeValue can only truncate float64 to an integer type when truncation is enabled") + } + if f64 > float64(math.MaxInt64) { + return fmt.Errorf("%g overflows int64", f64) + } + i64 = int64(f64) + default: + return fmt.Errorf("cannot decode %v into an integer type", vr.Type()) + } + + switch target := i.(type) { + case *uint8: + if target == nil { + return errors.New("UintDecodeValue can only be used to decode non-nil *uint8") + } + if i64 < 0 || i64 > math.MaxUint8 { + return fmt.Errorf("%d overflows uint8", i64) + } + *target = uint8(i64) + return nil + case *uint16: + if target == nil { + return errors.New("UintDecodeValue can only be used to decode non-nil *uint16") + } + if i64 < 0 || i64 > math.MaxUint16 { + return fmt.Errorf("%d overflows uint16", i64) + } + *target = uint16(i64) + return nil + case *uint32: + if target == nil { + return errors.New("UintDecodeValue can only be used to decode non-nil *uint32") + } + if i64 < 0 || i64 > math.MaxUint32 { + return fmt.Errorf("%d overflows uint32", i64) + } + *target = uint32(i64) + return nil + case *uint64: + if target == nil { + return errors.New("UintDecodeValue can only be used to decode non-nil *uint64") + } + if i64 < 0 { + return fmt.Errorf("%d overflows uint64", i64) + } + *target = uint64(i64) + return nil + case *uint: + if target == nil { + return errors.New("UintDecodeValue can only be used to decode non-nil *uint") + } + if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint + return fmt.Errorf("%d overflows uint", i64) + } + *target = uint(i64) + return nil + } + + val := reflect.ValueOf(i) + if !val.IsValid() || val.Kind() != reflect.Ptr || !val.Elem().CanSet() { + return errors.New("UintDecodeValue can only be used to decode settable (non-nil) values") + } + val = val.Elem() + + switch val.Type().Kind() { + case reflect.Uint8: + if i64 < 0 || i64 > math.MaxUint8 { + return fmt.Errorf("%d overflows uint8", i64) + } + case reflect.Uint16: + if i64 < 0 || i64 > math.MaxUint16 { + return fmt.Errorf("%d overflows uint16", i64) + } + case reflect.Uint32: + if i64 < 0 || i64 > math.MaxUint32 { + return fmt.Errorf("%d overflows uint32", i64) + } + case reflect.Uint64: + if i64 < 0 { + return fmt.Errorf("%d overflows uint64", i64) + } + case reflect.Uint: + if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint + return fmt.Errorf("%d overflows uint", i64) + } + default: + return ValueDecoderError{ + Name: "UintDecodeValue", + Types: []interface{}{(*uint8)(nil), (*uint16)(nil), (*uint32)(nil), (*uint64)(nil), (*uint)(nil)}, + Received: i, + } + } + + val.SetUint(uint64(i64)) + return nil +} + +// FloatDecodeValue is the ValueDecoderFunc for float types. +func (dvd DefaultValueDecoders) FloatDecodeValue(ec DecodeContext, vr ValueReader, i interface{}) error { + var f float64 + var err error + switch vr.Type() { + case bson.TypeInt32: + i32, err := vr.ReadInt32() + if err != nil { + return err + } + f = float64(i32) + case bson.TypeInt64: + i64, err := vr.ReadInt64() + if err != nil { + return err + } + f = float64(i64) + case bson.TypeDouble: + f, err = vr.ReadDouble() + if err != nil { + return err + } + default: + return fmt.Errorf("cannot decode %v into a float32 or float64 type", vr.Type()) + } + + switch target := i.(type) { + case *float32: + if target == nil { + return errors.New("FloatDecodeValue can only be used to decode non-nil *float32") + } + if !ec.Truncate && float64(float32(f)) != f { + return errors.New("FloatDecodeValue can only convert float64 to float32 when truncation is allowed") + } + *target = float32(f) + return nil + case *float64: + if target == nil { + return errors.New("FloatDecodeValue can only be used to decode non-nil *float64") + } + *target = f + return nil + } + + val := reflect.ValueOf(i) + if !val.IsValid() || val.Kind() != reflect.Ptr || !val.Elem().CanSet() { + return errors.New("FloatDecodeValue can only be used to decode settable (non-nil) values") + } + val = val.Elem() + + switch val.Type().Kind() { + case reflect.Float32: + if !ec.Truncate && float64(float32(f)) != f { + return errors.New("FloatDecodeValue can only convert float64 to float32 when truncation is allowed") + } + case reflect.Float64: + default: + return ValueDecoderError{Name: "FloatDecodeValue", Types: []interface{}{(*float32)(nil), (*float64)(nil)}, Received: i} + } + + val.SetFloat(f) + return nil +} + +// StringDecodeValue is the ValueDecoderFunc for string types. +func (dvd DefaultValueDecoders) StringDecodeValue(dctx DecodeContext, vr ValueReader, i interface{}) error { + var str string + var err error + switch vr.Type() { + case bson.TypeString: + str, err = vr.ReadString() + if err != nil { + return err + } + case bson.TypeJavaScript: + str, err = vr.ReadJavascript() + if err != nil { + return err + } + case bson.TypeSymbol: + str, err = vr.ReadSymbol() + if err != nil { + return err + } + default: + return fmt.Errorf("cannot decode %v into a string type", vr.Type()) + } + + switch t := i.(type) { + case *string: + if t == nil { + return errors.New("StringDecodeValue can only be used to decode non-nil *string") + } + *t = str + return nil + case *bson.JavaScriptCode: + if t == nil { + return errors.New("StringDecodeValue can only be used to decode non-nil *JavaScriptCode") + } + *t = bson.JavaScriptCode(str) + return nil + case *bson.Symbol: + if t == nil { + return errors.New("StringDecodeValue can only be used to decode non-nil *Symbol") + } + *t = bson.Symbol(str) + return nil + } + + val := reflect.ValueOf(i) + if !val.IsValid() || val.Kind() != reflect.Ptr || !val.Elem().CanSet() { + return errors.New("StringDecodeValue can only be used to decode settable (non-nil) values") + } + val = val.Elem() + + if val.Type().Kind() != reflect.String { + return ValueDecoderError{ + Name: "StringDecodeValue", + Types: []interface{}{(*string)(nil), (*bson.JavaScriptCode)(nil), (*bson.Symbol)(nil)}, + Received: i, + } + } + + val.SetString(str) + return nil +} + +// DocumentDecodeValue is the ValueDecoderFunc for *bson.Document. +func (dvd DefaultValueDecoders) DocumentDecodeValue(dctx DecodeContext, vr ValueReader, i interface{}) error { + doc, ok := i.(**bson.Document) + if !ok { + return ValueDecoderError{Name: "DocumentDecodeValue", Types: []interface{}{(**bson.Document)(nil)}, Received: i} + } + + if doc == nil { + return errors.New("DocumentDecodeValue can only be used to decode non-nil **Document") + } + + dr, err := vr.ReadDocument() + if err != nil { + return err + } + + return dvd.decodeDocument(dctx, dr, doc) +} + +// ArrayDecodeValue is the ValueDecoderFunc for *bson.Array. +func (dvd DefaultValueDecoders) ArrayDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + parr, ok := i.(**bson.Array) + if !ok { + return ValueDecoderError{Name: "ArrayDecodeValue", Types: []interface{}{(**bson.Array)(nil)}, Received: i} + } + + if parr == nil { + return errors.New("ArrayDecodeValue can only be used to decode non-nil **Array") + } + + ar, err := vr.ReadArray() + if err != nil { + return err + } + + arr := bson.NewArray() + for { + vr, err := ar.ReadValue() + if err == ErrEOA { + break + } + if err != nil { + return err + } + + var val *bson.Value + err = dvd.valueDecodeValue(dc, vr, &val) + if err != nil { + return err + } + + arr.Append(val) + } + + *parr = arr + return nil +} + +// BinaryDecodeValue is the ValueDecoderFunc for bson.Binary. +func (dvd DefaultValueDecoders) BinaryDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + if vr.Type() != bson.TypeBinary { + return fmt.Errorf("cannot decode %v into a Binary", vr.Type()) + } + + data, subtype, err := vr.ReadBinary() + if err != nil { + return err + } + + if target, ok := i.(*bson.Binary); ok && target != nil { + *target = bson.Binary{Data: data, Subtype: subtype} + return nil + } + + if target, ok := i.(**bson.Binary); ok && target != nil { + pb := *target + if pb == nil { + pb = new(bson.Binary) + } + *pb = bson.Binary{Data: data, Subtype: subtype} + *target = pb + return nil + } + + return ValueDecoderError{Name: "BinaryDecodeValue", Types: []interface{}{(*bson.Binary)(nil)}, Received: i} +} + +// UndefinedDecodeValue is the ValueDecoderFunc for bool types. +func (dvd DefaultValueDecoders) UndefinedDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + if vr.Type() != bson.TypeUndefined { + return fmt.Errorf("cannot decode %v into an Undefined", vr.Type()) + } + + target, ok := i.(*bson.Undefinedv2) + if !ok || target == nil { + return ValueDecoderError{Name: "UndefinedDecodeValue", Types: []interface{}{(*bson.Undefinedv2)(nil)}, Received: i} + } + + *target = bson.Undefinedv2{} + return vr.ReadUndefined() +} + +// ObjectIDDecodeValue is the ValueDecoderFunc for objectid.ObjectID. +func (dvd DefaultValueDecoders) ObjectIDDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + if vr.Type() != bson.TypeObjectID { + return fmt.Errorf("cannot decode %v into an ObjectID", vr.Type()) + } + + target, ok := i.(*objectid.ObjectID) + if !ok || target == nil { + return ValueDecoderError{Name: "ObjectIDDecodeValue", Types: []interface{}{(*objectid.ObjectID)(nil)}, Received: i} + } + + oid, err := vr.ReadObjectID() + if err != nil { + return err + } + + *target = oid + return nil +} + +// DateTimeDecodeValue is the ValueDecoderFunc for bson.DateTime. +func (dvd DefaultValueDecoders) DateTimeDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + if vr.Type() != bson.TypeDateTime { + return fmt.Errorf("cannot decode %v into a DateTime", vr.Type()) + } + + target, ok := i.(*bson.DateTime) + if !ok || target == nil { + return ValueDecoderError{Name: "DateTimeDecodeValue", Types: []interface{}{(*bson.DateTime)(nil)}, Received: i} + } + + dt, err := vr.ReadDateTime() + if err != nil { + return err + } + + *target = bson.DateTime(dt) + return nil +} + +// NullDecodeValue is the ValueDecoderFunc for bson.Null. +func (dvd DefaultValueDecoders) NullDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + if vr.Type() != bson.TypeNull { + return fmt.Errorf("cannot decode %v into a Null", vr.Type()) + } + + target, ok := i.(*bson.Nullv2) + if !ok || target == nil { + return ValueDecoderError{Name: "NullDecodeValue", Types: []interface{}{(*bson.Nullv2)(nil)}, Received: i} + } + + *target = bson.Nullv2{} + return vr.ReadNull() +} + +// RegexDecodeValue is the ValueDecoderFunc for bson.Regex. +func (dvd DefaultValueDecoders) RegexDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + if vr.Type() != bson.TypeRegex { + return fmt.Errorf("cannot decode %v into a Regex", vr.Type()) + } + + target, ok := i.(*bson.Regex) + if !ok || target == nil { + return ValueDecoderError{Name: "RegexDecodeValue", Types: []interface{}{(*bson.Regex)(nil)}, Received: i} + } + + pattern, options, err := vr.ReadRegex() + if err != nil { + return err + } + + *target = bson.Regex{Pattern: pattern, Options: options} + return nil +} + +// DBPointerDecodeValue is the ValueDecoderFunc for bson.DBPointer. +func (dvd DefaultValueDecoders) DBPointerDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + if vr.Type() != bson.TypeDBPointer { + return fmt.Errorf("cannot decode %v into a DBPointer", vr.Type()) + } + + target, ok := i.(*bson.DBPointer) + if !ok || target == nil { + return ValueDecoderError{Name: "DBPointerDecodeValue", Types: []interface{}{(*bson.DBPointer)(nil)}, Received: i} + } + + ns, pointer, err := vr.ReadDBPointer() + if err != nil { + return err + } + + *target = bson.DBPointer{DB: ns, Pointer: pointer} + return nil +} + +// CodeWithScopeDecodeValue is the ValueDecoderFunc for bson.CodeWithScope. +func (dvd DefaultValueDecoders) CodeWithScopeDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + if vr.Type() != bson.TypeCodeWithScope { + return fmt.Errorf("cannot decode %v into a CodeWithScope", vr.Type()) + } + + target, ok := i.(*bson.CodeWithScope) + if !ok || target == nil { + return ValueDecoderError{ + Name: "CodeWithScopeDecodeValue", + Types: []interface{}{(*bson.CodeWithScope)(nil)}, + Received: i, + } + } + + code, dr, err := vr.ReadCodeWithScope() + if err != nil { + return err + } + + var scope *bson.Document + err = dvd.decodeDocument(dc, dr, &scope) + if err != nil { + return err + } + + *target = bson.CodeWithScope{Code: code, Scope: scope} + return nil +} + +// TimestampDecodeValue is the ValueDecoderFunc for bson.Timestamp. +func (dvd DefaultValueDecoders) TimestampDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + if vr.Type() != bson.TypeTimestamp { + return fmt.Errorf("cannot decode %v into a Timestamp", vr.Type()) + } + + target, ok := i.(*bson.Timestamp) + if !ok || target == nil { + return ValueDecoderError{Name: "TimestampDecodeValue", Types: []interface{}{(*bson.Timestamp)(nil)}, Received: i} + } + + t, incr, err := vr.ReadTimestamp() + if err != nil { + return err + } + + *target = bson.Timestamp{T: t, I: incr} + return nil +} + +// Decimal128DecodeValue is the ValueDecoderFunc for decimal.Decimal128. +func (dvd DefaultValueDecoders) Decimal128DecodeValue(dctx DecodeContext, vr ValueReader, i interface{}) error { + if vr.Type() != bson.TypeDecimal128 { + return fmt.Errorf("cannot decode %v into a decimal.Decimal128", vr.Type()) + } + + target, ok := i.(*decimal.Decimal128) + if !ok || target == nil { + return ValueDecoderError{Name: "Decimal128DecodeValue", Types: []interface{}{(*decimal.Decimal128)(nil)}, Received: i} + } + + d128, err := vr.ReadDecimal128() + if err != nil { + return err + } + + *target = d128 + return nil +} + +// MinKeyDecodeValue is the ValueDecoderFunc for bson.MinKey. +func (dvd DefaultValueDecoders) MinKeyDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + if vr.Type() != bson.TypeMinKey { + return fmt.Errorf("cannot decode %v into a MinKey", vr.Type()) + } + + target, ok := i.(*bson.MinKeyv2) + if !ok || target == nil { + return ValueDecoderError{Name: "MinKeyDecodeValue", Types: []interface{}{(*bson.MinKeyv2)(nil)}, Received: i} + } + + *target = bson.MinKeyv2{} + return vr.ReadMinKey() +} + +// MaxKeyDecodeValue is the ValueDecoderFunc for bson.MaxKey. +func (dvd DefaultValueDecoders) MaxKeyDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + if vr.Type() != bson.TypeMaxKey { + return fmt.Errorf("cannot decode %v into a MaxKey", vr.Type()) + } + + target, ok := i.(*bson.MaxKeyv2) + if !ok || target == nil { + return ValueDecoderError{Name: "MaxKeyDecodeValue", Types: []interface{}{(*bson.MaxKeyv2)(nil)}, Received: i} + } + + *target = bson.MaxKeyv2{} + return vr.ReadMaxKey() +} + +// ValueDecodeValue is the ValueDecoderFunc for *bson.Value. +func (dvd DefaultValueDecoders) ValueDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + pval, ok := i.(**bson.Value) + if !ok { + return ValueDecoderError{Name: "ValueDecodeValue", Types: []interface{}{(**bson.Value)(nil)}, Received: i} + } + + if pval == nil { + return errors.New("ValueDecodeValue can only be used to decode non-nil **Value") + } + + return dvd.valueDecodeValue(dc, vr, pval) +} + +// JSONNumberDecodeValue is the ValueDecoderFunc for json.Number. +func (dvd DefaultValueDecoders) JSONNumberDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + target, ok := i.(*json.Number) + if !ok || target == nil { + return ValueDecoderError{Name: "JSONNumberDecodeValue", Types: []interface{}{(*json.Number)(nil)}, Received: i} + } + + switch vr.Type() { + case bson.TypeDouble: + f64, err := vr.ReadDouble() + if err != nil { + return err + } + *target = json.Number(strconv.FormatFloat(f64, 'g', -1, 64)) + case bson.TypeInt32: + i32, err := vr.ReadInt32() + if err != nil { + return err + } + *target = json.Number(strconv.FormatInt(int64(i32), 10)) + case bson.TypeInt64: + i64, err := vr.ReadInt64() + if err != nil { + return err + } + *target = json.Number(strconv.FormatInt(i64, 10)) + default: + return fmt.Errorf("cannot decode %v into a json.Number", vr.Type()) + } + + return nil +} + +// URLDecodeValue is the ValueDecoderFunc for url.URL. +func (dvd DefaultValueDecoders) URLDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + if vr.Type() != bson.TypeString { + return fmt.Errorf("cannot decode %v into a *url.URL", vr.Type()) + } + + str, err := vr.ReadString() + if err != nil { + return err + } + + u, err := url.Parse(str) + if err != nil { + return err + } + + err = ValueDecoderError{Name: "URLDecodeValue", Types: []interface{}{(*url.URL)(nil), (**url.URL)(nil)}, Received: i} + + // It's valid to use either a *url.URL or a url.URL + switch target := i.(type) { + case *url.URL: + if target == nil { + return err + } + *target = *u + case **url.URL: + if target == nil { + return err + } + *target = u + default: + return err + } + return nil +} + +// TimeDecodeValue is the ValueDecoderFunc for time.Time. +func (dvd DefaultValueDecoders) TimeDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + if vr.Type() != bson.TypeDateTime { + return fmt.Errorf("cannot decode %v into a time.Time", vr.Type()) + } + + dt, err := vr.ReadDateTime() + if err != nil { + return err + } + + if target, ok := i.(*time.Time); ok && target != nil { + *target = time.Unix(dt/1000, dt%1000*1000000) + return nil + } + + if target, ok := i.(**time.Time); ok && target != nil { + tt := *target + if tt == nil { + tt = new(time.Time) + } + *tt = time.Unix(dt/1000, dt%1000*1000000) + *target = tt + return nil + } + + return ValueDecoderError{ + Name: "TimeDecodeValue", + Types: []interface{}{(*time.Time)(nil), (**time.Time)(nil)}, + Received: i, + } +} + +// ReaderDecodeValue is the ValueDecoderFunc for bson.Reader. +func (dvd DefaultValueDecoders) ReaderDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + rdr, ok := i.(*bson.Reader) + if !ok { + return ValueDecoderError{Name: "ReaderDecodeValue", Types: []interface{}{(*bson.Reader)(nil)}, Received: i} + } + + if rdr == nil { + return errors.New("ReaderDecodeValue can only be used to decode non-nil *Reader") + } + + if *rdr == nil { + *rdr = make(bson.Reader, 0) + } else { + *rdr = (*rdr)[:0] + } + + var err error + *rdr, err = Copier{r: dc.Registry}.AppendDocumentBytes(*rdr, vr) + return err +} + +// ByteSliceDecodeValue is the ValueDecoderFunc for []byte. +func (dvd DefaultValueDecoders) ByteSliceDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + if vr.Type() != bson.TypeBinary { + return fmt.Errorf("cannot decode %v into a *[]byte", vr.Type()) + } + + target, ok := i.(*[]byte) + if !ok || target == nil { + return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []interface{}{(*[]byte)(nil)}, Received: i} + } + + data, subtype, err := vr.ReadBinary() + if err != nil { + return err + } + if subtype != 0x00 { + return fmt.Errorf("ByteSliceDecodeValue can only be used to decode subtype 0x00 for %s, got %v", bson.TypeBinary, subtype) + } + + *target = data + return nil +} + +// ElementSliceDecodeValue is the ValueDecoderFunc for []*bson.Element. +func (dvd DefaultValueDecoders) ElementSliceDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + dr, err := vr.ReadDocument() + if err != nil { + return err + } + elems := make([]*bson.Element, 0) + for { + key, vr, err := dr.ReadElement() + if err == ErrEOD { + break + } + if err != nil { + return err + } + + var elem *bson.Element + err = dvd.elementDecodeValue(dc, vr, key, &elem) + if err != nil { + return err + } + + elems = append(elems, elem) + } + + target, ok := i.(*[]*bson.Element) + if !ok || target == nil { + return ValueDecoderError{Name: "ElementSliceDecodeValue", Types: []interface{}{(*[]*bson.Element)(nil)}, Received: i} + } + + *target = elems + return nil +} + +// MapDecodeValue is the ValueDecoderFunc for map[string]* types. +func (dvd DefaultValueDecoders) MapDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + val := reflect.ValueOf(i) + if !val.IsValid() || val.Kind() != reflect.Ptr || val.IsNil() { + return fmt.Errorf("MapDecodeValue can only be used to decode non-nil pointers to map values, got %T", i) + } + + if val.Elem().Kind() != reflect.Map || val.Elem().Type().Key().Kind() != reflect.String || !val.Elem().CanSet() { + return errors.New("MapDecodeValue can only decode settable maps with string keys") + } + + dr, err := vr.ReadDocument() + if err != nil { + return err + } + + if val.Elem().IsNil() { + val.Elem().Set(reflect.MakeMap(val.Elem().Type())) + } + + mVal := val.Elem() + + dFn, err := dvd.decodeFn(dc, mVal) + if err != nil { + return err + } + + for { + var elem reflect.Value + key, vr, err := dr.ReadElement() + if err == ErrEOD { + break + } + if err != nil { + return err + } + key, elem, err = dFn(dc, vr, key) + if err != nil { + return err + } + + mVal.SetMapIndex(reflect.ValueOf(key), elem) + } + return err +} + +type decodeFn func(dc DecodeContext, vr ValueReader, key string) (updatedKey string, v reflect.Value, err error) + +// decodeFn returns a function that can be used to decode the values of a map. +// The mapVal parameter should be a map type, not a pointer to a map type. +// +// If error is nil, decodeFn will return a non-nil decodeFn. +func (dvd DefaultValueDecoders) decodeFn(dc DecodeContext, mapVal reflect.Value) (decodeFn, error) { + var dFn decodeFn + switch mapVal.Type().Elem() { + case tElement: + // TODO(skriptble): We have to decide if we want to support this. We have + // information loss because we can only store either the map key or the element + // key. We could add a struct tag field that allows the user to make a decision. + dFn = func(dc DecodeContext, vr ValueReader, key string) (string, reflect.Value, error) { + var elem *bson.Element + err := dvd.elementDecodeValue(dc, vr, key, &elem) + if err != nil { + return key, reflect.Value{}, err + } + return key, reflect.ValueOf(elem), nil + } + default: + eType := mapVal.Type().Elem() + decoder, err := dc.LookupDecoder(eType) + if err != nil { + return nil, err + } + + dFn = func(dc DecodeContext, vr ValueReader, key string) (string, reflect.Value, error) { + ptr := reflect.New(eType) + + err = decoder.DecodeValue(dc, vr, ptr.Interface()) + if err != nil { + return key, reflect.Value{}, err + } + return key, ptr.Elem(), nil + } + } + + return dFn, nil +} + +// SliceDecodeValue is the ValueDecoderFunc for []* types. +func (dvd DefaultValueDecoders) SliceDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + val := reflect.ValueOf(i) + if !val.IsValid() || val.Kind() != reflect.Ptr || val.IsNil() { + return fmt.Errorf("SliceDecodeValue can only be used to decode non-nil pointers to slice or array values, got %T", i) + } + + switch val.Elem().Kind() { + case reflect.Slice, reflect.Array: + if !val.Elem().CanSet() { + return errors.New("SliceDecodeValue can only decode settable slice and array values") + } + default: + return fmt.Errorf("SliceDecodeValue can only decode settable slice and array values, got %T", i) + } + + switch vr.Type() { + case bson.TypeArray: + case bson.TypeNull: + if val.Elem().Kind() != reflect.Slice { + return fmt.Errorf("cannot decode %v into an array", vr.Type()) + } + null := reflect.Zero(val.Elem().Type()) + val.Elem().Set(null) + return vr.ReadNull() + default: + return fmt.Errorf("cannot decode %v into a slice", vr.Type()) + } + + eType := val.Type().Elem().Elem() + + ar, err := vr.ReadArray() + if err != nil { + return err + } + + var elems []reflect.Value + switch eType { + case tElement: + elems, err = dvd.decodeElement(dc, ar) + default: + elems, err = dvd.decodeDefault(dc, ar, eType) + } + + if err != nil { + return err + } + + switch val.Elem().Kind() { + case reflect.Slice: + slc := reflect.MakeSlice(val.Elem().Type(), len(elems), len(elems)) + + for idx, elem := range elems { + slc.Index(idx).Set(elem) + } + + val.Elem().Set(slc) + case reflect.Array: + if len(elems) > val.Elem().Len() { + return fmt.Errorf("more elements returned in array than can fit inside %s", val.Elem().Type()) + } + + for idx, elem := range elems { + val.Elem().Index(idx).Set(elem) + } + } + + return nil +} + +// EmptyInterfaceDecodeValue is the ValueDecoderFunc for interface{}. +func (dvd DefaultValueDecoders) EmptyInterfaceDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + target, ok := i.(*interface{}) + if !ok || target == nil { + return fmt.Errorf("EmptyInterfaceDecodeValue can only be used to decode non-nil *interface{} values, provided type if %T", i) + } + + // fn is a function we call to assign val back to the target, we do this so + // we can keep down on the repeated code in this method. In all of the + // implementations this is a closure, so we don't need to provide the + // target as a parameter. + var fn func() + var val interface{} + var rtype reflect.Type + + switch vr.Type() { + case bson.TypeDouble: + val = new(float64) + rtype = tFloat64 + fn = func() { *target = *(val.(*float64)) } + case bson.TypeString: + val = new(string) + rtype = tString + fn = func() { *target = *(val.(*string)) } + case bson.TypeEmbeddedDocument: + val = new(*bson.Document) + rtype = tDocument + fn = func() { *target = *val.(**bson.Document) } + case bson.TypeArray: + val = new(*bson.Array) + rtype = tArray + fn = func() { *target = *val.(**bson.Array) } + case bson.TypeBinary: + val = new(bson.Binary) + rtype = tBinary + fn = func() { *target = *(val.(*bson.Binary)) } + case bson.TypeUndefined: + val = new(bson.Undefinedv2) + rtype = tUndefined + fn = func() { *target = *(val.(*bson.Undefinedv2)) } + case bson.TypeObjectID: + val = new(objectid.ObjectID) + rtype = tOID + fn = func() { *target = *(val.(*objectid.ObjectID)) } + case bson.TypeBoolean: + val = new(bool) + rtype = tBool + fn = func() { *target = *(val.(*bool)) } + case bson.TypeDateTime: + val = new(bson.DateTime) + rtype = tDateTime + fn = func() { *target = *(val.(*bson.DateTime)) } + case bson.TypeNull: + val = new(bson.Nullv2) + rtype = tNull + fn = func() { *target = *(val.(*bson.Nullv2)) } + case bson.TypeRegex: + val = new(bson.Regex) + rtype = tRegex + fn = func() { *target = *(val.(*bson.Regex)) } + case bson.TypeDBPointer: + val = new(bson.DBPointer) + rtype = tDBPointer + fn = func() { *target = *(val.(*bson.DBPointer)) } + case bson.TypeJavaScript: + val = new(bson.JavaScriptCode) + rtype = tJavaScriptCode + fn = func() { *target = *(val.(*bson.JavaScriptCode)) } + case bson.TypeSymbol: + val = new(bson.Symbol) + rtype = tSymbol + fn = func() { *target = *(val.(*bson.Symbol)) } + case bson.TypeCodeWithScope: + val = new(bson.CodeWithScope) + rtype = tCodeWithScope + fn = func() { *target = *(val.(*bson.CodeWithScope)) } + case bson.TypeInt32: + val = new(int32) + rtype = tInt32 + fn = func() { *target = *(val.(*int32)) } + case bson.TypeInt64: + val = new(int64) + rtype = tInt64 + fn = func() { *target = *(val.(*int64)) } + case bson.TypeTimestamp: + val = new(bson.Timestamp) + rtype = tTimestamp + fn = func() { *target = *(val.(*bson.Timestamp)) } + case bson.TypeDecimal128: + val = new(decimal.Decimal128) + rtype = tDecimal + fn = func() { *target = *(val.(*decimal.Decimal128)) } + case bson.TypeMinKey: + val = new(bson.MinKeyv2) + rtype = tMinKey + fn = func() { *target = *(val.(*bson.MinKeyv2)) } + case bson.TypeMaxKey: + val = new(bson.MaxKeyv2) + rtype = tMaxKey + fn = func() { *target = *(val.(*bson.MaxKeyv2)) } + default: + return fmt.Errorf("Type %s is not a valid BSON type and has no default Go type to decode into", vr.Type()) + } + + decoder, err := dc.LookupDecoder(rtype) + if err != nil { + return err + } + err = decoder.DecodeValue(dc, vr, val) + if err != nil { + return err + } + + fn() + return nil +} + +// ValueUnmarshalerDecodeValue is the ValueDecoderFunc for ValueUnmarshaler implementations. +func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { + val := reflect.ValueOf(i) + var valueUnmarshaler ValueUnmarshaler + if val.Kind() == reflect.Ptr && val.IsNil() { + return fmt.Errorf("ValueUnmarshalerDecodeValue can only unmarshal into non-nil ValueUnmarshaler values, got %T", i) + } + if val.Type().Implements(tValueUnmarshaler) { + valueUnmarshaler = val.Interface().(ValueUnmarshaler) + } else if val.Type().Kind() == reflect.Ptr && val.Elem().Type().Implements(tValueUnmarshaler) { + if val.Elem().Kind() == reflect.Ptr && val.Elem().IsNil() { + val.Elem().Set(reflect.New(val.Type().Elem().Elem())) + } + valueUnmarshaler = val.Elem().Interface().(ValueUnmarshaler) + } else { + return fmt.Errorf("ValueUnmarshalerDecodeValue can only handle types or pointers to types that are a ValueUnmarshaler, got %T", i) + } + + t, src, err := Copier{r: dc.Registry}.CopyValueToBytes(vr) + if err != nil { + return err + } + + return valueUnmarshaler.UnmarshalBSONValue(t, src) +} + +func (dvd DefaultValueDecoders) decodeElement(dc DecodeContext, ar ArrayReader) ([]reflect.Value, error) { + elems := make([]reflect.Value, 0) + for { + vr, err := ar.ReadValue() + if err == ErrEOA { + break + } + if err != nil { + return nil, err + } + + var elem *bson.Element + err = dvd.elementDecodeValue(dc, vr, "", &elem) + if err != nil { + return nil, err + } + elems = append(elems, reflect.ValueOf(elem)) + } + + return elems, nil +} + +func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, ar ArrayReader, eType reflect.Type) ([]reflect.Value, error) { + elems := make([]reflect.Value, 0) + + decoder, err := dc.LookupDecoder(eType) + if err != nil { + return nil, err + } + + for { + vr, err := ar.ReadValue() + if err == ErrEOA { + break + } + if err != nil { + return nil, err + } + + ptr := reflect.New(eType) + + err = decoder.DecodeValue(dc, vr, ptr.Interface()) + if err != nil { + return nil, err + } + elems = append(elems, ptr.Elem()) + } + + return elems, nil +} + +func (dvd DefaultValueDecoders) decodeDocument(dctx DecodeContext, dr DocumentReader, pdoc **bson.Document) error { + doc := bson.NewDocument() + for { + key, vr, err := dr.ReadElement() + if err == ErrEOD { + break + } + if err != nil { + return err + } + + var elem *bson.Element + err = dvd.elementDecodeValue(dctx, vr, key, &elem) + if err != nil { + return err + } + + doc.Append(elem) + } + + *pdoc = doc + return nil +} + +func (dvd DefaultValueDecoders) elementDecodeValue(dc DecodeContext, vr ValueReader, key string, elem **bson.Element) error { + switch vr.Type() { + case bson.TypeDouble: + f64, err := vr.ReadDouble() + if err != nil { + return err + } + *elem = bson.EC.Double(key, f64) + case bson.TypeString: + str, err := vr.ReadString() + if err != nil { + return err + } + *elem = bson.EC.String(key, str) + case bson.TypeEmbeddedDocument: + decoder, err := dc.LookupDecoder(tDocument) + if err != nil { + return err + } + var embeddedDoc *bson.Document + err = decoder.DecodeValue(dc, vr, &embeddedDoc) + if err != nil { + return err + } + *elem = bson.EC.SubDocument(key, embeddedDoc) + case bson.TypeArray: + decoder, err := dc.LookupDecoder(tArray) + if err != nil { + return err + } + var arr *bson.Array + err = decoder.DecodeValue(dc, vr, &arr) + if err != nil { + return err + } + *elem = bson.EC.Array(key, arr) + case bson.TypeBinary: + data, subtype, err := vr.ReadBinary() + if err != nil { + return err + } + *elem = bson.EC.BinaryWithSubtype(key, data, subtype) + case bson.TypeUndefined: + err := vr.ReadUndefined() + if err != nil { + return err + } + *elem = bson.EC.Undefined(key) + case bson.TypeObjectID: + oid, err := vr.ReadObjectID() + if err != nil { + return err + } + *elem = bson.EC.ObjectID(key, oid) + case bson.TypeBoolean: + b, err := vr.ReadBoolean() + if err != nil { + return err + } + *elem = bson.EC.Boolean(key, b) + case bson.TypeDateTime: + dt, err := vr.ReadDateTime() + if err != nil { + return err + } + *elem = bson.EC.DateTime(key, dt) + case bson.TypeNull: + err := vr.ReadNull() + if err != nil { + return err + } + *elem = bson.EC.Null(key) + case bson.TypeRegex: + pattern, options, err := vr.ReadRegex() + if err != nil { + return err + } + *elem = bson.EC.Regex(key, pattern, options) + case bson.TypeDBPointer: + ns, pointer, err := vr.ReadDBPointer() + if err != nil { + return err + } + *elem = bson.EC.DBPointer(key, ns, pointer) + case bson.TypeJavaScript: + js, err := vr.ReadJavascript() + if err != nil { + return err + } + *elem = bson.EC.JavaScript(key, js) + case bson.TypeSymbol: + symbol, err := vr.ReadSymbol() + if err != nil { + return err + } + *elem = bson.EC.Symbol(key, symbol) + case bson.TypeCodeWithScope: + code, scope, err := vr.ReadCodeWithScope() + if err != nil { + return err + } + scopeDoc := new(*bson.Document) + err = dvd.decodeDocument(dc, scope, scopeDoc) + if err != nil { + return err + } + *elem = bson.EC.CodeWithScope(key, code, *scopeDoc) + case bson.TypeInt32: + i32, err := vr.ReadInt32() + if err != nil { + return err + } + *elem = bson.EC.Int32(key, i32) + case bson.TypeTimestamp: + t, i, err := vr.ReadTimestamp() + if err != nil { + return err + } + *elem = bson.EC.Timestamp(key, t, i) + case bson.TypeInt64: + i64, err := vr.ReadInt64() + if err != nil { + return err + } + *elem = bson.EC.Int64(key, i64) + case bson.TypeDecimal128: + d128, err := vr.ReadDecimal128() + if err != nil { + return err + } + *elem = bson.EC.Decimal128(key, d128) + case bson.TypeMinKey: + err := vr.ReadMinKey() + if err != nil { + return err + } + *elem = bson.EC.MinKey(key) + case bson.TypeMaxKey: + err := vr.ReadMaxKey() + if err != nil { + return err + } + *elem = bson.EC.MaxKey(key) + default: + return fmt.Errorf("Cannot read unknown BSON type %s", vr.Type()) + } + + return nil +} + +func (dvd DefaultValueDecoders) valueDecodeValue(dc DecodeContext, vr ValueReader, val **bson.Value) error { + switch vr.Type() { + case bson.TypeDouble: + f64, err := vr.ReadDouble() + if err != nil { + return err + } + *val = bson.VC.Double(f64) + case bson.TypeString: + str, err := vr.ReadString() + if err != nil { + return err + } + *val = bson.VC.String(str) + case bson.TypeEmbeddedDocument: + decoder, err := dc.LookupDecoder(tDocument) + if err != nil { + return err + } + var embeddedDoc *bson.Document + err = decoder.DecodeValue(dc, vr, &embeddedDoc) + if err != nil { + return err + } + *val = bson.VC.Document(embeddedDoc) + case bson.TypeArray: + decoder, err := dc.LookupDecoder(tArray) + if err != nil { + return err + } + var arr *bson.Array + err = decoder.DecodeValue(dc, vr, &arr) + if err != nil { + return err + } + *val = bson.VC.Array(arr) + case bson.TypeBinary: + data, subtype, err := vr.ReadBinary() + if err != nil { + return err + } + *val = bson.VC.BinaryWithSubtype(data, subtype) + case bson.TypeUndefined: + err := vr.ReadUndefined() + if err != nil { + return err + } + *val = bson.VC.Undefined() + case bson.TypeObjectID: + oid, err := vr.ReadObjectID() + if err != nil { + return err + } + *val = bson.VC.ObjectID(oid) + case bson.TypeBoolean: + b, err := vr.ReadBoolean() + if err != nil { + return err + } + *val = bson.VC.Boolean(b) + case bson.TypeDateTime: + dt, err := vr.ReadDateTime() + if err != nil { + return err + } + *val = bson.VC.DateTime(dt) + case bson.TypeNull: + err := vr.ReadNull() + if err != nil { + return err + } + *val = bson.VC.Null() + case bson.TypeRegex: + pattern, options, err := vr.ReadRegex() + if err != nil { + return err + } + *val = bson.VC.Regex(pattern, options) + case bson.TypeDBPointer: + ns, pointer, err := vr.ReadDBPointer() + if err != nil { + return err + } + *val = bson.VC.DBPointer(ns, pointer) + case bson.TypeJavaScript: + js, err := vr.ReadJavascript() + if err != nil { + return err + } + *val = bson.VC.JavaScript(js) + case bson.TypeSymbol: + symbol, err := vr.ReadSymbol() + if err != nil { + return err + } + *val = bson.VC.Symbol(symbol) + case bson.TypeCodeWithScope: + code, scope, err := vr.ReadCodeWithScope() + if err != nil { + return err + } + scopeDoc := new(*bson.Document) + err = dvd.decodeDocument(dc, scope, scopeDoc) + if err != nil { + return err + } + *val = bson.VC.CodeWithScope(code, *scopeDoc) + case bson.TypeInt32: + i32, err := vr.ReadInt32() + if err != nil { + return err + } + *val = bson.VC.Int32(i32) + case bson.TypeTimestamp: + t, i, err := vr.ReadTimestamp() + if err != nil { + return err + } + *val = bson.VC.Timestamp(t, i) + case bson.TypeInt64: + i64, err := vr.ReadInt64() + if err != nil { + return err + } + *val = bson.VC.Int64(i64) + case bson.TypeDecimal128: + d128, err := vr.ReadDecimal128() + if err != nil { + return err + } + *val = bson.VC.Decimal128(d128) + case bson.TypeMinKey: + err := vr.ReadMinKey() + if err != nil { + return err + } + *val = bson.VC.MinKey() + case bson.TypeMaxKey: + err := vr.ReadMaxKey() + if err != nil { + return err + } + *val = bson.VC.MaxKey() + default: + return fmt.Errorf("Cannot read unknown BSON type %s", vr.Type()) + } + + return nil +} diff --git a/bson/bsoncodec/default_value_decoders_test.go b/bson/bsoncodec/default_value_decoders_test.go new file mode 100644 index 0000000000..b9cb4422c8 --- /dev/null +++ b/bson/bsoncodec/default_value_decoders_test.go @@ -0,0 +1,2690 @@ +package bsoncodec + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "math" + "net/url" + "reflect" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/decimal" + "github.com/mongodb/mongo-go-driver/bson/internal/llbson" + "github.com/mongodb/mongo-go-driver/bson/objectid" +) + +func TestDefaultValueDecoders(t *testing.T) { + var dvd DefaultValueDecoders + var wrong = func(string, string) string { return "wrong" } + + type mybool bool + type myint8 int8 + type myint16 int16 + type myint32 int32 + type myint64 int64 + type myint int + type myuint8 uint8 + type myuint16 uint16 + type myuint32 uint32 + type myuint64 uint64 + type myuint uint + type myfloat32 float32 + type myfloat64 float64 + type mystring string + + const cansetreflectiontest = "cansetreflectiontest" + + intAllowedDecodeTypes := []interface{}{(*int8)(nil), (*int16)(nil), (*int32)(nil), (*int64)(nil), (*int)(nil)} + uintAllowedDecodeTypes := []interface{}{(*uint8)(nil), (*uint16)(nil), (*uint32)(nil), (*uint64)(nil), (*uint)(nil)} + now := time.Now().Truncate(time.Millisecond) + d128 := decimal.NewDecimal128(12345, 67890) + var ptrPtrValueUnmarshaler **testValueUnmarshaler + + type subtest struct { + name string + val interface{} + dctx *DecodeContext + llvrw *llValueReaderWriter + invoke llvrwInvoked + err error + } + + testCases := []struct { + name string + vd ValueDecoder + subtests []subtest + }{ + { + "BooleanDecodeValue", + ValueDecoderFunc(dvd.BooleanDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeBoolean}, + llvrwNothing, + ValueDecoderError{Name: "BooleanDecodeValue", Types: []interface{}{bool(true)}, Received: &wrong}, + }, + { + "type not boolean", + bool(false), + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a boolean", bson.TypeString), + }, + { + "fast path", + bool(true), + nil, + &llValueReaderWriter{bsontype: bson.TypeBoolean, readval: bool(true)}, + llvrwReadBoolean, + nil, + }, + { + "reflection path", + mybool(true), + nil, + &llValueReaderWriter{bsontype: bson.TypeBoolean, readval: bool(true)}, + llvrwReadBoolean, + nil, + }, + { + "reflection path error", + mybool(true), + nil, + &llValueReaderWriter{bsontype: bson.TypeBoolean, readval: bool(true), err: errors.New("ReadBoolean Error")}, + llvrwReadBoolean, errors.New("ReadBoolean Error"), + }, + { + "can set false", + cansetreflectiontest, + nil, + &llValueReaderWriter{bsontype: bson.TypeBoolean}, + llvrwNothing, + errors.New("BooleanDecodeValue can only be used to decode settable (non-nil) values"), + }, + }, + }, + { + "IntDecodeValue", + ValueDecoderFunc(dvd.IntDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0)}, + llvrwReadInt32, + ValueDecoderError{Name: "IntDecodeValue", Types: intAllowedDecodeTypes, Received: &wrong}, + }, + { + "type not int32/int64", + 0, + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into an integer type", bson.TypeString), + }, + { + "ReadInt32 error", + 0, + nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0), err: errors.New("ReadInt32 error"), errAfter: llvrwReadInt32}, + llvrwReadInt32, + errors.New("ReadInt32 error"), + }, + { + "ReadInt64 error", + 0, + nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(0), err: errors.New("ReadInt64 error"), errAfter: llvrwReadInt64}, + llvrwReadInt64, + errors.New("ReadInt64 error"), + }, + { + "ReadDouble error", + 0, + nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(0), err: errors.New("ReadDouble error"), errAfter: llvrwReadDouble}, + llvrwReadDouble, + errors.New("ReadDouble error"), + }, + { + "ReadDouble", int64(3), &DecodeContext{}, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.00)}, llvrwReadDouble, + nil, + }, + { + "ReadDouble (truncate)", int64(3), &DecodeContext{Truncate: true}, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.14)}, llvrwReadDouble, + nil, + }, + { + "ReadDouble (no truncate)", int64(0), nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.14)}, llvrwReadDouble, + errors.New("IntDecodeValue can only truncate float64 to an integer type when truncation is enabled"), + }, + { + "ReadDouble overflows int64", int64(0), nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: math.MaxFloat64}, llvrwReadDouble, + fmt.Errorf("%g overflows int64", math.MaxFloat64), + }, + {"int8/fast path", int8(127), nil, &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(127)}, llvrwReadInt32, nil}, + {"int16/fast path", int16(32676), nil, &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(32676)}, llvrwReadInt32, nil}, + {"int32/fast path", int32(1234), nil, &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(1234)}, llvrwReadInt32, nil}, + {"int64/fast path", int64(1234), nil, &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(1234)}, llvrwReadInt64, nil}, + {"int/fast path", int(1234), nil, &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(1234)}, llvrwReadInt64, nil}, + { + "int8/fast path - nil", (*int8)(nil), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0)}, llvrwReadInt32, + errors.New("IntDecodeValue can only be used to decode non-nil *int8"), + }, + { + "int16/fast path - nil", (*int16)(nil), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0)}, llvrwReadInt32, + errors.New("IntDecodeValue can only be used to decode non-nil *int16"), + }, + { + "int32/fast path - nil", (*int32)(nil), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0)}, llvrwReadInt32, + errors.New("IntDecodeValue can only be used to decode non-nil *int32"), + }, + { + "int64/fast path - nil", (*int64)(nil), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0)}, llvrwReadInt32, + errors.New("IntDecodeValue can only be used to decode non-nil *int64"), + }, + { + "int/fast path - nil", (*int)(nil), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0)}, llvrwReadInt32, + errors.New("IntDecodeValue can only be used to decode non-nil *int"), + }, + { + "int8/fast path - overflow", int8(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(129)}, llvrwReadInt32, + fmt.Errorf("%d overflows int8", 129), + }, + { + "int16/fast path - overflow", int16(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(32768)}, llvrwReadInt32, + fmt.Errorf("%d overflows int16", 32768), + }, + { + "int32/fast path - overflow", int32(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(2147483648)}, llvrwReadInt64, + fmt.Errorf("%d overflows int32", 2147483648), + }, + { + "int8/fast path - overflow (negative)", int8(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(-129)}, llvrwReadInt32, + fmt.Errorf("%d overflows int8", -129), + }, + { + "int16/fast path - overflow (negative)", int16(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(-32769)}, llvrwReadInt32, + fmt.Errorf("%d overflows int16", -32769), + }, + { + "int32/fast path - overflow (negative)", int32(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(-2147483649)}, llvrwReadInt64, + fmt.Errorf("%d overflows int32", -2147483649), + }, + { + "int8/reflection path", myint8(127), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(127)}, llvrwReadInt32, + nil, + }, + { + "int16/reflection path", myint16(255), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(255)}, llvrwReadInt32, + nil, + }, + { + "int32/reflection path", myint32(511), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(511)}, llvrwReadInt32, + nil, + }, + { + "int64/reflection path", myint64(1023), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(1023)}, llvrwReadInt32, + nil, + }, + { + "int/reflection path", myint(2047), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(2047)}, llvrwReadInt32, + nil, + }, + { + "int8/reflection path - overflow", myint8(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(129)}, llvrwReadInt32, + fmt.Errorf("%d overflows int8", 129), + }, + { + "int16/reflection path - overflow", myint16(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(32768)}, llvrwReadInt32, + fmt.Errorf("%d overflows int16", 32768), + }, + { + "int32/reflection path - overflow", myint32(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(2147483648)}, llvrwReadInt64, + fmt.Errorf("%d overflows int32", 2147483648), + }, + { + "int8/reflection path - overflow (negative)", myint8(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(-129)}, llvrwReadInt32, + fmt.Errorf("%d overflows int8", -129), + }, + { + "int16/reflection path - overflow (negative)", myint16(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(-32769)}, llvrwReadInt32, + fmt.Errorf("%d overflows int16", -32769), + }, + { + "int32/reflection path - overflow (negative)", myint32(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(-2147483649)}, llvrwReadInt64, + fmt.Errorf("%d overflows int32", -2147483649), + }, + { + "can set false", + cansetreflectiontest, + nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0)}, + llvrwNothing, + errors.New("IntDecodeValue can only be used to decode settable (non-nil) values"), + }, + }, + }, + { + "UintDecodeValue", + ValueDecoderFunc(dvd.UintDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0)}, + llvrwReadInt32, + ValueDecoderError{Name: "UintDecodeValue", Types: uintAllowedDecodeTypes, Received: &wrong}, + }, + { + "type not int32/int64", + 0, + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into an integer type", bson.TypeString), + }, + { + "ReadInt32 error", + uint(0), + nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0), err: errors.New("ReadInt32 error"), errAfter: llvrwReadInt32}, + llvrwReadInt32, + errors.New("ReadInt32 error"), + }, + { + "ReadInt64 error", + uint(0), + nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(0), err: errors.New("ReadInt64 error"), errAfter: llvrwReadInt64}, + llvrwReadInt64, + errors.New("ReadInt64 error"), + }, + { + "ReadDouble error", + 0, + nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(0), err: errors.New("ReadDouble error"), errAfter: llvrwReadDouble}, + llvrwReadDouble, + errors.New("ReadDouble error"), + }, + { + "ReadDouble", uint64(3), &DecodeContext{}, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.00)}, llvrwReadDouble, + nil, + }, + { + "ReadDouble (truncate)", uint64(3), &DecodeContext{Truncate: true}, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.14)}, llvrwReadDouble, + nil, + }, + { + "ReadDouble (no truncate)", uint64(0), nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.14)}, llvrwReadDouble, + errors.New("UintDecodeValue can only truncate float64 to an integer type when truncation is enabled"), + }, + { + "ReadDouble overflows int64", uint64(0), nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: math.MaxFloat64}, llvrwReadDouble, + fmt.Errorf("%g overflows int64", math.MaxFloat64), + }, + {"uint8/fast path", uint8(127), nil, &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(127)}, llvrwReadInt32, nil}, + {"uint16/fast path", uint16(255), nil, &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(255)}, llvrwReadInt32, nil}, + {"uint32/fast path", uint32(1234), nil, &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(1234)}, llvrwReadInt32, nil}, + {"uint64/fast path", uint64(1234), nil, &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(1234)}, llvrwReadInt64, nil}, + {"uint/fast path", uint(1234), nil, &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(1234)}, llvrwReadInt64, nil}, + { + "uint8/fast path - nil", (*uint8)(nil), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0)}, llvrwReadInt32, + errors.New("UintDecodeValue can only be used to decode non-nil *uint8"), + }, + { + "uint16/fast path - nil", (*uint16)(nil), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0)}, llvrwReadInt32, + errors.New("UintDecodeValue can only be used to decode non-nil *uint16"), + }, + { + "uint32/fast path - nil", (*uint32)(nil), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0)}, llvrwReadInt32, + errors.New("UintDecodeValue can only be used to decode non-nil *uint32"), + }, + { + "uint64/fast path - nil", (*uint64)(nil), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0)}, llvrwReadInt32, + errors.New("UintDecodeValue can only be used to decode non-nil *uint64"), + }, + { + "uint/fast path - nil", (*uint)(nil), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0)}, llvrwReadInt32, + errors.New("UintDecodeValue can only be used to decode non-nil *uint"), + }, + { + "uint8/fast path - overflow", uint8(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(1 << 8)}, llvrwReadInt32, + fmt.Errorf("%d overflows uint8", 1<<8), + }, + { + "uint16/fast path - overflow", uint16(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(1 << 16)}, llvrwReadInt32, + fmt.Errorf("%d overflows uint16", 1<<16), + }, + { + "uint32/fast path - overflow", uint32(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(1 << 32)}, llvrwReadInt64, + fmt.Errorf("%d overflows uint32", 1<<32), + }, + { + "uint8/fast path - overflow (negative)", uint8(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(-1)}, llvrwReadInt32, + fmt.Errorf("%d overflows uint8", -1), + }, + { + "uint16/fast path - overflow (negative)", uint16(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(-1)}, llvrwReadInt32, + fmt.Errorf("%d overflows uint16", -1), + }, + { + "uint32/fast path - overflow (negative)", uint32(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(-1)}, llvrwReadInt64, + fmt.Errorf("%d overflows uint32", -1), + }, + { + "uint64/fast path - overflow (negative)", uint64(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(-1)}, llvrwReadInt64, + fmt.Errorf("%d overflows uint64", -1), + }, + { + "uint/fast path - overflow (negative)", uint(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(-1)}, llvrwReadInt64, + fmt.Errorf("%d overflows uint", -1), + }, + { + "uint8/reflection path", myuint8(127), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(127)}, llvrwReadInt32, + nil, + }, + { + "uint16/reflection path", myuint16(255), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(255)}, llvrwReadInt32, + nil, + }, + { + "uint32/reflection path", myuint32(511), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(511)}, llvrwReadInt32, + nil, + }, + { + "uint64/reflection path", myuint64(1023), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(1023)}, llvrwReadInt32, + nil, + }, + { + "uint/reflection path", myuint(2047), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(2047)}, llvrwReadInt32, + nil, + }, + { + "uint8/reflection path - overflow", myuint8(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(1 << 8)}, llvrwReadInt32, + fmt.Errorf("%d overflows uint8", 1<<8), + }, + { + "uint16/reflection path - overflow", myuint16(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(1 << 16)}, llvrwReadInt32, + fmt.Errorf("%d overflows uint16", 1<<16), + }, + { + "uint32/reflection path - overflow", myuint32(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(1 << 32)}, llvrwReadInt64, + fmt.Errorf("%d overflows uint32", 1<<32), + }, + { + "uint8/reflection path - overflow (negative)", myuint8(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(-1)}, llvrwReadInt32, + fmt.Errorf("%d overflows uint8", -1), + }, + { + "uint16/reflection path - overflow (negative)", myuint16(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(-1)}, llvrwReadInt32, + fmt.Errorf("%d overflows uint16", -1), + }, + { + "uint32/reflection path - overflow (negative)", myuint32(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(-1)}, llvrwReadInt64, + fmt.Errorf("%d overflows uint32", -1), + }, + { + "uint64/reflection path - overflow (negative)", myuint64(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(-1)}, llvrwReadInt64, + fmt.Errorf("%d overflows uint64", -1), + }, + { + "uint/reflection path - overflow (negative)", myuint(0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(-1)}, llvrwReadInt64, + fmt.Errorf("%d overflows uint", -1), + }, + { + "can set false", + cansetreflectiontest, + nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0)}, + llvrwNothing, + errors.New("UintDecodeValue can only be used to decode settable (non-nil) values"), + }, + }, + }, + { + "FloatDecodeValue", + ValueDecoderFunc(dvd.FloatDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(0)}, + llvrwReadDouble, + ValueDecoderError{Name: "FloatDecodeValue", Types: []interface{}{(*float32)(nil), (*float64)(nil)}, Received: &wrong}, + }, + { + "type not double", + 0, + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a float32 or float64 type", bson.TypeString), + }, + { + "ReadDouble error", + float64(0), + nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(0), err: errors.New("ReadDouble error"), errAfter: llvrwReadDouble}, + llvrwReadDouble, + errors.New("ReadDouble error"), + }, + { + "ReadInt32 error", + float64(0), + nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0), err: errors.New("ReadInt32 error"), errAfter: llvrwReadInt32}, + llvrwReadInt32, + errors.New("ReadInt32 error"), + }, + { + "ReadInt64 error", + float64(0), + nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(0), err: errors.New("ReadInt64 error"), errAfter: llvrwReadInt64}, + llvrwReadInt64, + errors.New("ReadInt64 error"), + }, + { + "float64/int32", float32(32.0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(32)}, llvrwReadInt32, + nil, + }, + { + "float64/int64", float32(64.0), nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(64)}, llvrwReadInt64, + nil, + }, + { + "float32/fast path (equal)", float32(3.0), nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.0)}, llvrwReadDouble, + nil, + }, + { + "float64/fast path", float64(3.14159), nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.14159)}, llvrwReadDouble, + nil, + }, + { + "float32/fast path (truncate)", float32(3.14), &DecodeContext{Truncate: true}, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.14)}, llvrwReadDouble, + nil, + }, + { + "float32/fast path (no truncate)", float32(0), nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.14)}, llvrwReadDouble, + errors.New("FloatDecodeValue can only convert float64 to float32 when truncation is allowed"), + }, + { + "float32/fast path - nil", (*float32)(nil), nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(0)}, llvrwReadDouble, + errors.New("FloatDecodeValue can only be used to decode non-nil *float32"), + }, + { + "float64/fast path - nil", (*float64)(nil), nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(0)}, llvrwReadDouble, + errors.New("FloatDecodeValue can only be used to decode non-nil *float64"), + }, + { + "float32/reflection path (equal)", myfloat32(3.0), nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.0)}, llvrwReadDouble, + nil, + }, + { + "float64/reflection path", myfloat64(3.14159), nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.14159)}, llvrwReadDouble, + nil, + }, + { + "float32/reflection path (truncate)", myfloat32(3.14), &DecodeContext{Truncate: true}, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.14)}, llvrwReadDouble, + nil, + }, + { + "float32/reflection path (no truncate)", myfloat32(0), nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.14)}, llvrwReadDouble, + errors.New("FloatDecodeValue can only convert float64 to float32 when truncation is allowed"), + }, + { + "can set false", + cansetreflectiontest, + nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(0)}, + llvrwNothing, + errors.New("FloatDecodeValue can only be used to decode settable (non-nil) values"), + }, + }, + }, + { + "StringDecodeValue", + ValueDecoderFunc(dvd.StringDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeString, readval: string("")}, + llvrwReadString, + ValueDecoderError{Name: "StringDecodeValue", Types: []interface{}{(*string)(nil), (*bson.JavaScriptCode)(nil), (*bson.Symbol)(nil)}, Received: &wrong}, + }, + { + "type not string", + string(""), + nil, + &llValueReaderWriter{bsontype: bson.TypeBoolean}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a string type", bson.TypeBoolean), + }, + { + "ReadString error", + string(""), + nil, + &llValueReaderWriter{bsontype: bson.TypeString, readval: string(""), err: errors.New("ReadString error"), errAfter: llvrwReadString}, + llvrwReadString, + errors.New("ReadString error"), + }, + { + "ReadJavaScript error", + string(""), + nil, + &llValueReaderWriter{bsontype: bson.TypeJavaScript, readval: string(""), err: errors.New("ReadJS error"), errAfter: llvrwReadJavascript}, + llvrwReadJavascript, + errors.New("ReadJS error"), + }, + { + "ReadSymbol error", + string(""), + nil, + &llValueReaderWriter{bsontype: bson.TypeSymbol, readval: string(""), err: errors.New("ReadSymbol error"), errAfter: llvrwReadSymbol}, + llvrwReadSymbol, + errors.New("ReadSymbol error"), + }, + { + "string/fast path", + string("foobar"), + nil, + &llValueReaderWriter{bsontype: bson.TypeString, readval: string("foobar")}, + llvrwReadString, + nil, + }, + { + "JavaScript/fast path", + bson.JavaScriptCode("var hello = 'world';"), + nil, + &llValueReaderWriter{bsontype: bson.TypeJavaScript, readval: string("var hello = 'world';")}, + llvrwReadJavascript, + nil, + }, + { + "Symbol/fast path", + bson.Symbol("foobarbaz"), + nil, + &llValueReaderWriter{bsontype: bson.TypeSymbol, readval: bson.Symbol("foobarbaz")}, + llvrwReadSymbol, + nil, + }, + { + "string/fast path - nil", (*string)(nil), nil, + &llValueReaderWriter{bsontype: bson.TypeString, readval: string("")}, llvrwReadString, + errors.New("StringDecodeValue can only be used to decode non-nil *string"), + }, + { + "JavaScript/fast path - nil", (*bson.JavaScriptCode)(nil), nil, + &llValueReaderWriter{bsontype: bson.TypeJavaScript, readval: string("")}, llvrwReadJavascript, + errors.New("StringDecodeValue can only be used to decode non-nil *JavaScriptCode"), + }, + { + "Symbol/fast path - nil", (*bson.Symbol)(nil), nil, + &llValueReaderWriter{bsontype: bson.TypeSymbol, readval: bson.Symbol("")}, llvrwReadSymbol, + errors.New("StringDecodeValue can only be used to decode non-nil *Symbol"), + }, + { + "reflection path", + mystring("foobarbaz"), + nil, + &llValueReaderWriter{bsontype: bson.TypeString, readval: string("foobarbaz")}, + llvrwReadString, + nil, + }, + { + "reflection path error", + mystring("foobarbazqux"), + nil, + &llValueReaderWriter{bsontype: bson.TypeString, readval: string("foobarbazqux"), err: errors.New("ReadString Error"), errAfter: llvrwReadString}, + llvrwReadString, errors.New("ReadString Error"), + }, + { + "can set false", + cansetreflectiontest, + nil, + &llValueReaderWriter{bsontype: bson.TypeString, readval: string("")}, + llvrwNothing, + errors.New("StringDecodeValue can only be used to decode settable (non-nil) values"), + }, + }, + }, + { + "TimeDecodeValue", + ValueDecoderFunc(dvd.TimeDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(0)}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a time.Time", bson.TypeInt32), + }, + { + "type not *time.Time", + int64(0), + nil, + &llValueReaderWriter{bsontype: bson.TypeDateTime, readval: int64(1234567890)}, + llvrwReadDateTime, + ValueDecoderError{ + Name: "TimeDecodeValue", + Types: []interface{}{(*time.Time)(nil), (**time.Time)(nil)}, + Received: (*int64)(nil), + }, + }, + { + "ReadDateTime error", + time.Time{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeDateTime, readval: int64(0), err: errors.New("ReadDateTime error"), errAfter: llvrwReadDateTime}, + llvrwReadDateTime, + errors.New("ReadDateTime error"), + }, + { + "time.Time", + now, + nil, + &llValueReaderWriter{bsontype: bson.TypeDateTime, readval: int64(now.UnixNano() / int64(time.Millisecond))}, + llvrwReadDateTime, + nil, + }, + { + "*time.Time", + &now, + nil, + &llValueReaderWriter{bsontype: bson.TypeDateTime, readval: int64(now.UnixNano() / int64(time.Millisecond))}, + llvrwReadDateTime, + nil, + }, + }, + }, + { + "MapDecodeValue", + ValueDecoderFunc(dvd.MapDecodeValue), + []subtest{ + { + "wrong kind", + wrong, + nil, + &llValueReaderWriter{}, + llvrwNothing, + errors.New("MapDecodeValue can only decode settable maps with string keys"), + }, + { + "wrong kind (non-string key)", + map[int]interface{}{}, + nil, + &llValueReaderWriter{}, + llvrwNothing, + errors.New("MapDecodeValue can only decode settable maps with string keys"), + }, + { + "ReadDocument Error", + make(map[string]interface{}), + nil, + &llValueReaderWriter{err: errors.New("rd error"), errAfter: llvrwReadDocument}, + llvrwReadDocument, + errors.New("rd error"), + }, + { + "Lookup Error", + map[string]string{}, + &DecodeContext{Registry: NewEmptyRegistryBuilder().Build()}, + &llValueReaderWriter{}, + llvrwReadDocument, + ErrNoDecoder{Type: reflect.TypeOf(string(""))}, + }, + { + "ReadElement Error", + make(map[string]interface{}), + &DecodeContext{Registry: NewRegistryBuilder().Build()}, + &llValueReaderWriter{err: errors.New("re error"), errAfter: llvrwReadElement}, + llvrwReadElement, + errors.New("re error"), + }, + { + "DecodeValue Error", + map[string]string{"foo": "bar"}, + &DecodeContext{Registry: NewRegistryBuilder().Build()}, + &llValueReaderWriter{bsontype: bson.TypeString, err: errors.New("dv error"), errAfter: llvrwReadString}, + llvrwReadString, + errors.New("dv error"), + }, + }, + }, + { + "SliceDecodeValue", + ValueDecoderFunc(dvd.SliceDecodeValue), + []subtest{ + { + "wrong kind", + wrong, + nil, + &llValueReaderWriter{}, + llvrwNothing, + fmt.Errorf("SliceDecodeValue can only decode settable slice and array values, got %T", &wrong), + }, + { + "can set false", + (*[]string)(nil), + nil, + &llValueReaderWriter{}, + llvrwNothing, + fmt.Errorf("SliceDecodeValue can only be used to decode non-nil pointers to slice or array values, got %T", (*[]string)(nil)), + }, + { + "Not Type Array", + []interface{}{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + errors.New("cannot decode string into a slice"), + }, + { + "ReadArray Error", + []interface{}{}, + nil, + &llValueReaderWriter{err: errors.New("ra error"), errAfter: llvrwReadArray, bsontype: bson.TypeArray}, + llvrwReadArray, + errors.New("ra error"), + }, + { + "Lookup Error", + []string{}, + &DecodeContext{Registry: NewEmptyRegistryBuilder().Build()}, + &llValueReaderWriter{bsontype: bson.TypeArray}, + llvrwReadArray, + ErrNoDecoder{Type: reflect.TypeOf(string(""))}, + }, + { + "ReadValue Error", + []string{}, + &DecodeContext{Registry: NewRegistryBuilder().Build()}, + &llValueReaderWriter{err: errors.New("rv error"), errAfter: llvrwReadValue, bsontype: bson.TypeArray}, + llvrwReadValue, + errors.New("rv error"), + }, + { + "DecodeValue Error", + []string{}, + &DecodeContext{Registry: NewRegistryBuilder().Build()}, + &llValueReaderWriter{bsontype: bson.TypeArray}, + llvrwReadValue, + errors.New("cannot decode array into a string type"), + }, + }, + }, + { + "BinaryDecodeValue", + ValueDecoderFunc(dvd.BinaryDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeBinary, readval: bson.Binary{}}, + llvrwReadBinary, + ValueDecoderError{Name: "BinaryDecodeValue", Types: []interface{}{(*bson.Binary)(nil)}, Received: &wrong}, + }, + { + "type not binary", + bson.Binary{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a Binary", bson.TypeString), + }, + { + "ReadBinary Error", + bson.Binary{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeBinary, err: errors.New("rb error"), errAfter: llvrwReadBinary}, + llvrwReadBinary, + errors.New("rb error"), + }, + { + "Binary/success", + bson.Binary{Data: []byte{0x01, 0x02, 0x03}, Subtype: 0xFF}, + nil, + &llValueReaderWriter{bsontype: bson.TypeBinary, readval: bson.Binary{Data: []byte{0x01, 0x02, 0x03}, Subtype: 0xFF}}, + llvrwReadBinary, + nil, + }, + { + "*Binary/success", + &bson.Binary{Data: []byte{0x01, 0x02, 0x03}, Subtype: 0xFF}, + nil, + &llValueReaderWriter{bsontype: bson.TypeBinary, readval: bson.Binary{Data: []byte{0x01, 0x02, 0x03}, Subtype: 0xFF}}, + llvrwReadBinary, + nil, + }, + }, + }, + { + "UndefinedDecodeValue", + ValueDecoderFunc(dvd.UndefinedDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeUndefined}, + llvrwNothing, + ValueDecoderError{Name: "UndefinedDecodeValue", Types: []interface{}{(*bson.Undefinedv2)(nil)}, Received: &wrong}, + }, + { + "type not undefined", + bson.Undefinedv2{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into an Undefined", bson.TypeString), + }, + { + "ReadUndefined Error", + bson.Undefinedv2{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeUndefined, err: errors.New("ru error"), errAfter: llvrwReadUndefined}, + llvrwReadUndefined, + errors.New("ru error"), + }, + { + "ReadUndefined/success", + bson.Undefinedv2{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeUndefined}, + llvrwReadUndefined, + nil, + }, + }, + }, + { + "ObjectIDDecodeValue", + ValueDecoderFunc(dvd.ObjectIDDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeObjectID}, + llvrwNothing, + ValueDecoderError{Name: "ObjectIDDecodeValue", Types: []interface{}{(*objectid.ObjectID)(nil)}, Received: &wrong}, + }, + { + "type not objectID", + objectid.ObjectID{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into an ObjectID", bson.TypeString), + }, + { + "ReadObjectID Error", + objectid.ObjectID{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeObjectID, err: errors.New("roid error"), errAfter: llvrwReadObjectID}, + llvrwReadObjectID, + errors.New("roid error"), + }, + { + "success", + objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, + nil, + &llValueReaderWriter{ + bsontype: bson.TypeObjectID, + readval: objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, + }, + llvrwReadObjectID, + nil, + }, + }, + }, + { + "DateTimeDecodeValue", + ValueDecoderFunc(dvd.DateTimeDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeDateTime}, + llvrwNothing, + ValueDecoderError{Name: "DateTimeDecodeValue", Types: []interface{}{(*bson.DateTime)(nil)}, Received: &wrong}, + }, + { + "type not datetime", + bson.DateTime(0), + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a DateTime", bson.TypeString), + }, + { + "ReadDateTime Error", + bson.DateTime(0), + nil, + &llValueReaderWriter{bsontype: bson.TypeDateTime, err: errors.New("rdt error"), errAfter: llvrwReadDateTime}, + llvrwReadDateTime, + errors.New("rdt error"), + }, + { + "success", + bson.DateTime(1234567890), + nil, + &llValueReaderWriter{bsontype: bson.TypeDateTime, readval: int64(1234567890)}, + llvrwReadDateTime, + nil, + }, + }, + }, + { + "NullDecodeValue", + ValueDecoderFunc(dvd.NullDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeNull}, + llvrwNothing, + ValueDecoderError{Name: "NullDecodeValue", Types: []interface{}{(*bson.Nullv2)(nil)}, Received: &wrong}, + }, + { + "type not null", + bson.Nullv2{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a Null", bson.TypeString), + }, + { + "ReadNull Error", + bson.Nullv2{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeNull, err: errors.New("rn error"), errAfter: llvrwReadNull}, + llvrwReadNull, + errors.New("rn error"), + }, + { + "success", + bson.Nullv2{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeNull}, + llvrwReadNull, + nil, + }, + }, + }, + { + "RegexDecodeValue", + ValueDecoderFunc(dvd.RegexDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeRegex}, + llvrwNothing, + ValueDecoderError{Name: "RegexDecodeValue", Types: []interface{}{(*bson.Regex)(nil)}, Received: &wrong}, + }, + { + "type not regex", + bson.Regex{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a Regex", bson.TypeString), + }, + { + "ReadRegex Error", + bson.Regex{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeRegex, err: errors.New("rr error"), errAfter: llvrwReadRegex}, + llvrwReadRegex, + errors.New("rr error"), + }, + { + "success", + bson.Regex{Pattern: "foo", Options: "bar"}, + nil, + &llValueReaderWriter{bsontype: bson.TypeRegex, readval: bson.Regex{Pattern: "foo", Options: "bar"}}, + llvrwReadRegex, + nil, + }, + }, + }, + { + "DBPointerDecodeValue", + ValueDecoderFunc(dvd.DBPointerDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeDBPointer}, + llvrwNothing, + ValueDecoderError{Name: "DBPointerDecodeValue", Types: []interface{}{(*bson.DBPointer)(nil)}, Received: &wrong}, + }, + { + "type not dbpointer", + bson.DBPointer{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a DBPointer", bson.TypeString), + }, + { + "ReadDBPointer Error", + bson.DBPointer{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeDBPointer, err: errors.New("rdbp error"), errAfter: llvrwReadDBPointer}, + llvrwReadDBPointer, + errors.New("rdbp error"), + }, + { + "success", + bson.DBPointer{ + DB: "foobar", + Pointer: objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, + }, + nil, + &llValueReaderWriter{ + bsontype: bson.TypeDBPointer, + readval: bson.DBPointer{ + DB: "foobar", + Pointer: objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, + }, + }, + llvrwReadDBPointer, + nil, + }, + }, + }, + { + "CodeWithScopeDecodeValue", + ValueDecoderFunc(dvd.CodeWithScopeDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeCodeWithScope}, + llvrwNothing, + ValueDecoderError{ + Name: "CodeWithScopeDecodeValue", + Types: []interface{}{(*bson.CodeWithScope)(nil)}, + Received: &wrong, + }, + }, + { + "type not codewithscope", + bson.CodeWithScope{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a CodeWithScope", bson.TypeString), + }, + { + "ReadCodeWithScope Error", + bson.CodeWithScope{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeCodeWithScope, err: errors.New("rcws error"), errAfter: llvrwReadCodeWithScope}, + llvrwReadCodeWithScope, + errors.New("rcws error"), + }, + { + "decodeDocument Error", + bson.CodeWithScope{ + Code: "var hello = 'world';", + Scope: bson.NewDocument(bson.EC.Null("foo")), + }, + nil, + &llValueReaderWriter{bsontype: bson.TypeCodeWithScope, err: errors.New("dd error"), errAfter: llvrwReadElement}, + llvrwReadElement, + errors.New("dd error"), + }, + }, + }, + { + "TimestampDecodeValue", + ValueDecoderFunc(dvd.TimestampDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeTimestamp}, + llvrwNothing, + ValueDecoderError{Name: "TimestampDecodeValue", Types: []interface{}{(*bson.Timestamp)(nil)}, Received: &wrong}, + }, + { + "type not timestamp", + bson.Timestamp{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a Timestamp", bson.TypeString), + }, + { + "ReadTimestamp Error", + bson.Timestamp{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeTimestamp, err: errors.New("rt error"), errAfter: llvrwReadTimestamp}, + llvrwReadTimestamp, + errors.New("rt error"), + }, + { + "success", + bson.Timestamp{T: 12345, I: 67890}, + nil, + &llValueReaderWriter{bsontype: bson.TypeTimestamp, readval: bson.Timestamp{T: 12345, I: 67890}}, + llvrwReadTimestamp, + nil, + }, + }, + }, + { + "Decimal128DecodeValue", + ValueDecoderFunc(dvd.Decimal128DecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeDecimal128}, + llvrwNothing, + ValueDecoderError{Name: "Decimal128DecodeValue", Types: []interface{}{(*decimal.Decimal128)(nil)}, Received: &wrong}, + }, + { + "type not decimal128", + decimal.Decimal128{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a decimal.Decimal128", bson.TypeString), + }, + { + "ReadDecimal128 Error", + decimal.Decimal128{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeDecimal128, err: errors.New("rd128 error"), errAfter: llvrwReadDecimal128}, + llvrwReadDecimal128, + errors.New("rd128 error"), + }, + { + "success", + d128, + nil, + &llValueReaderWriter{bsontype: bson.TypeDecimal128, readval: d128}, + llvrwReadDecimal128, + nil, + }, + }, + }, + { + "MinKeyDecodeValue", + ValueDecoderFunc(dvd.MinKeyDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeMinKey}, + llvrwNothing, + ValueDecoderError{Name: "MinKeyDecodeValue", Types: []interface{}{(*bson.MinKeyv2)(nil)}, Received: &wrong}, + }, + { + "type not null", + bson.MinKeyv2{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a MinKey", bson.TypeString), + }, + { + "ReadMinKey Error", + bson.MinKeyv2{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeMinKey, err: errors.New("rn error"), errAfter: llvrwReadMinKey}, + llvrwReadMinKey, + errors.New("rn error"), + }, + { + "success", + bson.MinKeyv2{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeMinKey}, + llvrwReadMinKey, + nil, + }, + }, + }, + { + "MaxKeyDecodeValue", + ValueDecoderFunc(dvd.MaxKeyDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeMaxKey}, + llvrwNothing, + ValueDecoderError{Name: "MaxKeyDecodeValue", Types: []interface{}{(*bson.MaxKeyv2)(nil)}, Received: &wrong}, + }, + { + "type not null", + bson.MaxKeyv2{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a MaxKey", bson.TypeString), + }, + { + "ReadMaxKey Error", + bson.MaxKeyv2{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeMaxKey, err: errors.New("rn error"), errAfter: llvrwReadMaxKey}, + llvrwReadMaxKey, + errors.New("rn error"), + }, + { + "success", + bson.MaxKeyv2{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeMaxKey}, + llvrwReadMaxKey, + nil, + }, + }, + }, + { + "ValueDecodeValue", + ValueDecoderFunc(dvd.ValueDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueDecoderError{Name: "ValueDecodeValue", Types: []interface{}{(**bson.Value)(nil)}, Received: &wrong}, + }, + {"invalid value", (**bson.Value)(nil), nil, nil, llvrwNothing, errors.New("ValueDecodeValue can only be used to decode non-nil **Value")}, + { + "success", + bson.VC.Double(3.14159), + &DecodeContext{Registry: NewRegistryBuilder().Build()}, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.14159)}, + llvrwReadDouble, + nil, + }, + }, + }, + { + "JSONNumberDecodeValue", + ValueDecoderFunc(dvd.JSONNumberDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeObjectID}, + llvrwNothing, + ValueDecoderError{Name: "JSONNumberDecodeValue", Types: []interface{}{(*json.Number)(nil)}, Received: &wrong}, + }, + { + "type not double/int32/int64", + json.Number(""), + nil, + &llValueReaderWriter{bsontype: bson.TypeString}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a json.Number", bson.TypeString), + }, + { + "ReadDouble Error", + json.Number(""), + nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, err: errors.New("rd error"), errAfter: llvrwReadDouble}, + llvrwReadDouble, + errors.New("rd error"), + }, + { + "ReadInt32 Error", + json.Number(""), + nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, err: errors.New("ri32 error"), errAfter: llvrwReadInt32}, + llvrwReadInt32, + errors.New("ri32 error"), + }, + { + "ReadInt64 Error", + json.Number(""), + nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, err: errors.New("ri64 error"), errAfter: llvrwReadInt64}, + llvrwReadInt64, + errors.New("ri64 error"), + }, + { + "success/double", + json.Number("3.14159"), + nil, + &llValueReaderWriter{bsontype: bson.TypeDouble, readval: float64(3.14159)}, + llvrwReadDouble, + nil, + }, + { + "success/int32", + json.Number("12345"), + nil, + &llValueReaderWriter{bsontype: bson.TypeInt32, readval: int32(12345)}, + llvrwReadInt32, + nil, + }, + { + "success/int64", + json.Number("1234567890"), + nil, + &llValueReaderWriter{bsontype: bson.TypeInt64, readval: int64(1234567890)}, + llvrwReadInt64, + nil, + }, + }, + }, + { + "URLDecodeValue", + ValueDecoderFunc(dvd.URLDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeInt32}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a *url.URL", bson.TypeInt32), + }, + { + "type not *url.URL", + int64(0), + nil, + &llValueReaderWriter{bsontype: bson.TypeString, readval: string("http://example.com")}, + llvrwReadString, + ValueDecoderError{Name: "URLDecodeValue", Types: []interface{}{(*url.URL)(nil), (**url.URL)(nil)}, Received: (*int64)(nil)}, + }, + { + "ReadString error", + url.URL{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString, err: errors.New("rs error"), errAfter: llvrwReadString}, + llvrwReadString, + errors.New("rs error"), + }, + { + "url.Parse error", + url.URL{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString, readval: string("not-valid-%%%%://")}, + llvrwReadString, + errors.New("parse not-valid-%%%%://: first path segment in URL cannot contain colon"), + }, + { + "nil *url.URL", + (*url.URL)(nil), + nil, + &llValueReaderWriter{bsontype: bson.TypeString, readval: string("http://example.com")}, + llvrwReadString, + ValueDecoderError{Name: "URLDecodeValue", Types: []interface{}{(*url.URL)(nil), (**url.URL)(nil)}, Received: (*url.URL)(nil)}, + }, + { + "nil **url.URL", + (**url.URL)(nil), + nil, + &llValueReaderWriter{bsontype: bson.TypeString, readval: string("http://example.com")}, + llvrwReadString, + ValueDecoderError{Name: "URLDecodeValue", Types: []interface{}{(*url.URL)(nil), (**url.URL)(nil)}, Received: (**url.URL)(nil)}, + }, + { + "url.URL", + url.URL{Scheme: "http", Host: "example.com"}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString, readval: string("http://example.com")}, + llvrwReadString, + nil, + }, + { + "*url.URL", + &url.URL{Scheme: "http", Host: "example.com"}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString, readval: string("http://example.com")}, + llvrwReadString, + nil, + }, + }, + }, + { + "ReaderDecodeValue", + ValueDecoderFunc(dvd.ReaderDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{}, + llvrwNothing, + ValueDecoderError{Name: "ReaderDecodeValue", Types: []interface{}{(*bson.Reader)(nil)}, Received: &wrong}, + }, + { + "*Reader is nil", + (*bson.Reader)(nil), + nil, + nil, + llvrwNothing, + errors.New("ReaderDecodeValue can only be used to decode non-nil *Reader"), + }, + { + "Copy error", + bson.Reader{}, + nil, + &llValueReaderWriter{err: errors.New("copy error"), errAfter: llvrwReadDocument}, + llvrwReadDocument, + errors.New("copy error"), + }, + }, + }, + { + "ByteSliceDecodeValue", + ValueDecoderFunc(dvd.ByteSliceDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + &llValueReaderWriter{bsontype: bson.TypeInt32}, + llvrwNothing, + fmt.Errorf("cannot decode %v into a *[]byte", bson.TypeInt32), + }, + { + "type not *[]byte", + int64(0), + nil, + &llValueReaderWriter{bsontype: bson.TypeBinary, readval: bson.Binary{}}, + llvrwNothing, + ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []interface{}{(*[]byte)(nil)}, Received: (*int64)(nil)}, + }, + { + "ReadBinary error", + []byte{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeBinary, err: errors.New("rb error"), errAfter: llvrwReadBinary}, + llvrwReadBinary, + errors.New("rb error"), + }, + { + "incorrect subtype", + []byte{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeBinary, readval: bson.Binary{Data: []byte{0x01, 0x02, 0x03}, Subtype: 0xFF}}, + llvrwReadBinary, + fmt.Errorf("ByteSliceDecodeValue can only be used to decode subtype 0x00 for %s, got %v", bson.TypeBinary, byte(0xFF)), + }, + }, + }, + { + "ValueUnmarshalerDecodeValue", + ValueDecoderFunc(dvd.ValueUnmarshalerDecodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + fmt.Errorf("ValueUnmarshalerDecodeValue can only handle types or pointers to types that are a ValueUnmarshaler, got %T", &wrong), + }, + { + "copy error", + testValueUnmarshaler{}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString, err: errors.New("copy error"), errAfter: llvrwReadString}, + llvrwReadString, + errors.New("copy error"), + }, + { + "ValueUnmarshaler", + testValueUnmarshaler{t: bson.TypeString, val: llbson.AppendString(nil, "hello, world")}, + nil, + &llValueReaderWriter{bsontype: bson.TypeString, readval: string("hello, world")}, + llvrwReadString, + nil, + }, + { + "nil pointer to ValueUnmarshaler", + ptrPtrValueUnmarshaler, + nil, + &llValueReaderWriter{bsontype: bson.TypeString, readval: string("hello, world")}, + llvrwNothing, + fmt.Errorf("ValueUnmarshalerDecodeValue can only unmarshal into non-nil ValueUnmarshaler values, got %T", ptrPtrValueUnmarshaler), + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for _, rc := range tc.subtests { + t.Run(rc.name, func(t *testing.T) { + var dc DecodeContext + if rc.dctx != nil { + dc = *rc.dctx + } + llvrw := new(llValueReaderWriter) + if rc.llvrw != nil { + llvrw = rc.llvrw + } + llvrw.t = t + var got interface{} + if rc.val == cansetreflectiontest { // We're doing a CanSet reflection test + err := tc.vd.DecodeValue(dc, llvrw, nil) + if !compareErrors(err, rc.err) { + t.Errorf("Errors do not match. got %v; want %v", err, rc.err) + } + + val := reflect.New(reflect.TypeOf(rc.val)).Elem().Interface() + err = tc.vd.DecodeValue(dc, llvrw, val) + if !compareErrors(err, rc.err) { + t.Errorf("Errors do not match. got %v; want %v", err, rc.err) + } + return + } + var unwrap bool + rtype := reflect.TypeOf(rc.val) + if rtype.Kind() == reflect.Ptr { + if reflect.ValueOf(rc.val).IsNil() { + got = rc.val + } else { + val := reflect.New(rtype).Elem() + elem := reflect.New(rtype.Elem()) + val.Set(elem) + got = val.Addr().Interface() + unwrap = true + } + } else { + unwrap = true + got = reflect.New(reflect.TypeOf(rc.val)).Interface() + } + want := rc.val + err := tc.vd.DecodeValue(dc, llvrw, got) + if !compareErrors(err, rc.err) { + t.Errorf("Errors do not match. got %v; want %v", err, rc.err) + } + invoked := llvrw.invoked + if !cmp.Equal(invoked, rc.invoke) { + t.Errorf("Incorrect method invoked. got %v; want %v", invoked, rc.invoke) + } + if unwrap { + got = reflect.ValueOf(got).Elem().Interface() + } + if rc.err == nil && !cmp.Equal(got, want, cmp.Comparer(compareDecimal128), cmp.Comparer(compareValues)) { + t.Errorf("Values do not match. got (%T)%v; want (%T)%v", got, got, want, want) + } + }) + } + }) + } + + t.Run("ValueUnmarshalerDecodeValue/UnmarshalBSONValue error", func(t *testing.T) { + var dc DecodeContext + llvrw := &llValueReaderWriter{bsontype: bson.TypeString, readval: string("hello, world!")} + llvrw.t = t + + want := errors.New("ubsonv error") + valUnmarshaler := &testValueUnmarshaler{err: want} + got := dvd.ValueUnmarshalerDecodeValue(dc, llvrw, valUnmarshaler) + if !compareErrors(got, want) { + t.Errorf("Errors do not match. got %v; want %v", got, want) + } + }) + t.Run("ValueUnmarshalerDecodeValue/Pointer to ValueUnmarshaler", func(t *testing.T) { + var dc DecodeContext + llvrw := &llValueReaderWriter{bsontype: bson.TypeString, readval: string("hello, world!")} + llvrw.t = t + + var want = new(*testValueUnmarshaler) + var got = new(*testValueUnmarshaler) + *want = &testValueUnmarshaler{t: bson.TypeString, val: llbson.AppendString(nil, "hello, world!")} + *got = new(testValueUnmarshaler) + + err := dvd.ValueUnmarshalerDecodeValue(dc, llvrw, got) + noerr(t, err) + if !cmp.Equal(*got, *want) { + t.Errorf("Unmarshaled values do not match. got %v; want %v", *got, *want) + } + }) + t.Run("ValueUnmarshalerDecodeValue/nil pointer inside non-nil pointer", func(t *testing.T) { + var dc DecodeContext + llvrw := &llValueReaderWriter{bsontype: bson.TypeString, readval: string("hello, world!")} + llvrw.t = t + + var got = new(*testValueUnmarshaler) + var want = new(*testValueUnmarshaler) + *want = &testValueUnmarshaler{t: bson.TypeString, val: llbson.AppendString(nil, "hello, world!")} + + err := dvd.ValueUnmarshalerDecodeValue(dc, llvrw, got) + noerr(t, err) + if !cmp.Equal(got, want) { + t.Errorf("Results do not match. got %v; want %v", got, want) + } + }) + t.Run("MapCodec/DecodeValue/non-settable", func(t *testing.T) { + var dc DecodeContext + llvrw := new(llValueReaderWriter) + llvrw.t = t + + want := fmt.Errorf("MapDecodeValue can only be used to decode non-nil pointers to map values, got %T", nil) + got := dvd.MapDecodeValue(dc, llvrw, nil) + if !compareErrors(got, want) { + t.Errorf("Errors do not match. got %v; want %v", got, want) + } + + want = fmt.Errorf("MapDecodeValue can only be used to decode non-nil pointers to map values, got %T", string("")) + + val := reflect.New(reflect.TypeOf(string(""))).Elem().Interface() + got = dvd.MapDecodeValue(dc, llvrw, val) + if !compareErrors(got, want) { + t.Errorf("Errors do not match. got %v; want %v", got, want) + } + return + }) + + t.Run("CodeWithScopeCodec/DecodeValue/success", func(t *testing.T) { + dc := DecodeContext{Registry: NewRegistryBuilder().Build()} + dvr := newDocumentValueReader(bson.NewDocument(bson.EC.CodeWithScope("foo", "var hello = 'world';", bson.NewDocument(bson.EC.Null("bar"))))) + dr, err := dvr.ReadDocument() + noerr(t, err) + _, vr, err := dr.ReadElement() + noerr(t, err) + + want := bson.CodeWithScope{ + Code: "var hello = 'world';", + Scope: bson.NewDocument(bson.EC.Null("bar")), + } + var got bson.CodeWithScope + err = dvd.CodeWithScopeDecodeValue(dc, vr, &got) + noerr(t, err) + + if !cmp.Equal(got, want) { + t.Errorf("CodeWithScopes do not match. got %v; want %v", got, want) + } + }) + t.Run("DocumentDecodeValue", func(t *testing.T) { + t.Run("CodecDecodeError", func(t *testing.T) { + val := bool(true) + want := ValueDecoderError{Name: "DocumentDecodeValue", Types: []interface{}{(**bson.Document)(nil)}, Received: val} + got := dvd.DocumentDecodeValue(DecodeContext{}, &llValueReaderWriter{bsontype: bson.TypeEmbeddedDocument}, val) + if !compareErrors(got, want) { + t.Errorf("Errors do not match. got %v; want %v", got, want) + } + }) + t.Run("ReadDocument Error", func(t *testing.T) { + want := errors.New("ReadDocument Error") + llvrw := &llValueReaderWriter{ + t: t, + err: want, + errAfter: llvrwReadDocument, + bsontype: bson.TypeEmbeddedDocument, + } + got := dvd.DocumentDecodeValue(DecodeContext{}, llvrw, new(*bson.Document)) + if !compareErrors(got, want) { + t.Errorf("Errors do not match. got %v; want %v", got, want) + } + }) + t.Run("decodeDocument errors", func(t *testing.T) { + dc := DecodeContext{} + err := errors.New("decodeDocument error") + testCases := []struct { + name string + dc DecodeContext + llvrw *llValueReaderWriter + err error + }{ + { + "ReadElement", + dc, + &llValueReaderWriter{t: t, err: errors.New("re error"), errAfter: llvrwReadElement}, + errors.New("re error"), + }, + {"ReadDouble", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadDouble, bsontype: bson.TypeDouble}, err}, + {"ReadString", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadString, bsontype: bson.TypeString}, err}, + { + "ReadDocument (Lookup)", DecodeContext{Registry: NewEmptyRegistryBuilder().Build()}, + &llValueReaderWriter{t: t, bsontype: bson.TypeEmbeddedDocument}, + ErrNoDecoder{Type: tDocument}, + }, + { + "ReadArray (Lookup)", DecodeContext{Registry: NewEmptyRegistryBuilder().Build()}, + &llValueReaderWriter{t: t, bsontype: bson.TypeArray}, + ErrNoDecoder{Type: tArray}, + }, + {"ReadBinary", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadBinary, bsontype: bson.TypeBinary}, err}, + {"ReadUndefined", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadUndefined, bsontype: bson.TypeUndefined}, err}, + {"ReadObjectID", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadObjectID, bsontype: bson.TypeObjectID}, err}, + {"ReadBoolean", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadBoolean, bsontype: bson.TypeBoolean}, err}, + {"ReadDateTime", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadDateTime, bsontype: bson.TypeDateTime}, err}, + {"ReadNull", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadNull, bsontype: bson.TypeNull}, err}, + {"ReadRegex", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadRegex, bsontype: bson.TypeRegex}, err}, + {"ReadDBPointer", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadDBPointer, bsontype: bson.TypeDBPointer}, err}, + {"ReadJavascript", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadJavascript, bsontype: bson.TypeJavaScript}, err}, + {"ReadSymbol", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadSymbol, bsontype: bson.TypeSymbol}, err}, + { + "ReadCodeWithScope (Lookup)", DecodeContext{Registry: NewEmptyRegistryBuilder().Build()}, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadCodeWithScope, bsontype: bson.TypeCodeWithScope}, + err, + }, + {"ReadInt32", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadInt32, bsontype: bson.TypeInt32}, err}, + {"ReadInt64", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadInt64, bsontype: bson.TypeInt64}, err}, + {"ReadTimestamp", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadTimestamp, bsontype: bson.TypeTimestamp}, err}, + {"ReadDecimal128", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadDecimal128, bsontype: bson.TypeDecimal128}, err}, + {"ReadMinKey", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadMinKey, bsontype: bson.TypeMinKey}, err}, + {"ReadMaxKey", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadMaxKey, bsontype: bson.TypeMaxKey}, err}, + {"Invalid Type", dc, &llValueReaderWriter{t: t, bsontype: bson.Type(0)}, fmt.Errorf("Cannot read unknown BSON type %s", bson.Type(0))}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := dvd.decodeDocument(tc.dc, tc.llvrw, new(*bson.Document)) + if !compareErrors(err, tc.err) { + t.Errorf("Errors do not match. got %v; want %v", err, tc.err) + } + }) + } + }) + + t.Run("success", func(t *testing.T) { + oid := objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C} + d128 := decimal.NewDecimal128(10, 20) + want := bson.NewDocument( + bson.EC.Double("a", 3.14159), bson.EC.String("b", "foo"), bson.EC.SubDocumentFromElements("c", bson.EC.Null("aa")), + bson.EC.ArrayFromElements("d", bson.VC.Null()), + bson.EC.BinaryWithSubtype("e", []byte{0x01, 0x02, 0x03}, 0xFF), bson.EC.Undefined("f"), + bson.EC.ObjectID("g", oid), bson.EC.Boolean("h", true), bson.EC.DateTime("i", 1234567890), bson.EC.Null("j"), bson.EC.Regex("k", "foo", "bar"), + bson.EC.DBPointer("l", "foobar", oid), bson.EC.JavaScript("m", "var hello = 'world';"), bson.EC.Symbol("n", "bazqux"), + bson.EC.CodeWithScope("o", "var hello = 'world';", bson.NewDocument(bson.EC.Null("ab"))), bson.EC.Int32("p", 12345), + bson.EC.Timestamp("q", 10, 20), bson.EC.Int64("r", 1234567890), bson.EC.Decimal128("s", d128), bson.EC.MinKey("t"), bson.EC.MaxKey("u"), + ) + var got *bson.Document + dc := DecodeContext{Registry: NewRegistryBuilder().Build()} + err := dvd.DocumentDecodeValue(dc, newDocumentValueReader(want), &got) + noerr(t, err) + if !got.Equal(want) { + t.Error("Documents do not match") + t.Errorf("\ngot :%v\nwant:%v", got, want) + } + }) + }) + t.Run("ArrayDecodeValue", func(t *testing.T) { + t.Run("CodecDecodeError", func(t *testing.T) { + val := bool(true) + want := ValueDecoderError{Name: "ArrayDecodeValue", Types: []interface{}{(**bson.Array)(nil)}, Received: val} + got := dvd.ArrayDecodeValue(DecodeContext{}, &llValueReaderWriter{bsontype: bson.TypeArray}, val) + if !compareErrors(got, want) { + t.Errorf("Errors do not match. got %v; want %v", got, want) + } + }) + t.Run("ReadArray Error", func(t *testing.T) { + want := errors.New("ReadArray Error") + llvrw := &llValueReaderWriter{ + t: t, + err: want, + errAfter: llvrwReadArray, + bsontype: bson.TypeArray, + } + got := dvd.ArrayDecodeValue(DecodeContext{}, llvrw, new(*bson.Array)) + if !compareErrors(got, want) { + t.Errorf("Errors do not match. got %v; want %v", got, want) + } + }) + t.Run("decode array errors", func(t *testing.T) { + dc := DecodeContext{} + err := errors.New("decode array error") + testCases := []struct { + name string + dc DecodeContext + llvrw *llValueReaderWriter + err error + }{ + { + "ReadValue", + dc, + &llValueReaderWriter{t: t, err: errors.New("re error"), errAfter: llvrwReadValue}, + errors.New("re error"), + }, + {"ReadDouble", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadDouble, bsontype: bson.TypeDouble}, err}, + {"ReadString", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadString, bsontype: bson.TypeString}, err}, + { + "ReadDocument (Lookup)", DecodeContext{Registry: NewEmptyRegistryBuilder().Build()}, + &llValueReaderWriter{t: t, bsontype: bson.TypeEmbeddedDocument}, + ErrNoDecoder{Type: tDocument}, + }, + { + "ReadArray (Lookup)", DecodeContext{Registry: NewEmptyRegistryBuilder().Build()}, + &llValueReaderWriter{t: t, bsontype: bson.TypeArray}, + ErrNoDecoder{Type: tArray}, + }, + {"ReadBinary", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadBinary, bsontype: bson.TypeBinary}, err}, + {"ReadUndefined", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadUndefined, bsontype: bson.TypeUndefined}, err}, + {"ReadObjectID", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadObjectID, bsontype: bson.TypeObjectID}, err}, + {"ReadBoolean", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadBoolean, bsontype: bson.TypeBoolean}, err}, + {"ReadDateTime", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadDateTime, bsontype: bson.TypeDateTime}, err}, + {"ReadNull", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadNull, bsontype: bson.TypeNull}, err}, + {"ReadRegex", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadRegex, bsontype: bson.TypeRegex}, err}, + {"ReadDBPointer", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadDBPointer, bsontype: bson.TypeDBPointer}, err}, + {"ReadJavascript", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadJavascript, bsontype: bson.TypeJavaScript}, err}, + {"ReadSymbol", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadSymbol, bsontype: bson.TypeSymbol}, err}, + { + "ReadCodeWithScope (Lookup)", DecodeContext{Registry: NewEmptyRegistryBuilder().Build()}, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadCodeWithScope, bsontype: bson.TypeCodeWithScope}, + err, + }, + {"ReadInt32", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadInt32, bsontype: bson.TypeInt32}, err}, + {"ReadInt64", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadInt64, bsontype: bson.TypeInt64}, err}, + {"ReadTimestamp", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadTimestamp, bsontype: bson.TypeTimestamp}, err}, + {"ReadDecimal128", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadDecimal128, bsontype: bson.TypeDecimal128}, err}, + {"ReadMinKey", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadMinKey, bsontype: bson.TypeMinKey}, err}, + {"ReadMaxKey", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadMaxKey, bsontype: bson.TypeMaxKey}, err}, + {"Invalid Type", dc, &llValueReaderWriter{t: t, bsontype: bson.Type(0)}, fmt.Errorf("Cannot read unknown BSON type %s", bson.Type(0))}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := dvd.ArrayDecodeValue(tc.dc, tc.llvrw, new(*bson.Array)) + if !compareErrors(err, tc.err) { + t.Errorf("Errors do not match. got %v; want %v", err, tc.err) + } + }) + } + }) + + t.Run("success", func(t *testing.T) { + oid := objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C} + d128 := decimal.NewDecimal128(10, 20) + want := bson.NewArray( + bson.VC.Double(3.14159), bson.VC.String("foo"), bson.VC.DocumentFromElements(bson.EC.Null("aa")), + bson.VC.ArrayFromValues(bson.VC.Null()), + bson.VC.BinaryWithSubtype([]byte{0x01, 0x02, 0x03}, 0xFF), bson.VC.Undefined(), + bson.VC.ObjectID(oid), bson.VC.Boolean(true), bson.VC.DateTime(1234567890), bson.VC.Null(), bson.VC.Regex("foo", "bar"), + bson.VC.DBPointer("foobar", oid), bson.VC.JavaScript("var hello = 'world';"), bson.VC.Symbol("bazqux"), + bson.VC.CodeWithScope("var hello = 'world';", bson.NewDocument(bson.EC.Null("ab"))), bson.VC.Int32(12345), + bson.VC.Timestamp(10, 20), bson.VC.Int64(1234567890), bson.VC.Decimal128(d128), bson.VC.MinKey(), bson.VC.MaxKey(), + ) + dc := DecodeContext{Registry: NewRegistryBuilder().Build()} + + dvr := newDocumentValueReader(bson.NewDocument(bson.EC.Array("", want))) + dr, err := dvr.ReadDocument() + noerr(t, err) + _, vr, err := dr.ReadElement() + noerr(t, err) + + var got *bson.Array + err = dvd.ArrayDecodeValue(dc, vr, &got) + noerr(t, err) + if !got.Equal(want) { + t.Error("Documents do not match") + t.Errorf("\ngot :%v\nwant:%v", got, want) + } + }) + }) + t.Run("SliceCodec/DecodeValue/can't set slice", func(t *testing.T) { + var val []string + want := fmt.Errorf("SliceDecodeValue can only be used to decode non-nil pointers to slice or array values, got %T", val) + got := dvd.SliceDecodeValue(DecodeContext{}, nil, val) + if !compareErrors(got, want) { + t.Errorf("Errors do not match. got %v; want %v", got, want) + } + }) + t.Run("SliceCodec/DecodeValue/too many elements", func(t *testing.T) { + dvr := newDocumentValueReader(bson.NewDocument(bson.EC.ArrayFromElements("foo", bson.VC.String("foo"), bson.VC.String("bar")))) + dr, err := dvr.ReadDocument() + noerr(t, err) + _, vr, err := dr.ReadElement() + noerr(t, err) + var val [1]string + want := fmt.Errorf("more elements returned in array than can fit inside %T", val) + + dc := DecodeContext{Registry: NewRegistryBuilder().Build()} + got := dvd.SliceDecodeValue(dc, vr, &val) + if !compareErrors(got, want) { + t.Errorf("Errors do not match. got %v; want %v", got, want) + } + }) + + t.Run("success path", func(t *testing.T) { + oid := objectid.New() + oids := []objectid.ObjectID{objectid.New(), objectid.New(), objectid.New()} + var str = new(string) + *str = "bar" + now := time.Now().Truncate(time.Millisecond) + murl, err := url.Parse("https://mongodb.com/random-url?hello=world") + if err != nil { + t.Errorf("Error parsing URL: %v", err) + t.FailNow() + } + decimal128, err := decimal.ParseDecimal128("1.5e10") + if err != nil { + t.Errorf("Error parsing decimal128: %v", err) + t.FailNow() + } + + testCases := []struct { + name string + value interface{} + b []byte + err error + }{ + { + "map[string]int", + map[string]int32{"foo": 1}, + []byte{ + 0x0E, 0x00, 0x00, 0x00, + 0x10, 'f', 'o', 'o', 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + }, + nil, + }, + { + "map[string]objectid.ObjectID", + map[string]objectid.ObjectID{"foo": oid}, + docToBytes(bson.NewDocument(bson.EC.ObjectID("foo", oid))), + nil, + }, + { + "map[string][]*Element", + map[string][]*bson.Element{"Z": {bson.EC.Int32("A", 1), bson.EC.Int32("B", 2), bson.EC.Int32("EC", 3)}}, + docToBytes(bson.NewDocument( + bson.EC.SubDocumentFromElements("Z", bson.EC.Int32("A", 1), bson.EC.Int32("B", 2), bson.EC.Int32("EC", 3)), + )), + nil, + }, + { + "map[string][]*Value", + map[string][]*bson.Value{"Z": {bson.VC.Int32(1), bson.VC.Int32(2), bson.VC.Int32(3)}}, + docToBytes(bson.NewDocument( + bson.EC.ArrayFromElements("Z", bson.VC.Int32(1), bson.VC.Int32(2), bson.VC.Int32(3)), + )), + nil, + }, + { + "map[string]*Element", + map[string]*bson.Element{"Z": bson.EC.Int32("Z", 12345)}, + docToBytes(bson.NewDocument( + bson.EC.Int32("Z", 12345), + )), + nil, + }, + { + "map[string]*Document", + map[string]*bson.Document{"Z": bson.NewDocument(bson.EC.Null("foo"))}, + docToBytes(bson.NewDocument( + bson.EC.SubDocumentFromElements("Z", bson.EC.Null("foo")), + )), + nil, + }, + { + "map[string]Reader", + map[string]bson.Reader{"Z": {0x05, 0x00, 0x00, 0x00, 0x00}}, + docToBytes(bson.NewDocument( + bson.EC.SubDocumentFromReader("Z", bson.Reader{0x05, 0x00, 0x00, 0x00, 0x00}), + )), + nil, + }, + { + "map[string][]int32", + map[string][]int32{"Z": {1, 2, 3}}, + docToBytes(bson.NewDocument( + bson.EC.ArrayFromElements("Z", bson.VC.Int32(1), bson.VC.Int32(2), bson.VC.Int32(3)), + )), + nil, + }, + { + "map[string][]objectid.ObjectID", + map[string][]objectid.ObjectID{"Z": oids}, + docToBytes(bson.NewDocument( + bson.EC.ArrayFromElements("Z", bson.VC.ObjectID(oids[0]), bson.VC.ObjectID(oids[1]), bson.VC.ObjectID(oids[2])), + )), + nil, + }, + { + "map[string][]json.Number(int64)", + map[string][]json.Number{"Z": {json.Number("5"), json.Number("10")}}, + docToBytes(bson.NewDocument( + bson.EC.ArrayFromElements("Z", bson.VC.Int64(5), bson.VC.Int64(10)), + )), + nil, + }, + { + "map[string][]json.Number(float64)", + map[string][]json.Number{"Z": {json.Number("5"), json.Number("10.1")}}, + docToBytes(bson.NewDocument( + bson.EC.ArrayFromElements("Z", bson.VC.Int64(5), bson.VC.Double(10.1)), + )), + nil, + }, + { + "map[string][]*url.URL", + map[string][]*url.URL{"Z": {murl}}, + docToBytes(bson.NewDocument( + bson.EC.ArrayFromElements("Z", bson.VC.String(murl.String())), + )), + nil, + }, + { + "map[string][]decimal.Decimal128", + map[string][]decimal.Decimal128{"Z": {decimal128}}, + docToBytes(bson.NewDocument( + bson.EC.ArrayFromElements("Z", bson.VC.Decimal128(decimal128)), + )), + nil, + }, + { + "-", + struct { + A string `bson:"-"` + }{ + A: "", + }, + docToBytes(bson.NewDocument()), + nil, + }, + { + "omitempty", + struct { + A string `bson:",omitempty"` + }{ + A: "", + }, + docToBytes(bson.NewDocument()), + nil, + }, + { + "omitempty, empty time", + struct { + A time.Time `bson:",omitempty"` + }{ + A: time.Time{}, + }, + docToBytes(bson.NewDocument()), + nil, + }, + { + "no private fields", + noPrivateFields{a: "should be empty"}, + docToBytes(bson.NewDocument()), + nil, + }, + { + "minsize", + struct { + A int64 `bson:",minsize"` + }{ + A: 12345, + }, + docToBytes(bson.NewDocument(bson.EC.Int32("a", 12345))), + nil, + }, + { + "inline", + struct { + Foo struct { + A int64 `bson:",minsize"` + } `bson:",inline"` + }{ + Foo: struct { + A int64 `bson:",minsize"` + }{ + A: 12345, + }, + }, + docToBytes(bson.NewDocument(bson.EC.Int32("a", 12345))), + nil, + }, + { + "inline map", + struct { + Foo map[string]string `bson:",inline"` + }{ + Foo: map[string]string{"foo": "bar"}, + }, + docToBytes(bson.NewDocument(bson.EC.String("foo", "bar"))), + nil, + }, + { + "alternate name bson:name", + struct { + A string `bson:"foo"` + }{ + A: "bar", + }, + docToBytes(bson.NewDocument(bson.EC.String("foo", "bar"))), + nil, + }, + { + "alternate name", + struct { + A string `bson:"foo"` + }{ + A: "bar", + }, + docToBytes(bson.NewDocument(bson.EC.String("foo", "bar"))), + nil, + }, + { + "inline, omitempty", + struct { + A string + Foo zeroTest `bson:"omitempty,inline"` + }{ + A: "bar", + Foo: zeroTest{true}, + }, + docToBytes(bson.NewDocument(bson.EC.String("a", "bar"))), + nil, + }, + { + "struct{}", + struct { + A bool + B int32 + C int64 + D uint16 + E uint64 + F float64 + G string + H map[string]string + I []byte + K [2]string + L struct { + M string + } + N *bson.Element + O *bson.Document + P bson.Reader + Q objectid.ObjectID + T []struct{} + Y json.Number + Z time.Time + AA json.Number + AB *url.URL + AC decimal.Decimal128 + AD *time.Time + AE *testValueUnmarshaler + }{ + A: true, + B: 123, + C: 456, + D: 789, + E: 101112, + F: 3.14159, + G: "Hello, world", + H: map[string]string{"foo": "bar"}, + I: []byte{0x01, 0x02, 0x03}, + K: [2]string{"baz", "qux"}, + L: struct { + M string + }{ + M: "foobar", + }, + N: bson.EC.Null("n"), + O: bson.NewDocument(bson.EC.Int64("countdown", 9876543210)), + P: bson.Reader{0x05, 0x00, 0x00, 0x00, 0x00}, + Q: oid, + T: nil, + Y: json.Number("5"), + Z: now, + AA: json.Number("10.1"), + AB: murl, + AC: decimal128, + AD: &now, + AE: &testValueUnmarshaler{t: bson.TypeString, val: llbson.AppendString(nil, "hello, world!")}, + }, + docToBytes(bson.NewDocument( + bson.EC.Boolean("a", true), + bson.EC.Int32("b", 123), + bson.EC.Int64("c", 456), + bson.EC.Int32("d", 789), + bson.EC.Int64("e", 101112), + bson.EC.Double("f", 3.14159), + bson.EC.String("g", "Hello, world"), + bson.EC.SubDocumentFromElements("h", bson.EC.String("foo", "bar")), + bson.EC.Binary("i", []byte{0x01, 0x02, 0x03}), + bson.EC.ArrayFromElements("k", bson.VC.String("baz"), bson.VC.String("qux")), + bson.EC.SubDocumentFromElements("l", bson.EC.String("m", "foobar")), + bson.EC.Null("n"), + bson.EC.SubDocumentFromElements("o", bson.EC.Int64("countdown", 9876543210)), + bson.EC.SubDocumentFromElements("p"), + bson.EC.ObjectID("q", oid), + bson.EC.Null("t"), + bson.EC.Int64("y", 5), + bson.EC.DateTime("z", now.UnixNano()/int64(time.Millisecond)), + bson.EC.Double("aa", 10.1), + bson.EC.String("ab", murl.String()), + bson.EC.Decimal128("ac", decimal128), + bson.EC.DateTime("ad", now.UnixNano()/int64(time.Millisecond)), + bson.EC.String("ae", "hello, world!"), + )), + nil, + }, + { + "struct{[]interface{}}", + struct { + A []bool + B []int32 + C []int64 + D []uint16 + E []uint64 + F []float64 + G []string + H []map[string]string + I [][]byte + K [1][2]string + L []struct { + M string + } + N [][]string + O []*bson.Element + P []*bson.Document + Q []bson.Reader + R []objectid.ObjectID + T []struct{} + W []map[string]struct{} + X []map[string]struct{} + Y []map[string]struct{} + Z []time.Time + AA []json.Number + AB []*url.URL + AC []decimal.Decimal128 + AD []*time.Time + AE []*testValueUnmarshaler + }{ + A: []bool{true}, + B: []int32{123}, + C: []int64{456}, + D: []uint16{789}, + E: []uint64{101112}, + F: []float64{3.14159}, + G: []string{"Hello, world"}, + H: []map[string]string{{"foo": "bar"}}, + I: [][]byte{{0x01, 0x02, 0x03}}, + K: [1][2]string{{"baz", "qux"}}, + L: []struct { + M string + }{ + { + M: "foobar", + }, + }, + N: [][]string{{"foo", "bar"}}, + O: []*bson.Element{bson.EC.Null("N")}, + P: []*bson.Document{bson.NewDocument(bson.EC.Int64("countdown", 9876543210))}, + Q: []bson.Reader{{0x05, 0x00, 0x00, 0x00, 0x00}}, + R: oids, + T: nil, + W: nil, + X: []map[string]struct{}{}, // Should be empty BSON Array + Y: []map[string]struct{}{{}}, // Should be BSON array with one element, an empty BSON SubDocument + Z: []time.Time{now, now}, + AA: []json.Number{json.Number("5"), json.Number("10.1")}, + AB: []*url.URL{murl}, + AC: []decimal.Decimal128{decimal128}, + AD: []*time.Time{&now, &now}, + AE: []*testValueUnmarshaler{ + {t: bson.TypeString, val: llbson.AppendString(nil, "hello")}, + {t: bson.TypeString, val: llbson.AppendString(nil, "world")}, + }, + }, + docToBytes(bson.NewDocument( + bson.EC.ArrayFromElements("a", bson.VC.Boolean(true)), + bson.EC.ArrayFromElements("b", bson.VC.Int32(123)), + bson.EC.ArrayFromElements("c", bson.VC.Int64(456)), + bson.EC.ArrayFromElements("d", bson.VC.Int32(789)), + bson.EC.ArrayFromElements("e", bson.VC.Int64(101112)), + bson.EC.ArrayFromElements("f", bson.VC.Double(3.14159)), + bson.EC.ArrayFromElements("g", bson.VC.String("Hello, world")), + bson.EC.ArrayFromElements("h", bson.VC.DocumentFromElements(bson.EC.String("foo", "bar"))), + bson.EC.ArrayFromElements("i", bson.VC.Binary([]byte{0x01, 0x02, 0x03})), + bson.EC.ArrayFromElements("k", bson.VC.ArrayFromValues(bson.VC.String("baz"), bson.VC.String("qux"))), + bson.EC.ArrayFromElements("l", bson.VC.DocumentFromElements(bson.EC.String("m", "foobar"))), + bson.EC.ArrayFromElements("n", bson.VC.ArrayFromValues(bson.VC.String("foo"), bson.VC.String("bar"))), + bson.EC.SubDocumentFromElements("o", bson.EC.Null("N")), + bson.EC.ArrayFromElements("p", bson.VC.DocumentFromElements(bson.EC.Int64("countdown", 9876543210))), + bson.EC.ArrayFromElements("q", bson.VC.DocumentFromElements()), + bson.EC.ArrayFromElements("r", bson.VC.ObjectID(oids[0]), bson.VC.ObjectID(oids[1]), bson.VC.ObjectID(oids[2])), + bson.EC.Null("t"), + bson.EC.Null("w"), + bson.EC.Array("x", bson.NewArray()), + bson.EC.ArrayFromElements("y", bson.VC.Document(bson.NewDocument())), + bson.EC.ArrayFromElements("z", bson.VC.DateTime(now.UnixNano()/int64(time.Millisecond)), bson.VC.DateTime(now.UnixNano()/int64(time.Millisecond))), + bson.EC.ArrayFromElements("aa", bson.VC.Int64(5), bson.VC.Double(10.10)), + bson.EC.ArrayFromElements("ab", bson.VC.String(murl.String())), + bson.EC.ArrayFromElements("ac", bson.VC.Decimal128(decimal128)), + bson.EC.ArrayFromElements("ad", bson.VC.DateTime(now.UnixNano()/int64(time.Millisecond)), bson.VC.DateTime(now.UnixNano()/int64(time.Millisecond))), + bson.EC.ArrayFromElements("ae", bson.VC.String("hello"), bson.VC.String("world")), + )), + nil, + }, + } + + t.Run("Decode", func(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + vr := newValueReader(tc.b) + dec, err := NewDecoder(NewRegistryBuilder().Build(), vr) + noerr(t, err) + gotVal := reflect.New(reflect.TypeOf(tc.value)) + err = dec.Decode(gotVal.Interface()) + noerr(t, err) + got := gotVal.Elem().Interface() + want := tc.value + if diff := cmp.Diff( + got, want, + cmp.Comparer(compareElements), + cmp.Comparer(compareValues), + cmp.Comparer(compareDecimal128), + cmp.Comparer(compareNoPrivateFields), + cmp.Comparer(compareZeroTest), + ); diff != "" { + t.Errorf("difference:\n%s", diff) + t.Errorf("Values are not equal.\ngot: %#v\nwant:%#v", got, want) + } + }) + } + }) + }) + + t.Run("EmptyInterfaceDecodeValue", func(t *testing.T) { + t.Run("DecodeValue", func(t *testing.T) { + testCases := []struct { + name string + val interface{} + bsontype bson.Type + }{ + { + "Double - float64", + float64(3.14159), + bson.TypeDouble, + }, + { + "String - string", + string("foo bar baz"), + bson.TypeString, + }, + { + "Embedded Document - *Document", + bson.NewDocument(bson.EC.Null("foo")), + bson.TypeEmbeddedDocument, + }, + { + "Array - *Array", + bson.NewArray(bson.VC.Double(3.14159)), + bson.TypeArray, + }, + { + "Binary - Binary", + bson.Binary{Subtype: 0xFF, Data: []byte{0x01, 0x02, 0x03}}, + bson.TypeBinary, + }, + { + "Undefined - Undefined", + bson.Undefinedv2{}, + bson.TypeUndefined, + }, + { + "ObjectID - objectid.ObjectID", + objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, + bson.TypeObjectID, + }, + { + "Boolean - bool", + bool(true), + bson.TypeBoolean, + }, + { + "DateTime - DateTime", + bson.DateTime(1234567890), + bson.TypeDateTime, + }, + { + "Null - Null", + bson.Nullv2{}, + bson.TypeNull, + }, + { + "Regex - Regex", + bson.Regex{Pattern: "foo", Options: "bar"}, + bson.TypeRegex, + }, + { + "DBPointer - DBPointer", + bson.DBPointer{ + DB: "foobar", + Pointer: objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, + }, + bson.TypeDBPointer, + }, + { + "JavaScript - JavaScriptCode", + bson.JavaScriptCode("var foo = 'bar';"), + bson.TypeJavaScript, + }, + { + "Symbol - Symbol", + bson.Symbol("foobarbazlolz"), + bson.TypeSymbol, + }, + { + "CodeWithScope - CodeWithScope", + bson.CodeWithScope{ + Code: "var foo = 'bar';", + Scope: bson.NewDocument(bson.EC.Double("foo", 3.14159)), + }, + bson.TypeCodeWithScope, + }, + { + "Int32 - int32", + int32(123456), + bson.TypeInt32, + }, + { + "Int64 - int64", + int64(1234567890), + bson.TypeInt64, + }, + { + "Timestamp - Timestamp", + bson.Timestamp{T: 12345, I: 67890}, + bson.TypeTimestamp, + }, + { + "Decimal128 - decimal.Decimal128", + decimal.NewDecimal128(12345, 67890), + bson.TypeDecimal128, + }, + { + "MinKey - MinKey", + bson.MinKeyv2{}, + bson.TypeMinKey, + }, + { + "MaxKey - MaxKey", + bson.MaxKeyv2{}, + bson.TypeMaxKey, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + llvr := &llValueReaderWriter{bsontype: tc.bsontype} + + t.Run("Lookup failure", func(t *testing.T) { + val := new(interface{}) + dc := DecodeContext{Registry: NewEmptyRegistryBuilder().Build()} + want := ErrNoDecoder{Type: reflect.TypeOf(tc.val)} + got := dvd.EmptyInterfaceDecodeValue(dc, llvr, val) + if !compareErrors(got, want) { + t.Errorf("Errors are not equal. got %v; want %v", got, want) + } + }) + + t.Run("DecodeValue failure", func(t *testing.T) { + want := errors.New("DecodeValue failure error") + llc := &llCodec{t: t, err: want} + dc := DecodeContext{ + Registry: NewEmptyRegistryBuilder().RegisterDecoder(reflect.TypeOf(tc.val), llc).Build(), + } + got := dvd.EmptyInterfaceDecodeValue(dc, llvr, new(interface{})) + if !compareErrors(got, want) { + t.Errorf("Errors are not equal. got %v; want %v", got, want) + } + }) + + t.Run("Success", func(t *testing.T) { + want := tc.val + llc := &llCodec{t: t, decodeval: tc.val} + dc := DecodeContext{ + Registry: NewEmptyRegistryBuilder().RegisterDecoder(reflect.TypeOf(tc.val), llc).Build(), + } + got := new(interface{}) + err := dvd.EmptyInterfaceDecodeValue(dc, llvr, got) + noerr(t, err) + if !cmp.Equal(*got, want, cmp.Comparer(compareDecimal128)) { + t.Errorf("Did not receive expected value. got %v; want %v", *got, want) + } + }) + }) + } + }) + + t.Run("non-*interface{}", func(t *testing.T) { + val := uint64(1234567890) + want := fmt.Errorf("EmptyInterfaceDecodeValue can only be used to decode non-nil *interface{} values, provided type if %T", &val) + got := dvd.EmptyInterfaceDecodeValue(DecodeContext{}, nil, &val) + if !compareErrors(got, want) { + t.Errorf("Errors are not equal. got %v; want %v", got, want) + } + }) + + t.Run("nil *interface{}", func(t *testing.T) { + var val *interface{} + want := fmt.Errorf("EmptyInterfaceDecodeValue can only be used to decode non-nil *interface{} values, provided type if %T", val) + got := dvd.EmptyInterfaceDecodeValue(DecodeContext{}, nil, val) + if !compareErrors(got, want) { + t.Errorf("Errors are not equal. got %v; want %v", got, want) + } + }) + + t.Run("unknown BSON type", func(t *testing.T) { + llvr := &llValueReaderWriter{bsontype: bson.Type(0)} + want := fmt.Errorf("Type %s is not a valid BSON type and has no default Go type to decode into", bson.Type(0)) + got := dvd.EmptyInterfaceDecodeValue(DecodeContext{}, llvr, new(interface{})) + if !compareErrors(got, want) { + t.Errorf("Errors are not equal. got %v; want %v", got, want) + } + }) + }) + +} + +type testValueUnmarshaler struct { + t bson.Type + val []byte + err error +} + +func (tvu *testValueUnmarshaler) UnmarshalBSONValue(t bson.Type, val []byte) error { + tvu.t, tvu.val = t, val + return tvu.err +} +func (tvu testValueUnmarshaler) Equal(tvu2 testValueUnmarshaler) bool { + return tvu.t == tvu2.t && bytes.Equal(tvu.val, tvu2.val) +} diff --git a/bson/bsoncodec/default_value_encoders.go b/bson/bsoncodec/default_value_encoders.go new file mode 100644 index 0000000000..981c8249bc --- /dev/null +++ b/bson/bsoncodec/default_value_encoders.go @@ -0,0 +1,822 @@ +package bsoncodec + +import ( + "encoding/json" + "errors" + "fmt" + "math" + "net/url" + "reflect" + "time" + + "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/decimal" + "github.com/mongodb/mongo-go-driver/bson/objectid" +) + +var defaultValueEncoders DefaultValueEncoders + +// DefaultValueEncoders is a namespace type for the default ValueEncoders used +// when creating a registry. +type DefaultValueEncoders struct{} + +// BooleanEncodeValue is the ValueEncoderFunc for bool types. +func (dve DefaultValueEncoders) BooleanEncodeValue(ectx EncodeContext, vw ValueWriter, i interface{}) error { + b, ok := i.(bool) + if !ok { + if reflect.TypeOf(i).Kind() != reflect.Bool { + return ValueEncoderError{Name: "BooleanEncodeValue", Types: []interface{}{bool(true)}, Received: i} + } + + b = reflect.ValueOf(i).Bool() + } + + return vw.WriteBoolean(b) +} + +// IntEncodeValue is the ValueEncoderFunc for int types. +func (dve DefaultValueEncoders) IntEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + switch t := i.(type) { + case int8: + return vw.WriteInt32(int32(t)) + case int16: + return vw.WriteInt32(int32(t)) + case int32: + return vw.WriteInt32(t) + case int64: + if ec.MinSize && t <= math.MaxInt32 { + return vw.WriteInt32(int32(t)) + } + return vw.WriteInt64(t) + case int: + if ec.MinSize && t <= math.MaxInt32 { + return vw.WriteInt32(int32(t)) + } + return vw.WriteInt64(int64(t)) + } + + val := reflect.ValueOf(i) + switch val.Type().Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32: + return vw.WriteInt32(int32(val.Int())) + case reflect.Int, reflect.Int64: + i64 := val.Int() + if ec.MinSize && i64 <= math.MaxInt32 { + return vw.WriteInt32(int32(i64)) + } + return vw.WriteInt64(i64) + } + + return ValueEncoderError{ + Name: "IntEncodeValue", + Types: []interface{}{int8(0), int16(0), int32(0), int64(0), int(0)}, + Received: i, + } +} + +// UintEncodeValue is the ValueEncoderFunc for uint types. +func (dve DefaultValueEncoders) UintEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + switch t := i.(type) { + case uint8: + return vw.WriteInt32(int32(t)) + case uint16: + return vw.WriteInt32(int32(t)) + case uint: + if ec.MinSize && t <= math.MaxInt32 { + return vw.WriteInt32(int32(t)) + } + if uint64(t) > math.MaxInt64 { + return fmt.Errorf("%d overflows int64", t) + } + return vw.WriteInt64(int64(t)) + case uint32: + if ec.MinSize && t <= math.MaxInt32 { + return vw.WriteInt32(int32(t)) + } + return vw.WriteInt64(int64(t)) + case uint64: + if ec.MinSize && t <= math.MaxInt32 { + return vw.WriteInt32(int32(t)) + } + if t > math.MaxInt64 { + return fmt.Errorf("%d overflows int64", t) + } + return vw.WriteInt64(int64(t)) + } + + val := reflect.ValueOf(i) + switch val.Type().Kind() { + case reflect.Uint8, reflect.Uint16: + return vw.WriteInt32(int32(val.Uint())) + case reflect.Uint, reflect.Uint32, reflect.Uint64: + u64 := val.Uint() + if ec.MinSize && u64 <= math.MaxInt32 { + return vw.WriteInt32(int32(u64)) + } + if u64 > math.MaxInt64 { + return fmt.Errorf("%d overflows int64", u64) + } + return vw.WriteInt64(int64(u64)) + } + + return ValueEncoderError{ + Name: "UintEncodeValue", + Types: []interface{}{uint8(0), uint16(0), uint32(0), uint64(0), uint(0)}, + Received: i, + } +} + +// FloatEncodeValue is the ValueEncoderFunc for float types. +func (dve DefaultValueEncoders) FloatEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + switch t := i.(type) { + case float32: + return vw.WriteDouble(float64(t)) + case float64: + return vw.WriteDouble(t) + } + + val := reflect.ValueOf(i) + switch val.Type().Kind() { + case reflect.Float32, reflect.Float64: + return vw.WriteDouble(val.Float()) + } + + return ValueEncoderError{Name: "FloatEncodeValue", Types: []interface{}{float32(0), float64(0)}, Received: i} +} + +// StringEncodeValue is the ValueEncoderFunc for string types. +func (dve DefaultValueEncoders) StringEncodeValue(ectx EncodeContext, vw ValueWriter, i interface{}) error { + switch t := i.(type) { + case string: + return vw.WriteString(t) + case bson.JavaScriptCode: + return vw.WriteJavascript(string(t)) + case bson.Symbol: + return vw.WriteSymbol(string(t)) + } + + val := reflect.ValueOf(i) + if val.Type().Kind() != reflect.String { + return ValueEncoderError{ + Name: "StringEncodeValue", + Types: []interface{}{string(""), bson.JavaScriptCode(""), bson.Symbol("")}, + Received: i, + } + } + + return vw.WriteString(val.String()) +} + +// DocumentEncodeValue is the ValueEncoderFunc for *bson.Document. +func (dve DefaultValueEncoders) DocumentEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + doc, ok := i.(*bson.Document) + if !ok { + return ValueEncoderError{Name: "DocumentEncodeValue", Types: []interface{}{(*bson.Document)(nil)}, Received: i} + } + + dw, err := vw.WriteDocument() + if err != nil { + return err + } + + return dve.encodeDocument(ec, dw, doc) +} + +// ArrayEncodeValue is the ValueEncoderFunc for *bson.Array. +func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + arr, ok := i.(*bson.Array) + if !ok { + return ValueEncoderError{Name: "ArrayEncodeValue", Types: []interface{}{(*bson.Array)(nil)}, Received: i} + } + + aw, err := vw.WriteArray() + if err != nil { + return err + } + + itr, err := arr.Iterator() + if err != nil { + return err + } + + for itr.Next() { + val := itr.Value() + dvw, err := aw.WriteArrayElement() + if err != nil { + return err + } + + err = dve.encodeValue(ec, dvw, val) + + if err != nil { + return err + } + } + + if err := itr.Err(); err != nil { + return err + } + + return aw.WriteArrayEnd() +} + +// encodeDocument is a separate function that we use because CodeWithScope +// returns us a DocumentWriter and we need to do the same logic that we would do +// for a document but cannot use a Codec. +func (dve DefaultValueEncoders) encodeDocument(ec EncodeContext, dw DocumentWriter, doc *bson.Document) error { + itr := doc.Iterator() + + for itr.Next() { + elem := itr.Element() + dvw, err := dw.WriteDocumentElement(elem.Key()) + if err != nil { + return err + } + + val := elem.Value() + err = dve.encodeValue(ec, dvw, val) + + if err != nil { + return err + } + } + + if err := itr.Err(); err != nil { + return err + } + + return dw.WriteDocumentEnd() +} + +// BinaryEncodeValue is the ValueEncoderFunc for bson.Binary. +func (dve DefaultValueEncoders) BinaryEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + var b bson.Binary + switch t := i.(type) { + case bson.Binary: + b = t + case *bson.Binary: + b = *t + default: + return ValueEncoderError{ + Name: "BinaryEncodeValue", + Types: []interface{}{bson.Binary{}, (*bson.Binary)(nil)}, + Received: i, + } + } + + return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) +} + +// UndefinedEncodeValue is the ValueEncoderFunc for bson.Undefined. +func (dve DefaultValueEncoders) UndefinedEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + switch i.(type) { + case bson.Undefinedv2, *bson.Undefinedv2: + default: + return ValueEncoderError{ + Name: "UndefinedEncodeValue", + Types: []interface{}{bson.Undefinedv2{}, (*bson.Undefinedv2)(nil)}, + Received: i, + } + } + + return vw.WriteUndefined() +} + +// ObjectIDEncodeValue is the ValueEncoderFunc for objectid.ObjectID. +func (dve DefaultValueEncoders) ObjectIDEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + var oid objectid.ObjectID + switch t := i.(type) { + case objectid.ObjectID: + oid = t + case *objectid.ObjectID: + oid = *t + default: + return ValueEncoderError{ + Name: "ObjectIDEncodeValue", + Types: []interface{}{objectid.ObjectID{}, (*objectid.ObjectID)(nil)}, + Received: i, + } + } + + return vw.WriteObjectID(oid) +} + +// DateTimeEncodeValue is the ValueEncoderFunc for bson.DateTime. +func (dve DefaultValueEncoders) DateTimeEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + var dt bson.DateTime + switch t := i.(type) { + case bson.DateTime: + dt = t + case *bson.DateTime: + dt = *t + default: + return ValueEncoderError{ + Name: "DateTimeEncodeValue", + Types: []interface{}{bson.DateTime(0), (*bson.DateTime)(nil)}, + Received: i, + } + } + + return vw.WriteDateTime(int64(dt)) +} + +// NullEncodeValue is the ValueEncoderFunc for bson.Null. +func (dve DefaultValueEncoders) NullEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + switch i.(type) { + case bson.Nullv2, *bson.Nullv2: + default: + return ValueEncoderError{ + Name: "NullEncodeValue", + Types: []interface{}{bson.Nullv2{}, (*bson.Nullv2)(nil)}, + Received: i, + } + } + + return vw.WriteNull() +} + +// RegexEncodeValue is the ValueEncoderFunc for bson.Regex. +func (dve DefaultValueEncoders) RegexEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + var regex bson.Regex + switch t := i.(type) { + case bson.Regex: + regex = t + case *bson.Regex: + regex = *t + default: + return ValueEncoderError{ + Name: "RegexEncodeValue", + Types: []interface{}{bson.Regex{}, (*bson.Regex)(nil)}, + Received: i, + } + } + + return vw.WriteRegex(regex.Pattern, regex.Options) +} + +// DBPointerEncodeValue is the ValueEncoderFunc for bson.DBPointer. +func (dve DefaultValueEncoders) DBPointerEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + var dbp bson.DBPointer + switch t := i.(type) { + case bson.DBPointer: + dbp = t + case *bson.DBPointer: + dbp = *t + default: + return ValueEncoderError{ + Name: "DBPointerEncodeValue", + Types: []interface{}{bson.DBPointer{}, (*bson.DBPointer)(nil)}, + Received: i, + } + } + + return vw.WriteDBPointer(dbp.DB, dbp.Pointer) +} + +// CodeWithScopeEncodeValue is the ValueEncoderFunc for bson.CodeWithScope. +func (dve DefaultValueEncoders) CodeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + var cws bson.CodeWithScope + switch t := i.(type) { + case bson.CodeWithScope: + cws = t + case *bson.CodeWithScope: + cws = *t + default: + return ValueEncoderError{ + Name: "CodeWithScopeEncodeValue", + Types: []interface{}{bson.CodeWithScope{}, (*bson.CodeWithScope)(nil)}, + Received: i, + } + } + + dw, err := vw.WriteCodeWithScope(cws.Code) + if err != nil { + return err + } + + return dve.encodeDocument(ec, dw, cws.Scope) +} + +// TimestampEncodeValue is the ValueEncoderFunc for bson.Timestamp. +func (dve DefaultValueEncoders) TimestampEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + var ts bson.Timestamp + switch t := i.(type) { + case bson.Timestamp: + ts = t + case *bson.Timestamp: + ts = *t + default: + return ValueEncoderError{ + Name: "TimestampEncodeValue", + Types: []interface{}{bson.Timestamp{}, (*bson.Timestamp)(nil)}, + Received: i, + } + } + + return vw.WriteTimestamp(ts.T, ts.I) +} + +// Decimal128EncodeValue is the ValueEncoderFunc for decimal.Decimal128. +func (dve DefaultValueEncoders) Decimal128EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + var d128 decimal.Decimal128 + switch t := i.(type) { + case decimal.Decimal128: + d128 = t + case *decimal.Decimal128: + d128 = *t + default: + return ValueEncoderError{ + Name: "Decimal128EncodeValue", + Types: []interface{}{decimal.Decimal128{}, (*decimal.Decimal128)(nil)}, + Received: i, + } + } + + return vw.WriteDecimal128(d128) +} + +// MinKeyEncodeValue is the ValueEncoderFunc for bson.MinKey. +func (dve DefaultValueEncoders) MinKeyEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + switch i.(type) { + case bson.MinKeyv2, *bson.MinKeyv2: + default: + return ValueEncoderError{ + Name: "MinKeyEncodeValue", + Types: []interface{}{bson.MinKeyv2{}, (*bson.MinKeyv2)(nil)}, + Received: i, + } + } + + return vw.WriteMinKey() +} + +// MaxKeyEncodeValue is the ValueEncoderFunc for bson.MaxKey. +func (dve DefaultValueEncoders) MaxKeyEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + switch i.(type) { + case bson.MaxKeyv2, *bson.MaxKeyv2: + default: + return ValueEncoderError{ + Name: "MaxKeyEncodeValue", + Types: []interface{}{bson.MaxKeyv2{}, (*bson.MaxKeyv2)(nil)}, + Received: i, + } + } + + return vw.WriteMaxKey() +} + +// elementEncodeValue is used internally to encode to values +func (dve DefaultValueEncoders) elementEncodeValue(ectx EncodeContext, vw ValueWriter, i interface{}) error { + elem, ok := i.(*bson.Element) + if !ok { + return ValueEncoderError{ + Name: "elementEncodeValue", + Types: []interface{}{(*bson.Element)(nil)}, + Received: i, + } + } + + if _, err := elem.Validate(); err != nil { + return err + } + + return dve.encodeValue(ectx, vw, elem.Value()) +} + +// ValueEncodeValue is the ValueEncoderFunc for *bson.Value. +func (dve DefaultValueEncoders) ValueEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + val, ok := i.(*bson.Value) + if !ok { + return ValueEncoderError{ + Name: "ValueEncodeValue", + Types: []interface{}{(*bson.Value)(nil)}, + Received: i, + } + } + + if err := val.Validate(); err != nil { + return err + } + + return dve.encodeValue(ec, vw, val) +} + +// JSONNumberEncodeValue is the ValueEncoderFunc for json.Number. +func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + var jsnum json.Number + switch t := i.(type) { + case json.Number: + jsnum = t + case *json.Number: + jsnum = *t + default: + return ValueEncoderError{ + Name: "JSONNumberEncodeValue", + Types: []interface{}{json.Number(""), (*json.Number)(nil)}, + Received: i, + } + } + + // Attempt int first, then float64 + if i64, err := jsnum.Int64(); err == nil { + return dve.IntEncodeValue(ec, vw, i64) + } + + f64, err := jsnum.Float64() + if err != nil { + return err + } + + return dve.FloatEncodeValue(ec, vw, f64) +} + +// URLEncodeValue is the ValueEncoderFunc for url.URL. +func (dve DefaultValueEncoders) URLEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + var u *url.URL + switch t := i.(type) { + case url.URL: + u = &t + case *url.URL: + u = t + default: + return ValueEncoderError{ + Name: "URLEncodeValue", + Types: []interface{}{url.URL{}, (*url.URL)(nil)}, + Received: i, + } + } + + return vw.WriteString(u.String()) +} + +// TimeEncodeValue is the ValueEncoderFunc for time.TIme. +func (dve DefaultValueEncoders) TimeEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + var tt time.Time + switch t := i.(type) { + case time.Time: + tt = t + case *time.Time: + tt = *t + default: + return ValueEncoderError{ + Name: "TimeEncodeValue", + Types: []interface{}{time.Time{}, (*time.Time)(nil)}, + Received: i, + } + } + + return vw.WriteDateTime(tt.Unix()*1000 + int64(tt.Nanosecond()/1e6)) +} + +// ReaderEncodeValue is the ValueEncoderFunc for bson.Reader. +func (dve DefaultValueEncoders) ReaderEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + rdr, ok := i.(bson.Reader) + if !ok { + return ValueEncoderError{ + Name: "ReaderEncodeValue", + Types: []interface{}{bson.Reader{}}, + Received: i, + } + } + + return (Copier{r: ec.Registry}).CopyDocumentFromBytes(vw, rdr) +} + +// ByteSliceEncodeValue is the ValueEncoderFunc for []byte. +func (dve DefaultValueEncoders) ByteSliceEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + var slcb []byte + switch t := i.(type) { + case []byte: + slcb = t + case *[]byte: + slcb = *t + default: + return ValueEncoderError{ + Name: "ByteSliceEncodeValue", + Types: []interface{}{[]byte{}, (*[]byte)(nil)}, + Received: i, + } + } + + return vw.WriteBinary(slcb) +} + +// ElementSliceEncodeValue is the ValueEncoderFunc for []*bson.Element. +func (dve DefaultValueEncoders) ElementSliceEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + var slce []*bson.Element + switch t := i.(type) { + case []*bson.Element: + slce = t + case *[]*bson.Element: + slce = *t + default: + return ValueEncoderError{ + Name: "ElementSliceEncodeValue", + Types: []interface{}{[]*bson.Element{}, (*[]*bson.Element)(nil)}, + Received: i, + } + } + + return dve.DocumentEncodeValue(ec, vw, (&bson.Document{}).Append(slce...)) +} + +// MapEncodeValue is the ValueEncoderFunc for map[string]* types. +func (dve DefaultValueEncoders) MapEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + val := reflect.ValueOf(i) + if val.Kind() != reflect.Map || val.Type().Key().Kind() != reflect.String { + return errors.New("MapEncodeValue can only encode maps with string keys") + } + + dw, err := vw.WriteDocument() + if err != nil { + return err + } + + return dve.mapEncodeValue(ec, dw, val, nil) +} + +// mapEncodeValue handles encoding of the values of a map. The collisionFn returns +// true if the provided key exists, this is mainly used for inline maps in the +// struct codec. +func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { + + var err error + var encoder ValueEncoder + switch val.Type().Elem() { + case tElement: + encoder = ValueEncoderFunc(dve.elementEncodeValue) + default: + encoder, err = ec.LookupEncoder(val.Type().Elem()) + if err != nil { + return err + } + } + + keys := val.MapKeys() + for _, key := range keys { + if collisionFn != nil && collisionFn(key.String()) { + return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key) + } + vw, err := dw.WriteDocumentElement(key.String()) + if err != nil { + return err + } + + err = encoder.EncodeValue(ec, vw, val.MapIndex(key).Interface()) + if err != nil { + return err + } + } + + return dw.WriteDocumentEnd() +} + +// SliceEncodeValue is the ValueEncoderFunc for []* types. +func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + val := reflect.ValueOf(i) + switch val.Kind() { + case reflect.Array: + case reflect.Slice: + if val.IsNil() { // When nil, special case to null + return vw.WriteNull() + } + default: + return errors.New("SliceEncodeValue can only encode arrays and slices") + } + + length := val.Len() + + aw, err := vw.WriteArray() + if err != nil { + return err + } + + // We do this outside of the loop because an array or a slice can only have + // one element type. If it's the empty interface, we'll use the empty + // interface codec. + var encoder ValueEncoder + switch val.Type().Elem() { + case tElement: + encoder = ValueEncoderFunc(dve.elementEncodeValue) + default: + encoder, err = ec.LookupEncoder(val.Type().Elem()) + if err != nil { + return err + } + } + for idx := 0; idx < length; idx++ { + vw, err := aw.WriteArrayElement() + if err != nil { + return err + } + + err = encoder.EncodeValue(ec, vw, val.Index(idx).Interface()) + if err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +// EmptyInterfaceEncodeValue is the ValueEncoderFunc for interface{}. +func (dve DefaultValueEncoders) EmptyInterfaceEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + encoder, err := ec.LookupEncoder(reflect.TypeOf(i)) + if err != nil { + return err + } + + return encoder.EncodeValue(ec, vw, i) +} + +// ValueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations. +func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { + vm, ok := i.(ValueMarshaler) + if !ok { + return ValueEncoderError{ + Name: "ValueMarshalerEncodeValue", + Types: []interface{}{(ValueMarshaler)(nil)}, + Received: i, + } + } + + t, val, err := vm.MarshalBSONValue() + if err != nil { + return err + } + return Copier{r: ec.Registry}.CopyValueFromBytes(vw, t, val) +} + +// encodeValue does not validation, and the callers must perform validation on val before calling +// this method. +func (dve DefaultValueEncoders) encodeValue(ec EncodeContext, vw ValueWriter, val *bson.Value) error { + var err error + switch val.Type() { + case bson.TypeDouble: + err = vw.WriteDouble(val.Double()) + case bson.TypeString: + err = vw.WriteString(val.StringValue()) + case bson.TypeEmbeddedDocument: + var encoder ValueEncoder + encoder, err = ec.LookupEncoder(tDocument) + if err != nil { + break + } + err = encoder.EncodeValue(ec, vw, val.MutableDocument()) + case bson.TypeArray: + var encoder ValueEncoder + encoder, err = ec.LookupEncoder(tArray) + if err != nil { + break + } + err = encoder.EncodeValue(ec, vw, val.MutableArray()) + case bson.TypeBinary: + // TODO: FIX THIS (╯°□°)╯︵ ┻━┻ + subtype, data := val.Binary() + err = vw.WriteBinaryWithSubtype(data, subtype) + case bson.TypeUndefined: + err = vw.WriteUndefined() + case bson.TypeObjectID: + err = vw.WriteObjectID(val.ObjectID()) + case bson.TypeBoolean: + err = vw.WriteBoolean(val.Boolean()) + case bson.TypeDateTime: + err = vw.WriteDateTime(val.DateTime()) + case bson.TypeNull: + err = vw.WriteNull() + case bson.TypeRegex: + err = vw.WriteRegex(val.Regex()) + case bson.TypeDBPointer: + err = vw.WriteDBPointer(val.DBPointer()) + case bson.TypeJavaScript: + err = vw.WriteJavascript(val.JavaScript()) + case bson.TypeSymbol: + err = vw.WriteSymbol(val.Symbol()) + case bson.TypeCodeWithScope: + code, scope := val.MutableJavaScriptWithScope() + + var cwsw DocumentWriter + cwsw, err = vw.WriteCodeWithScope(code) + if err != nil { + break + } + + err = dve.encodeDocument(ec, cwsw, scope) + case bson.TypeInt32: + err = vw.WriteInt32(val.Int32()) + case bson.TypeTimestamp: + err = vw.WriteTimestamp(val.Timestamp()) + case bson.TypeInt64: + err = vw.WriteInt64(val.Int64()) + case bson.TypeDecimal128: + err = vw.WriteDecimal128(val.Decimal128()) + case bson.TypeMinKey: + err = vw.WriteMinKey() + case bson.TypeMaxKey: + err = vw.WriteMaxKey() + default: + err = fmt.Errorf("%T is not a valid BSON type to encode", val.Type()) + } + + return err +} diff --git a/bson/bsoncodec/default_value_encoders_test.go b/bson/bsoncodec/default_value_encoders_test.go new file mode 100644 index 0000000000..7ada55ae7e --- /dev/null +++ b/bson/bsoncodec/default_value_encoders_test.go @@ -0,0 +1,1675 @@ +package bsoncodec + +import ( + "encoding/json" + "errors" + "fmt" + "net/url" + "reflect" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/decimal" + "github.com/mongodb/mongo-go-driver/bson/internal/llbson" + "github.com/mongodb/mongo-go-driver/bson/objectid" +) + +func TestDefaultValueEncoders(t *testing.T) { + var dve DefaultValueEncoders + var wrong = func(string, string) string { return "wrong" } + + type mybool bool + type myint8 int8 + type myint16 int16 + type myint32 int32 + type myint64 int64 + type myint int + type myuint8 uint8 + type myuint16 uint16 + type myuint32 uint32 + type myuint64 uint64 + type myuint uint + type myfloat32 float32 + type myfloat64 float64 + type mystring string + + intAllowedTypes := []interface{}{int8(0), int16(0), int32(0), int64(0), int(0)} + uintAllowedEncodeTypes := []interface{}{uint8(0), uint16(0), uint32(0), uint64(0), uint(0)} + + now := time.Now().Truncate(time.Millisecond) + pdatetime := new(bson.DateTime) + *pdatetime = bson.DateTime(1234567890) + pjsnum := new(json.Number) + *pjsnum = json.Number("3.14159") + d128 := decimal.NewDecimal128(12345, 67890) + + type subtest struct { + name string + val interface{} + ectx *EncodeContext + llvrw *llValueReaderWriter + invoke llvrwInvoked + err error + } + + testCases := []struct { + name string + ve ValueEncoder + subtests []subtest + }{ + { + "BooleanEncodeValue", + ValueEncoderFunc(dve.BooleanEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{Name: "BooleanEncodeValue", Types: []interface{}{bool(true)}, Received: wrong}, + }, + {"fast path", bool(true), nil, nil, llvrwWriteBoolean, nil}, + {"reflection path", mybool(true), nil, nil, llvrwWriteBoolean, nil}, + }, + }, + { + "IntEncodeValue", + ValueEncoderFunc(dve.IntEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{Name: "IntEncodeValue", Types: intAllowedTypes, Received: wrong}, + }, + {"int8/fast path", int8(127), nil, nil, llvrwWriteInt32, nil}, + {"int16/fast path", int16(32767), nil, nil, llvrwWriteInt32, nil}, + {"int32/fast path", int32(2147483647), nil, nil, llvrwWriteInt32, nil}, + {"int64/fast path", int64(1234567890987), nil, nil, llvrwWriteInt64, nil}, + {"int/fast path", int(1234567), nil, nil, llvrwWriteInt64, nil}, + {"int64/fast path - minsize", int64(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, + {"int/fast path - minsize", int(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, + {"int64/fast path - minsize too large", int64(2147483648), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, + {"int/fast path - minsize too large", int(2147483648), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, + {"int8/reflection path", myint8(127), nil, nil, llvrwWriteInt32, nil}, + {"int16/reflection path", myint16(32767), nil, nil, llvrwWriteInt32, nil}, + {"int32/reflection path", myint32(2147483647), nil, nil, llvrwWriteInt32, nil}, + {"int64/reflection path", myint64(1234567890987), nil, nil, llvrwWriteInt64, nil}, + {"int/reflection path", myint(1234567890987), nil, nil, llvrwWriteInt64, nil}, + {"int64/reflection path - minsize", myint64(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, + {"int/reflection path - minsize", myint(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, + {"int64/reflection path - minsize too large", myint64(2147483648), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, + {"int/reflection path - minsize too large", myint(2147483648), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, + }, + }, + { + "UintEncodeValue", + ValueEncoderFunc(dve.UintEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{Name: "UintEncodeValue", Types: uintAllowedEncodeTypes, Received: wrong}, + }, + {"uint8/fast path", uint8(127), nil, nil, llvrwWriteInt32, nil}, + {"uint16/fast path", uint16(32767), nil, nil, llvrwWriteInt32, nil}, + {"uint32/fast path", uint32(2147483647), nil, nil, llvrwWriteInt64, nil}, + {"uint64/fast path", uint64(1234567890987), nil, nil, llvrwWriteInt64, nil}, + {"uint/fast path", uint(1234567), nil, nil, llvrwWriteInt64, nil}, + {"uint32/fast path - minsize", uint32(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, + {"uint64/fast path - minsize", uint64(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, + {"uint/fast path - minsize", uint(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, + {"uint32/fast path - minsize too large", uint32(2147483648), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, + {"uint64/fast path - minsize too large", uint64(2147483648), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, + {"uint/fast path - minsize too large", uint(2147483648), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, + {"uint64/fast path - overflow", uint64(1 << 63), nil, nil, llvrwNothing, fmt.Errorf("%d overflows int64", uint(1<<63))}, + {"uint/fast path - overflow", uint(1 << 63), nil, nil, llvrwNothing, fmt.Errorf("%d overflows int64", uint(1<<63))}, + {"uint8/reflection path", myuint8(127), nil, nil, llvrwWriteInt32, nil}, + {"uint16/reflection path", myuint16(32767), nil, nil, llvrwWriteInt32, nil}, + {"uint32/reflection path", myuint32(2147483647), nil, nil, llvrwWriteInt64, nil}, + {"uint64/reflection path", myuint64(1234567890987), nil, nil, llvrwWriteInt64, nil}, + {"uint/reflection path", myuint(1234567890987), nil, nil, llvrwWriteInt64, nil}, + {"uint32/reflection path - minsize", myuint32(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, + {"uint64/reflection path - minsize", myuint64(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, + {"uint/reflection path - minsize", myuint(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, + {"uint32/reflection path - minsize too large", myuint(1 << 31), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, + {"uint64/reflection path - minsize too large", myuint64(1 << 31), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, + {"uint/reflection path - minsize too large", myuint(2147483648), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, + {"uint64/reflection path - overflow", myuint64(1 << 63), nil, nil, llvrwNothing, fmt.Errorf("%d overflows int64", uint(1<<63))}, + {"uint/reflection path - overflow", myuint(1 << 63), nil, nil, llvrwNothing, fmt.Errorf("%d overflows int64", uint(1<<63))}, + }, + }, + { + "FloatEncodeValue", + ValueEncoderFunc(dve.FloatEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{Name: "FloatEncodeValue", Types: []interface{}{float32(0), float64(0)}, Received: wrong}, + }, + {"float32/fast path", float32(3.14159), nil, nil, llvrwWriteDouble, nil}, + {"float64/fast path", float64(3.14159), nil, nil, llvrwWriteDouble, nil}, + {"float32/reflection path", myfloat32(3.14159), nil, nil, llvrwWriteDouble, nil}, + {"float64/reflection path", myfloat64(3.14159), nil, nil, llvrwWriteDouble, nil}, + }, + }, + { + "StringEncodeValue", + ValueEncoderFunc(dve.StringEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "StringEncodeValue", + Types: []interface{}{string(""), bson.JavaScriptCode(""), bson.Symbol("")}, + Received: wrong, + }, + }, + {"string/fast path", string("foobar"), nil, nil, llvrwWriteString, nil}, + {"JavaScript/fast path", bson.JavaScriptCode("foobar"), nil, nil, llvrwWriteJavascript, nil}, + {"Symbol/fast path", bson.Symbol("foobar"), nil, nil, llvrwWriteSymbol, nil}, + {"reflection path", mystring("foobarbaz"), nil, nil, llvrwWriteString, nil}, + }, + }, + { + "TimeEncodeValue", + ValueEncoderFunc(dve.TimeEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "TimeEncodeValue", + Types: []interface{}{time.Time{}, (*time.Time)(nil)}, + Received: wrong, + }, + }, + {"time.Time", now, nil, nil, llvrwWriteDateTime, nil}, + {"*time.Time", &now, nil, nil, llvrwWriteDateTime, nil}, + }, + }, + { + "MapEncodeValue", + ValueEncoderFunc(dve.MapEncodeValue), + []subtest{ + { + "wrong kind", + wrong, + nil, + nil, + llvrwNothing, + errors.New("MapEncodeValue can only encode maps with string keys"), + }, + { + "wrong kind (non-string key)", + map[int]interface{}{}, + nil, + nil, + llvrwNothing, + errors.New("MapEncodeValue can only encode maps with string keys"), + }, + { + "WriteDocument Error", + map[string]interface{}{}, + nil, + &llValueReaderWriter{err: errors.New("wd error"), errAfter: llvrwWriteDocument}, + llvrwWriteDocument, + errors.New("wd error"), + }, + { + "Lookup Error", + map[string]interface{}{}, + &EncodeContext{Registry: NewEmptyRegistryBuilder().Build()}, + &llValueReaderWriter{}, + llvrwWriteDocument, + ErrNoEncoder{Type: reflect.TypeOf((*interface{})(nil)).Elem()}, + }, + { + "WriteDocumentElement Error", + map[string]interface{}{"foo": "bar"}, + &EncodeContext{Registry: NewRegistryBuilder().Build()}, + &llValueReaderWriter{err: errors.New("wde error"), errAfter: llvrwWriteDocumentElement}, + llvrwWriteDocumentElement, + errors.New("wde error"), + }, + { + "EncodeValue Error", + map[string]interface{}{"foo": "bar"}, + &EncodeContext{Registry: NewRegistryBuilder().Build()}, + &llValueReaderWriter{err: errors.New("ev error"), errAfter: llvrwWriteString}, + llvrwWriteString, + errors.New("ev error"), + }, + }, + }, + { + "SliceEncodeValue", + ValueEncoderFunc(dve.SliceEncodeValue), + []subtest{ + { + "wrong kind", + wrong, + nil, + nil, + llvrwNothing, + errors.New("SliceEncodeValue can only encode arrays and slices"), + }, + { + "WriteArray Error", + []string{}, + nil, + &llValueReaderWriter{err: errors.New("wa error"), errAfter: llvrwWriteArray}, + llvrwWriteArray, + errors.New("wa error"), + }, + { + "Lookup Error", + []interface{}{}, + &EncodeContext{Registry: NewEmptyRegistryBuilder().Build()}, + &llValueReaderWriter{}, + llvrwWriteArray, + ErrNoEncoder{Type: reflect.TypeOf((*interface{})(nil)).Elem()}, + }, + { + "WriteArrayElement Error", + []string{"foo"}, + &EncodeContext{Registry: NewRegistryBuilder().Build()}, + &llValueReaderWriter{err: errors.New("wae error"), errAfter: llvrwWriteArrayElement}, + llvrwWriteArrayElement, + errors.New("wae error"), + }, + { + "EncodeValue Error", + []string{"foo"}, + &EncodeContext{Registry: NewRegistryBuilder().Build()}, + &llValueReaderWriter{err: errors.New("ev error"), errAfter: llvrwWriteString}, + llvrwWriteString, + errors.New("ev error"), + }, + }, + }, + { + "BinaryEncodeValue", + ValueEncoderFunc(dve.BinaryEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "BinaryEncodeValue", + Types: []interface{}{bson.Binary{}, (*bson.Binary)(nil)}, + Received: wrong, + }, + }, + {"Binary/success", bson.Binary{Data: []byte{0x01, 0x02}, Subtype: 0xFF}, nil, nil, llvrwWriteBinaryWithSubtype, nil}, + {"*Binary/success", &bson.Binary{Data: []byte{0x01, 0x02}, Subtype: 0xFF}, nil, nil, llvrwWriteBinaryWithSubtype, nil}, + }, + }, + { + "UndefinedEncodeValue", + ValueEncoderFunc(dve.UndefinedEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "UndefinedEncodeValue", + Types: []interface{}{bson.Undefinedv2{}, (*bson.Undefinedv2)(nil)}, + Received: wrong, + }, + }, + {"Undefined/success", bson.Undefinedv2{}, nil, nil, llvrwWriteUndefined, nil}, + {"*Undefined/success", &bson.Undefinedv2{}, nil, nil, llvrwWriteUndefined, nil}, + }, + }, + { + "ObjectIDEncodeValue", + ValueEncoderFunc(dve.ObjectIDEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "ObjectIDEncodeValue", + Types: []interface{}{objectid.ObjectID{}, (*objectid.ObjectID)(nil)}, + Received: wrong, + }, + }, + { + "objectid.ObjectID/success", + objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, + nil, nil, llvrwWriteObjectID, nil, + }, + { + "*objectid.ObjectID/success", + &objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, + nil, nil, llvrwWriteObjectID, nil, + }, + }, + }, + { + "DateTimeEncodeValue", + ValueEncoderFunc(dve.DateTimeEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "DateTimeEncodeValue", + Types: []interface{}{bson.DateTime(0), (*bson.DateTime)(nil)}, + Received: wrong, + }, + }, + {"DateTime/success", bson.DateTime(1234567890), nil, nil, llvrwWriteDateTime, nil}, + {"*DateTime/success", pdatetime, nil, nil, llvrwWriteDateTime, nil}, + }, + }, + { + "NullEncodeValue", + ValueEncoderFunc(dve.NullEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "NullEncodeValue", + Types: []interface{}{bson.Nullv2{}, (*bson.Nullv2)(nil)}, + Received: wrong, + }, + }, + {"Null/success", bson.Nullv2{}, nil, nil, llvrwWriteNull, nil}, + {"*Null/success", &bson.Nullv2{}, nil, nil, llvrwWriteNull, nil}, + }, + }, + { + "RegexEncodeValue", + ValueEncoderFunc(dve.RegexEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "RegexEncodeValue", + Types: []interface{}{bson.Regex{}, (*bson.Regex)(nil)}, + Received: wrong, + }, + }, + {"Regex/success", bson.Regex{Pattern: "foo", Options: "bar"}, nil, nil, llvrwWriteRegex, nil}, + {"*Regex/success", &bson.Regex{Pattern: "foo", Options: "bar"}, nil, nil, llvrwWriteRegex, nil}, + }, + }, + { + "DBPointerEncodeValue", + ValueEncoderFunc(dve.DBPointerEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "DBPointerEncodeValue", + Types: []interface{}{bson.DBPointer{}, (*bson.DBPointer)(nil)}, + Received: wrong, + }, + }, + { + "DBPointer/success", + bson.DBPointer{ + DB: "foobar", + Pointer: objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, + }, + nil, nil, llvrwWriteDBPointer, nil, + }, + { + "*DBPointer/success", + &bson.DBPointer{ + DB: "foobar", + Pointer: objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, + }, + nil, nil, llvrwWriteDBPointer, nil, + }, + }, + }, + { + "CodeWithScopeEncodeValue", + ValueEncoderFunc(dve.CodeWithScopeEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "CodeWithScopeEncodeValue", + Types: []interface{}{bson.CodeWithScope{}, (*bson.CodeWithScope)(nil)}, + Received: wrong, + }, + }, + { + "WriteCodeWithScope error", + bson.CodeWithScope{}, + nil, + &llValueReaderWriter{err: errors.New("wcws error"), errAfter: llvrwWriteCodeWithScope}, + llvrwWriteCodeWithScope, + errors.New("wcws error"), + }, + { + "CodeWithScope/success", + bson.CodeWithScope{ + Code: "var hello = 'world';", + Scope: bson.NewDocument(), + }, + nil, nil, llvrwWriteDocumentEnd, nil, + }, + { + "*CodeWithScope/success", + &bson.CodeWithScope{ + Code: "var hello = 'world';", + Scope: bson.NewDocument(), + }, + nil, nil, llvrwWriteDocumentEnd, nil, + }, + }, + }, + { + "TimestampEncodeValue", + ValueEncoderFunc(dve.TimestampEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "TimestampEncodeValue", + Types: []interface{}{bson.Timestamp{}, (*bson.Timestamp)(nil)}, + Received: wrong, + }, + }, + {"Timestamp/success", bson.Timestamp{T: 12345, I: 67890}, nil, nil, llvrwWriteTimestamp, nil}, + {"*Timestamp/success", &bson.Timestamp{T: 12345, I: 67890}, nil, nil, llvrwWriteTimestamp, nil}, + }, + }, + { + "Decimal128EncodeValue", + ValueEncoderFunc(dve.Decimal128EncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "Decimal128EncodeValue", + Types: []interface{}{decimal.Decimal128{}, (*decimal.Decimal128)(nil)}, + Received: wrong, + }, + }, + {"Decimal128/success", d128, nil, nil, llvrwWriteDecimal128, nil}, + {"*Decimal128/success", &d128, nil, nil, llvrwWriteDecimal128, nil}, + }, + }, + { + "MinKeyEncodeValue", + ValueEncoderFunc(dve.MinKeyEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "MinKeyEncodeValue", + Types: []interface{}{bson.MinKeyv2{}, (*bson.MinKeyv2)(nil)}, + Received: wrong, + }, + }, + {"MinKey/success", bson.MinKeyv2{}, nil, nil, llvrwWriteMinKey, nil}, + {"*MinKey/success", &bson.MinKeyv2{}, nil, nil, llvrwWriteMinKey, nil}, + }, + }, + { + "MaxKeyEncodeValue", + ValueEncoderFunc(dve.MaxKeyEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "MaxKeyEncodeValue", + Types: []interface{}{bson.MaxKeyv2{}, (*bson.MaxKeyv2)(nil)}, + Received: wrong, + }, + }, + {"MaxKey/success", bson.MaxKeyv2{}, nil, nil, llvrwWriteMaxKey, nil}, + {"*MaxKey/success", &bson.MaxKeyv2{}, nil, nil, llvrwWriteMaxKey, nil}, + }, + }, + { + "elementEncodeValue", + ValueEncoderFunc(dve.elementEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{Name: "elementEncodeValue", Types: []interface{}{(*bson.Element)(nil)}, Received: wrong}, + }, + {"invalid element", (*bson.Element)(nil), nil, nil, llvrwNothing, bson.ErrNilElement}, + { + "success", + bson.EC.Null("foo"), + &EncodeContext{Registry: NewRegistryBuilder().Build()}, + &llValueReaderWriter{}, + llvrwWriteNull, + nil, + }, + }, + }, + { + "ValueEncodeValue", + ValueEncoderFunc(dve.ValueEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{Name: "ValueEncodeValue", Types: []interface{}{(*bson.Value)(nil)}, Received: wrong}, + }, + {"invalid value", &bson.Value{}, nil, nil, llvrwNothing, bson.ErrUninitializedElement}, + { + "success", + bson.VC.Null(), + &EncodeContext{Registry: NewRegistryBuilder().Build()}, + &llValueReaderWriter{}, + llvrwWriteNull, + nil, + }, + }, + }, + { + "JSONNumberEncodeValue", + ValueEncoderFunc(dve.JSONNumberEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "JSONNumberEncodeValue", + Types: []interface{}{json.Number(""), (*json.Number)(nil)}, + Received: wrong, + }, + }, + { + "json.Number/invalid", + json.Number("hello world"), + nil, nil, llvrwNothing, errors.New(`strconv.ParseFloat: parsing "hello world": invalid syntax`), + }, + { + "json.Number/int64/success", + json.Number("1234567890"), + nil, nil, llvrwWriteInt64, nil, + }, + { + "json.Number/float64/success", + json.Number("3.14159"), + nil, nil, llvrwWriteDouble, nil, + }, + { + "*json.Number/int64/success", + pjsnum, + nil, nil, llvrwWriteDouble, nil, + }, + }, + }, + { + "URLEncodeValue", + ValueEncoderFunc(dve.URLEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "URLEncodeValue", + Types: []interface{}{url.URL{}, (*url.URL)(nil)}, + Received: wrong, + }, + }, + {"url.URL", url.URL{Scheme: "http", Host: "example.com"}, nil, nil, llvrwWriteString, nil}, + {"*url.URL", &url.URL{Scheme: "http", Host: "example.com"}, nil, nil, llvrwWriteString, nil}, + }, + }, + { + "ReaderEncodeValue", + ValueEncoderFunc(dve.ReaderEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{Name: "ReaderEncodeValue", Types: []interface{}{bson.Reader{}}, Received: wrong}, + }, + { + "WriteDocument Error", + bson.Reader{}, + nil, + &llValueReaderWriter{err: errors.New("wd error"), errAfter: llvrwWriteDocument}, + llvrwWriteDocument, + errors.New("wd error"), + }, + { + "Reader.Iterator Error", + bson.Reader{0xFF, 0x00, 0x00, 0x00, 0x00}, + nil, + &llValueReaderWriter{}, + llvrwWriteDocument, + bson.ErrInvalidLength, + }, + { + "WriteDocumentElement Error", + bson.Reader(bytesFromDoc(bson.NewDocument(bson.EC.Null("foo")))), + nil, + &llValueReaderWriter{err: errors.New("wde error"), errAfter: llvrwWriteDocumentElement}, + llvrwWriteDocumentElement, + errors.New("wde error"), + }, + { + "encodeValue error", + bson.Reader(bytesFromDoc(bson.NewDocument(bson.EC.Null("foo")))), + nil, + &llValueReaderWriter{err: errors.New("ev error"), errAfter: llvrwWriteNull}, + llvrwWriteNull, + errors.New("ev error"), + }, + { + "iterator error", + bson.Reader{0x0C, 0x00, 0x00, 0x00, 0x01, 'f', 'o', 'o', 0x00, 0x01, 0x02, 0x03}, + nil, + &llValueReaderWriter{}, + llvrwWriteDocument, + bson.NewErrTooSmall(), + }, + }, + }, + { + "ByteSliceEncodeValue", + ValueEncoderFunc(dve.ByteSliceEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "ByteSliceEncodeValue", + Types: []interface{}{[]byte{}, (*[]byte)(nil)}, + Received: wrong, + }, + }, + {"[]byte", []byte{0x01, 0x02, 0x03}, nil, nil, llvrwWriteBinary, nil}, + {"*[]byte", &([]byte{0x01, 0x02, 0x03}), nil, nil, llvrwWriteBinary, nil}, + }, + }, + { + "ValueMarshalerEncodeValue", + ValueEncoderFunc(dve.ValueMarshalerEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + llvrwNothing, + ValueEncoderError{ + Name: "ValueMarshalerEncodeValue", + Types: []interface{}{(ValueMarshaler)(nil)}, + Received: wrong, + }, + }, + { + "MarshalBSONValue error", + testValueMarshaler{err: errors.New("mbsonv error")}, + nil, + nil, + llvrwNothing, + errors.New("mbsonv error"), + }, + { + "Copy error", + testValueMarshaler{}, + nil, + nil, + llvrwNothing, + fmt.Errorf("Cannot copy unknown BSON type %s", bson.Type(0)), + }, + { + "success", + testValueMarshaler{t: bson.TypeString, buf: []byte{0x04, 0x00, 0x00, 0x00, 'f', 'o', 'o', 0x00}}, + nil, + nil, + llvrwWriteString, + nil, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for _, subtest := range tc.subtests { + t.Run(subtest.name, func(t *testing.T) { + var ec EncodeContext + if subtest.ectx != nil { + ec = *subtest.ectx + } + llvrw := new(llValueReaderWriter) + if subtest.llvrw != nil { + llvrw = subtest.llvrw + } + llvrw.t = t + err := tc.ve.EncodeValue(ec, llvrw, subtest.val) + if !compareErrors(err, subtest.err) { + t.Errorf("Errors do not match. got %v; want %v", err, subtest.err) + } + invoked := llvrw.invoked + if !cmp.Equal(invoked, subtest.invoke) { + t.Errorf("Incorrect method invoked. got %v; want %v", invoked, subtest.invoke) + } + }) + } + }) + } + + t.Run("DocumentEncodeValue", func(t *testing.T) { + t.Run("ValueEncoderError", func(t *testing.T) { + val := bool(true) + want := ValueEncoderError{Name: "DocumentEncodeValue", Types: []interface{}{(*bson.Document)(nil)}, Received: val} + got := (DefaultValueEncoders{}).DocumentEncodeValue(EncodeContext{}, nil, val) + if !compareErrors(got, want) { + t.Errorf("Errors do not match. got %v; want %v", got, want) + } + }) + t.Run("WriteDocument Error", func(t *testing.T) { + want := errors.New("WriteDocument Error") + llvrw := &llValueReaderWriter{ + t: t, + err: want, + errAfter: llvrwWriteDocument, + } + got := (DefaultValueEncoders{}).DocumentEncodeValue(EncodeContext{}, llvrw, bson.NewDocument()) + if !compareErrors(got, want) { + t.Errorf("Errors do not match. got %v; want %v", got, want) + } + }) + t.Run("encodeDocument errors", func(t *testing.T) { + ec := EncodeContext{} + err := errors.New("encodeDocument error") + oid := objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C} + badelem := &bson.Element{} + testCases := []struct { + name string + ec EncodeContext + llvrw *llValueReaderWriter + doc *bson.Document + err error + }{ + { + "WriteDocumentElement", + ec, + &llValueReaderWriter{t: t, err: errors.New("wde error"), errAfter: llvrwWriteDocumentElement}, + bson.NewDocument(bson.EC.Null("foo")), + errors.New("wde error"), + }, + { + "WriteDouble", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteDouble}, + bson.NewDocument(bson.EC.Double("foo", 3.14159)), err, + }, + { + "WriteString", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteString}, + bson.NewDocument(bson.EC.String("foo", "bar")), err, + }, + { + "WriteDocument (Lookup)", EncodeContext{Registry: NewEmptyRegistryBuilder().Build()}, + &llValueReaderWriter{t: t}, + bson.NewDocument(bson.EC.SubDocument("foo", bson.NewDocument(bson.EC.Null("bar")))), + ErrNoEncoder{Type: tDocument}, + }, + { + "WriteArray (Lookup)", EncodeContext{Registry: NewEmptyRegistryBuilder().Build()}, + &llValueReaderWriter{t: t}, + bson.NewDocument(bson.EC.Array("foo", bson.NewArray(bson.VC.Null()))), + ErrNoEncoder{Type: tArray}, + }, + { + "WriteBinary", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteBinaryWithSubtype}, + bson.NewDocument(bson.EC.BinaryWithSubtype("foo", []byte{0x01, 0x02, 0x03}, 0xFF)), err, + }, + { + "WriteUndefined", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteUndefined}, + bson.NewDocument(bson.EC.Undefined("foo")), err, + }, + { + "WriteObjectID", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteObjectID}, + bson.NewDocument(bson.EC.ObjectID("foo", oid)), err, + }, + { + "WriteBoolean", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteBoolean}, + bson.NewDocument(bson.EC.Boolean("foo", true)), err, + }, + { + "WriteDateTime", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteDateTime}, + bson.NewDocument(bson.EC.DateTime("foo", 1234567890)), err, + }, + { + "WriteNull", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteNull}, + bson.NewDocument(bson.EC.Null("foo")), err, + }, + { + "WriteRegex", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteRegex}, + bson.NewDocument(bson.EC.Regex("foo", "bar", "baz")), err, + }, + { + "WriteDBPointer", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteDBPointer}, + bson.NewDocument(bson.EC.DBPointer("foo", "bar", oid)), err, + }, + { + "WriteJavascript", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteJavascript}, + bson.NewDocument(bson.EC.JavaScript("foo", "var hello = 'world';")), err, + }, + { + "WriteSymbol", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteSymbol}, + bson.NewDocument(bson.EC.Symbol("foo", "symbolbaz")), err, + }, + { + "WriteCodeWithScope (Lookup)", EncodeContext{Registry: NewEmptyRegistryBuilder().Build()}, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteCodeWithScope}, + bson.NewDocument(bson.EC.CodeWithScope("foo", "var hello = 'world';", bson.NewDocument(bson.EC.Null("bar")))), + err, + }, + { + "WriteInt32", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteInt32}, + bson.NewDocument(bson.EC.Int32("foo", 12345)), err, + }, + { + "WriteInt64", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteInt64}, + bson.NewDocument(bson.EC.Int64("foo", 1234567890)), err, + }, + { + "WriteTimestamp", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteTimestamp}, + bson.NewDocument(bson.EC.Timestamp("foo", 10, 20)), err, + }, + { + "WriteDecimal128", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteDecimal128}, + bson.NewDocument(bson.EC.Decimal128("foo", decimal.NewDecimal128(10, 20))), err, + }, + { + "WriteMinKey", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteMinKey}, + bson.NewDocument(bson.EC.MinKey("foo")), err, + }, + { + "WriteMaxKey", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteMaxKey}, + bson.NewDocument(bson.EC.MaxKey("foo")), err, + }, + { + "Invalid Type", ec, + &llValueReaderWriter{t: t, bsontype: bson.Type(0)}, + bson.NewDocument(badelem), + bson.ErrUninitializedElement, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := (DefaultValueEncoders{}).DocumentEncodeValue(tc.ec, tc.llvrw, tc.doc) + if !compareErrors(err, tc.err) { + t.Errorf("Errors do not match. got %v; want %v", err, tc.err) + } + }) + } + }) + + t.Run("success", func(t *testing.T) { + oid := objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C} + d128 := decimal.NewDecimal128(10, 20) + want := bson.NewDocument( + bson.EC.Double("a", 3.14159), bson.EC.String("b", "foo"), bson.EC.SubDocumentFromElements("c", bson.EC.Null("aa")), + bson.EC.ArrayFromElements("d", bson.VC.Null()), + bson.EC.BinaryWithSubtype("e", []byte{0x01, 0x02, 0x03}, 0xFF), bson.EC.Undefined("f"), + bson.EC.ObjectID("g", oid), bson.EC.Boolean("h", true), bson.EC.DateTime("i", 1234567890), bson.EC.Null("j"), bson.EC.Regex("k", "foo", "bar"), + bson.EC.DBPointer("l", "foobar", oid), bson.EC.JavaScript("m", "var hello = 'world';"), bson.EC.Symbol("n", "bazqux"), + bson.EC.CodeWithScope("o", "var hello = 'world';", bson.NewDocument(bson.EC.Null("ab"))), bson.EC.Int32("p", 12345), + bson.EC.Timestamp("q", 10, 20), bson.EC.Int64("r", 1234567890), bson.EC.Decimal128("s", d128), bson.EC.MinKey("t"), bson.EC.MaxKey("u"), + ) + got := bson.NewDocument() + ec := EncodeContext{Registry: NewRegistryBuilder().Build()} + err := (DefaultValueEncoders{}).DocumentEncodeValue(ec, newDocumentValueWriter(got), want) + noerr(t, err) + if !got.Equal(want) { + t.Error("Documents do not match") + t.Errorf("\ngot :%v\nwant:%v", got, want) + } + }) + }) + + t.Run("ArrayEncodeValue", func(t *testing.T) { + t.Run("CodecEncodeError", func(t *testing.T) { + val := bool(true) + want := ValueEncoderError{Name: "ArrayEncodeValue", Types: []interface{}{(*bson.Array)(nil)}, Received: val} + got := (DefaultValueEncoders{}).ArrayEncodeValue(EncodeContext{}, nil, val) + if !compareErrors(got, want) { + t.Errorf("Errors do not match. got %v; want %v", got, want) + } + }) + t.Run("WriteArray Error", func(t *testing.T) { + want := errors.New("WriteArray Error") + llvrw := &llValueReaderWriter{ + t: t, + err: want, + errAfter: llvrwWriteArray, + } + got := (DefaultValueEncoders{}).ArrayEncodeValue(EncodeContext{}, llvrw, bson.NewArray()) + if !compareErrors(got, want) { + t.Errorf("Errors do not match. got %v; want %v", got, want) + } + }) + t.Run("encode array errors", func(t *testing.T) { + ec := EncodeContext{} + err := errors.New("encode array error") + oid := objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C} + badval := &bson.Value{} + testCases := []struct { + name string + ec EncodeContext + llvrw *llValueReaderWriter + arr *bson.Array + err error + }{ + { + "WriteDocumentElement", + ec, + &llValueReaderWriter{t: t, err: errors.New("wde error"), errAfter: llvrwWriteArrayElement}, + bson.NewArray(bson.VC.Null()), + errors.New("wde error"), + }, + { + "WriteDouble", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteDouble}, + bson.NewArray(bson.VC.Double(3.14159)), err, + }, + { + "WriteString", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteString}, + bson.NewArray(bson.VC.String("bar")), err, + }, + { + "WriteDocument (Lookup)", EncodeContext{Registry: NewEmptyRegistryBuilder().Build()}, + &llValueReaderWriter{t: t}, + bson.NewArray(bson.VC.Document(bson.NewDocument(bson.EC.Null("bar")))), + ErrNoEncoder{Type: tDocument}, + }, + { + "WriteArray (Lookup)", EncodeContext{Registry: NewEmptyRegistryBuilder().Build()}, + &llValueReaderWriter{t: t}, + bson.NewArray(bson.VC.Array(bson.NewArray(bson.VC.Null()))), + ErrNoEncoder{Type: tArray}, + }, + { + "WriteBinary", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteBinaryWithSubtype}, + bson.NewArray(bson.VC.BinaryWithSubtype([]byte{0x01, 0x02, 0x03}, 0xFF)), err, + }, + { + "WriteUndefined", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteUndefined}, + bson.NewArray(bson.VC.Undefined()), err, + }, + { + "WriteObjectID", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteObjectID}, + bson.NewArray(bson.VC.ObjectID(oid)), err, + }, + { + "WriteBoolean", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteBoolean}, + bson.NewArray(bson.VC.Boolean(true)), err, + }, + { + "WriteDateTime", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteDateTime}, + bson.NewArray(bson.VC.DateTime(1234567890)), err, + }, + { + "WriteNull", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteNull}, + bson.NewArray(bson.VC.Null()), err, + }, + { + "WriteRegex", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteRegex}, + bson.NewArray(bson.VC.Regex("bar", "baz")), err, + }, + { + "WriteDBPointer", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteDBPointer}, + bson.NewArray(bson.VC.DBPointer("bar", oid)), err, + }, + { + "WriteJavascript", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteJavascript}, + bson.NewArray(bson.VC.JavaScript("var hello = 'world';")), err, + }, + { + "WriteSymbol", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteSymbol}, + bson.NewArray(bson.VC.Symbol("symbolbaz")), err, + }, + { + "WriteCodeWithScope (Lookup)", EncodeContext{Registry: NewEmptyRegistryBuilder().Build()}, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteCodeWithScope}, + bson.NewArray(bson.VC.CodeWithScope("var hello = 'world';", bson.NewDocument(bson.EC.Null("bar")))), + err, + }, + { + "WriteInt32", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteInt32}, + bson.NewArray(bson.VC.Int32(12345)), err, + }, + { + "WriteInt64", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteInt64}, + bson.NewArray(bson.VC.Int64(1234567890)), err, + }, + { + "WriteTimestamp", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteTimestamp}, + bson.NewArray(bson.VC.Timestamp(10, 20)), err, + }, + { + "WriteDecimal128", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteDecimal128}, + bson.NewArray(bson.VC.Decimal128(decimal.NewDecimal128(10, 20))), err, + }, + { + "WriteMinKey", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteMinKey}, + bson.NewArray(bson.VC.MinKey()), err, + }, + { + "WriteMaxKey", ec, + &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteMaxKey}, + bson.NewArray(bson.VC.MaxKey()), err, + }, + { + "Invalid Type", ec, + &llValueReaderWriter{t: t, bsontype: bson.Type(0)}, + bson.NewArray(badval), + bson.ErrUninitializedElement, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := (DefaultValueEncoders{}).ArrayEncodeValue(tc.ec, tc.llvrw, tc.arr) + if !compareErrors(err, tc.err) { + t.Errorf("Errors do not match. got %v; want %v", err, tc.err) + } + }) + } + }) + + t.Run("success", func(t *testing.T) { + oid := objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C} + d128 := decimal.NewDecimal128(10, 20) + want := bson.NewArray( + bson.VC.Double(3.14159), bson.VC.String("foo"), bson.VC.DocumentFromElements(bson.EC.Null("aa")), + bson.VC.ArrayFromValues(bson.VC.Null()), + bson.VC.BinaryWithSubtype([]byte{0x01, 0x02, 0x03}, 0xFF), bson.VC.Undefined(), + bson.VC.ObjectID(oid), bson.VC.Boolean(true), bson.VC.DateTime(1234567890), bson.VC.Null(), bson.VC.Regex("foo", "bar"), + bson.VC.DBPointer("foobar", oid), bson.VC.JavaScript("var hello = 'world';"), bson.VC.Symbol("bazqux"), + bson.VC.CodeWithScope("var hello = 'world';", bson.NewDocument(bson.EC.Null("ab"))), bson.VC.Int32(12345), + bson.VC.Timestamp(10, 20), bson.VC.Int64(1234567890), bson.VC.Decimal128(d128), bson.VC.MinKey(), bson.VC.MaxKey(), + ) + + ec := EncodeContext{Registry: NewRegistryBuilder().Build()} + + doc := bson.NewDocument() + dvw := newDocumentValueWriter(doc) + dr, err := dvw.WriteDocument() + noerr(t, err) + vr, err := dr.WriteDocumentElement("foo") + noerr(t, err) + + err = (DefaultValueEncoders{}).ArrayEncodeValue(ec, vr, want) + noerr(t, err) + + got := doc.Lookup("foo").MutableArray() + if !got.Equal(want) { + t.Error("Documents do not match") + t.Errorf("\ngot :%v\nwant:%v", got, want) + } + }) + }) + + t.Run("success path", func(t *testing.T) { + oid := objectid.New() + oids := []objectid.ObjectID{objectid.New(), objectid.New(), objectid.New()} + var str = new(string) + *str = "bar" + now := time.Now().Truncate(time.Millisecond) + murl, err := url.Parse("https://mongodb.com/random-url?hello=world") + if err != nil { + t.Errorf("Error parsing URL: %v", err) + t.FailNow() + } + decimal128, err := decimal.ParseDecimal128("1.5e10") + if err != nil { + t.Errorf("Error parsing decimal128: %v", err) + t.FailNow() + } + + testCases := []struct { + name string + value interface{} + b []byte + err error + }{ + { + "map[string]int", + map[string]int32{"foo": 1}, + []byte{ + 0x0E, 0x00, 0x00, 0x00, + 0x10, 'f', 'o', 'o', 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x00, + }, + nil, + }, + { + "map[string]objectid.ObjectID", + map[string]objectid.ObjectID{"foo": oid}, + docToBytes(bson.NewDocument(bson.EC.ObjectID("foo", oid))), + nil, + }, + { + "map[string][]*Element", + map[string][]*bson.Element{"Z": {bson.EC.Int32("A", 1), bson.EC.Int32("B", 2), bson.EC.Int32("EC", 3)}}, + docToBytes(bson.NewDocument( + bson.EC.SubDocumentFromElements("Z", bson.EC.Int32("A", 1), bson.EC.Int32("B", 2), bson.EC.Int32("EC", 3)), + )), + nil, + }, + { + "map[string][]*Value", + map[string][]*bson.Value{"Z": {bson.VC.Int32(1), bson.VC.Int32(2), bson.VC.Int32(3)}}, + docToBytes(bson.NewDocument( + bson.EC.ArrayFromElements("Z", bson.VC.Int32(1), bson.VC.Int32(2), bson.VC.Int32(3)), + )), + nil, + }, + { + "map[string]*Element", + map[string]*bson.Element{"Z": bson.EC.Int32("Z", 12345)}, + docToBytes(bson.NewDocument( + bson.EC.Int32("Z", 12345), + )), + nil, + }, + { + "map[string]*Document", + map[string]*bson.Document{"Z": bson.NewDocument(bson.EC.Null("foo"))}, + docToBytes(bson.NewDocument( + bson.EC.SubDocumentFromElements("Z", bson.EC.Null("foo")), + )), + nil, + }, + { + "map[string]Reader", + map[string]bson.Reader{"Z": {0x05, 0x00, 0x00, 0x00, 0x00}}, + docToBytes(bson.NewDocument( + bson.EC.SubDocumentFromReader("Z", bson.Reader{0x05, 0x00, 0x00, 0x00, 0x00}), + )), + nil, + }, + { + "map[string][]int32", + map[string][]int32{"Z": {1, 2, 3}}, + docToBytes(bson.NewDocument( + bson.EC.ArrayFromElements("Z", bson.VC.Int32(1), bson.VC.Int32(2), bson.VC.Int32(3)), + )), + nil, + }, + { + "map[string][]objectid.ObjectID", + map[string][]objectid.ObjectID{"Z": oids}, + docToBytes(bson.NewDocument( + bson.EC.ArrayFromElements("Z", bson.VC.ObjectID(oids[0]), bson.VC.ObjectID(oids[1]), bson.VC.ObjectID(oids[2])), + )), + nil, + }, + { + "map[string][]json.Number(int64)", + map[string][]json.Number{"Z": {json.Number("5"), json.Number("10")}}, + docToBytes(bson.NewDocument( + bson.EC.ArrayFromElements("Z", bson.VC.Int64(5), bson.VC.Int64(10)), + )), + nil, + }, + { + "map[string][]json.Number(float64)", + map[string][]json.Number{"Z": {json.Number("5"), json.Number("10.1")}}, + docToBytes(bson.NewDocument( + bson.EC.ArrayFromElements("Z", bson.VC.Int64(5), bson.VC.Double(10.1)), + )), + nil, + }, + { + "map[string][]*url.URL", + map[string][]*url.URL{"Z": {murl}}, + docToBytes(bson.NewDocument( + bson.EC.ArrayFromElements("Z", bson.VC.String(murl.String())), + )), + nil, + }, + { + "map[string][]decimal.Decimal128", + map[string][]decimal.Decimal128{"Z": {decimal128}}, + docToBytes(bson.NewDocument( + bson.EC.ArrayFromElements("Z", bson.VC.Decimal128(decimal128)), + )), + nil, + }, + { + "-", + struct { + A string `bson:"-"` + }{ + A: "", + }, + docToBytes(bson.NewDocument()), + nil, + }, + { + "omitempty", + struct { + A string `bson:",omitempty"` + }{ + A: "", + }, + docToBytes(bson.NewDocument()), + nil, + }, + { + "omitempty, empty time", + struct { + A time.Time `bson:",omitempty"` + }{ + A: time.Time{}, + }, + docToBytes(bson.NewDocument()), + nil, + }, + { + "no private fields", + noPrivateFields{a: "should be empty"}, + docToBytes(bson.NewDocument()), + nil, + }, + { + "minsize", + struct { + A int64 `bson:",minsize"` + }{ + A: 12345, + }, + docToBytes(bson.NewDocument(bson.EC.Int32("a", 12345))), + nil, + }, + { + "inline", + struct { + Foo struct { + A int64 `bson:",minsize"` + } `bson:",inline"` + }{ + Foo: struct { + A int64 `bson:",minsize"` + }{ + A: 12345, + }, + }, + docToBytes(bson.NewDocument(bson.EC.Int32("a", 12345))), + nil, + }, + { + "inline map", + struct { + Foo map[string]string `bson:",inline"` + }{ + Foo: map[string]string{"foo": "bar"}, + }, + docToBytes(bson.NewDocument(bson.EC.String("foo", "bar"))), + nil, + }, + { + "alternate name bson:name", + struct { + A string `bson:"foo"` + }{ + A: "bar", + }, + docToBytes(bson.NewDocument(bson.EC.String("foo", "bar"))), + nil, + }, + { + "alternate name", + struct { + A string `bson:"foo"` + }{ + A: "bar", + }, + docToBytes(bson.NewDocument(bson.EC.String("foo", "bar"))), + nil, + }, + { + "inline, omitempty", + struct { + A string + Foo zeroTest `bson:"omitempty,inline"` + }{ + A: "bar", + Foo: zeroTest{true}, + }, + docToBytes(bson.NewDocument(bson.EC.String("a", "bar"))), + nil, + }, + { + "struct{}", + struct { + A bool + B int32 + C int64 + D uint16 + E uint64 + F float64 + G string + H map[string]string + I []byte + K [2]string + L struct { + M string + } + N *bson.Element + O *bson.Document + P bson.Reader + Q objectid.ObjectID + T []struct{} + Y json.Number + Z time.Time + AA json.Number + AB *url.URL + AC decimal.Decimal128 + AD *time.Time + AE testValueMarshaler + }{ + A: true, + B: 123, + C: 456, + D: 789, + E: 101112, + F: 3.14159, + G: "Hello, world", + H: map[string]string{"foo": "bar"}, + I: []byte{0x01, 0x02, 0x03}, + K: [2]string{"baz", "qux"}, + L: struct { + M string + }{ + M: "foobar", + }, + N: bson.EC.Null("n"), + O: bson.NewDocument(bson.EC.Int64("countdown", 9876543210)), + P: bson.Reader{0x05, 0x00, 0x00, 0x00, 0x00}, + Q: oid, + T: nil, + Y: json.Number("5"), + Z: now, + AA: json.Number("10.1"), + AB: murl, + AC: decimal128, + AD: &now, + AE: testValueMarshaler{t: bson.TypeString, buf: llbson.AppendString(nil, "hello, world")}, + }, + docToBytes(bson.NewDocument( + bson.EC.Boolean("a", true), + bson.EC.Int32("b", 123), + bson.EC.Int64("c", 456), + bson.EC.Int32("d", 789), + bson.EC.Int64("e", 101112), + bson.EC.Double("f", 3.14159), + bson.EC.String("g", "Hello, world"), + bson.EC.SubDocumentFromElements("h", bson.EC.String("foo", "bar")), + bson.EC.Binary("i", []byte{0x01, 0x02, 0x03}), + bson.EC.ArrayFromElements("k", bson.VC.String("baz"), bson.VC.String("qux")), + bson.EC.SubDocumentFromElements("l", bson.EC.String("m", "foobar")), + bson.EC.Null("n"), + bson.EC.SubDocumentFromElements("o", bson.EC.Int64("countdown", 9876543210)), + bson.EC.SubDocumentFromElements("p"), + bson.EC.ObjectID("q", oid), + bson.EC.Null("t"), + bson.EC.Int64("y", 5), + bson.EC.DateTime("z", now.UnixNano()/int64(time.Millisecond)), + bson.EC.Double("aa", 10.1), + bson.EC.String("ab", murl.String()), + bson.EC.Decimal128("ac", decimal128), + bson.EC.DateTime("ad", now.UnixNano()/int64(time.Millisecond)), + bson.EC.String("ae", "hello, world"), + )), + nil, + }, + { + "struct{[]interface{}}", + struct { + A []bool + B []int32 + C []int64 + D []uint16 + E []uint64 + F []float64 + G []string + H []map[string]string + I [][]byte + K [1][2]string + L []struct { + M string + } + N [][]string + O []*bson.Element + P []*bson.Document + Q []bson.Reader + R []objectid.ObjectID + T []struct{} + W []map[string]struct{} + X []map[string]struct{} + Y []map[string]struct{} + Z []time.Time + AA []json.Number + AB []*url.URL + AC []decimal.Decimal128 + AD []*time.Time + AE []testValueMarshaler + }{ + A: []bool{true}, + B: []int32{123}, + C: []int64{456}, + D: []uint16{789}, + E: []uint64{101112}, + F: []float64{3.14159}, + G: []string{"Hello, world"}, + H: []map[string]string{{"foo": "bar"}}, + I: [][]byte{{0x01, 0x02, 0x03}}, + K: [1][2]string{{"baz", "qux"}}, + L: []struct { + M string + }{ + { + M: "foobar", + }, + }, + N: [][]string{{"foo", "bar"}}, + O: []*bson.Element{bson.EC.Null("N")}, + P: []*bson.Document{bson.NewDocument(bson.EC.Int64("countdown", 9876543210))}, + Q: []bson.Reader{{0x05, 0x00, 0x00, 0x00, 0x00}}, + R: oids, + T: nil, + W: nil, + X: []map[string]struct{}{}, // Should be empty BSON Array + Y: []map[string]struct{}{{}}, // Should be BSON array with one element, an empty BSON SubDocument + Z: []time.Time{now, now}, + AA: []json.Number{json.Number("5"), json.Number("10.1")}, + AB: []*url.URL{murl}, + AC: []decimal.Decimal128{decimal128}, + AD: []*time.Time{&now, &now}, + AE: []testValueMarshaler{ + {t: bson.TypeString, buf: llbson.AppendString(nil, "hello")}, + {t: bson.TypeString, buf: llbson.AppendString(nil, "world")}, + }, + }, + docToBytes(bson.NewDocument( + bson.EC.ArrayFromElements("a", bson.VC.Boolean(true)), + bson.EC.ArrayFromElements("b", bson.VC.Int32(123)), + bson.EC.ArrayFromElements("c", bson.VC.Int64(456)), + bson.EC.ArrayFromElements("d", bson.VC.Int32(789)), + bson.EC.ArrayFromElements("e", bson.VC.Int64(101112)), + bson.EC.ArrayFromElements("f", bson.VC.Double(3.14159)), + bson.EC.ArrayFromElements("g", bson.VC.String("Hello, world")), + bson.EC.ArrayFromElements("h", bson.VC.DocumentFromElements(bson.EC.String("foo", "bar"))), + bson.EC.ArrayFromElements("i", bson.VC.Binary([]byte{0x01, 0x02, 0x03})), + bson.EC.ArrayFromElements("k", bson.VC.ArrayFromValues(bson.VC.String("baz"), bson.VC.String("qux"))), + bson.EC.ArrayFromElements("l", bson.VC.DocumentFromElements(bson.EC.String("m", "foobar"))), + bson.EC.ArrayFromElements("n", bson.VC.ArrayFromValues(bson.VC.String("foo"), bson.VC.String("bar"))), + bson.EC.SubDocumentFromElements("o", bson.EC.Null("N")), + bson.EC.ArrayFromElements("p", bson.VC.DocumentFromElements(bson.EC.Int64("countdown", 9876543210))), + bson.EC.ArrayFromElements("q", bson.VC.DocumentFromElements()), + bson.EC.ArrayFromElements("r", bson.VC.ObjectID(oids[0]), bson.VC.ObjectID(oids[1]), bson.VC.ObjectID(oids[2])), + bson.EC.Null("t"), + bson.EC.Null("w"), + bson.EC.Array("x", bson.NewArray()), + bson.EC.ArrayFromElements("y", bson.VC.Document(bson.NewDocument())), + bson.EC.ArrayFromElements("z", bson.VC.DateTime(now.UnixNano()/int64(time.Millisecond)), bson.VC.DateTime(now.UnixNano()/int64(time.Millisecond))), + bson.EC.ArrayFromElements("aa", bson.VC.Int64(5), bson.VC.Double(10.10)), + bson.EC.ArrayFromElements("ab", bson.VC.String(murl.String())), + bson.EC.ArrayFromElements("ac", bson.VC.Decimal128(decimal128)), + bson.EC.ArrayFromElements("ad", bson.VC.DateTime(now.UnixNano()/int64(time.Millisecond)), bson.VC.DateTime(now.UnixNano()/int64(time.Millisecond))), + bson.EC.ArrayFromElements("ae", bson.VC.String("hello"), bson.VC.String("world")), + )), + nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + b := make([]byte, 0, 512) + vw := newValueWriterFromSlice(b) + enc, err := NewEncoder(NewRegistryBuilder().Build(), vw) + noerr(t, err) + err = enc.Encode(tc.value) + if err != tc.err { + t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) + } + b = vw.buf + if diff := cmp.Diff(b, tc.b); diff != "" { + t.Errorf("Bytes written differ: (-got +want)\n%s", diff) + t.Errorf("Bytes\ngot: %v\nwant:%v\n", b, tc.b) + t.Errorf("Readers\ngot: %v\nwant:%v\n", bson.Reader(b), bson.Reader(tc.b)) + } + }) + } + }) +} + +type testValueMarshaler struct { + t bson.Type + buf []byte + err error +} + +func (tvm testValueMarshaler) MarshalBSONValue() (bson.Type, []byte, error) { + return tvm.t, tvm.buf, tvm.err +} diff --git a/bson/document_value_reader.go b/bson/bsoncodec/document_value_reader.go similarity index 78% rename from bson/document_value_reader.go rename to bson/bsoncodec/document_value_reader.go index 05b3635a34..394fa32743 100644 --- a/bson/document_value_reader.go +++ b/bson/bsoncodec/document_value_reader.go @@ -1,18 +1,19 @@ -package bson +package bsoncodec import ( "errors" "fmt" + "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/bson/decimal" "github.com/mongodb/mongo-go-driver/bson/objectid" ) type dvrState struct { mode mode - v *Value - d *Document - a *Array + v *bson.Value + d *bson.Document + a *bson.Array idx uint } @@ -23,14 +24,14 @@ type documentValueReader struct { } // NewDocumentValueReader creates a new ValueReader from the given *Document. -func NewDocumentValueReader(d *Document) (ValueReader, error) { +func NewDocumentValueReader(d *bson.Document) (ValueReader, error) { if d == nil { - return nil, ErrNilDocument + return nil, bson.ErrNilDocument } return newDocumentValueReader(d), nil } -func newDocumentValueReader(d *Document) *documentValueReader { +func newDocumentValueReader(d *bson.Document) *documentValueReader { stack := make([]dvrState, 1, 5) stack[0] = dvrState{ mode: mTopLevel, @@ -42,7 +43,7 @@ func newDocumentValueReader(d *Document) *documentValueReader { } func (dvr *documentValueReader) invalidTransitionErr(destination mode) error { - te := transitionError{ + te := TransitionError{ current: dvr.stack[dvr.frame].mode, destination: destination, } @@ -52,10 +53,10 @@ func (dvr *documentValueReader) invalidTransitionErr(destination mode) error { return te } -func (dvr *documentValueReader) typeError(t Type) error { +func (dvr *documentValueReader) typeError(t bson.Type) error { val := dvr.stack[dvr.frame].v if val == nil { - return fmt.Errorf("positioned on %s, but attempted to read %s", Type(0), t) + return fmt.Errorf("positioned on %s, but attempted to read %s", bson.Type(0), t) } return fmt.Errorf("positioned on %s, but attempted to read %s", val.Type(), t) } @@ -95,7 +96,7 @@ func (dvr *documentValueReader) pop() { } } -func (dvr *documentValueReader) ensureElementValue(t Type, destination mode) error { +func (dvr *documentValueReader) ensureElementValue(t bson.Type, destination mode) error { switch dvr.stack[dvr.frame].mode { case mElement, mValue: if dvr.stack[dvr.frame].v == nil || dvr.stack[dvr.frame].v.Type() != t { @@ -108,15 +109,15 @@ func (dvr *documentValueReader) ensureElementValue(t Type, destination mode) err return nil } -func (dvr *documentValueReader) Type() Type { +func (dvr *documentValueReader) Type() bson.Type { switch dvr.stack[dvr.frame].mode { case mElement, mValue: default: - return Type(0) + return bson.Type(0) } if dvr.stack[dvr.frame].v == nil { - return Type(0) + return bson.Type(0) } return dvr.stack[dvr.frame].v.Type() @@ -133,7 +134,7 @@ func (dvr *documentValueReader) Skip() error { } func (dvr *documentValueReader) ReadArray() (ArrayReader, error) { - if err := dvr.ensureElementValue(TypeArray, mArray); err != nil { + if err := dvr.ensureElementValue(bson.TypeArray, mArray); err != nil { return nil, err } val := dvr.stack[dvr.frame].v @@ -144,7 +145,7 @@ func (dvr *documentValueReader) ReadArray() (ArrayReader, error) { } func (dvr *documentValueReader) ReadBinary() (b []byte, btype byte, err error) { - if err := dvr.ensureElementValue(TypeBinary, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeBinary, 0); err != nil { return nil, 0, err } defer dvr.pop() @@ -154,7 +155,7 @@ func (dvr *documentValueReader) ReadBinary() (b []byte, btype byte, err error) { } func (dvr *documentValueReader) ReadBoolean() (bool, error) { - if err := dvr.ensureElementValue(TypeBoolean, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeBoolean, 0); err != nil { return false, err } defer dvr.pop() @@ -167,8 +168,8 @@ func (dvr *documentValueReader) ReadDocument() (DocumentReader, error) { case mTopLevel: return dvr, nil case mElement, mValue: - if dvr.stack[dvr.frame].v == nil || dvr.stack[dvr.frame].v.Type() != TypeEmbeddedDocument { - return nil, dvr.typeError(TypeEmbeddedDocument) + if dvr.stack[dvr.frame].v == nil || dvr.stack[dvr.frame].v.Type() != bson.TypeEmbeddedDocument { + return nil, dvr.typeError(bson.TypeEmbeddedDocument) } default: return nil, dvr.invalidTransitionErr(mDocument) @@ -182,7 +183,7 @@ func (dvr *documentValueReader) ReadDocument() (DocumentReader, error) { } func (dvr *documentValueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) { - if err := dvr.ensureElementValue(TypeCodeWithScope, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeCodeWithScope, 0); err != nil { return "", nil, err } @@ -193,7 +194,7 @@ func (dvr *documentValueReader) ReadCodeWithScope() (code string, dr DocumentRea } func (dvr *documentValueReader) ReadDBPointer() (ns string, oid objectid.ObjectID, err error) { - if err := dvr.ensureElementValue(TypeDBPointer, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeDBPointer, 0); err != nil { return "", objectid.ObjectID{}, err } defer dvr.pop() @@ -203,7 +204,7 @@ func (dvr *documentValueReader) ReadDBPointer() (ns string, oid objectid.ObjectI } func (dvr *documentValueReader) ReadDateTime() (int64, error) { - if err := dvr.ensureElementValue(TypeDateTime, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeDateTime, 0); err != nil { return 0, err } defer dvr.pop() @@ -212,7 +213,7 @@ func (dvr *documentValueReader) ReadDateTime() (int64, error) { } func (dvr *documentValueReader) ReadDecimal128() (decimal.Decimal128, error) { - if err := dvr.ensureElementValue(TypeDecimal128, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeDecimal128, 0); err != nil { return decimal.Decimal128{}, err } defer dvr.pop() @@ -221,7 +222,7 @@ func (dvr *documentValueReader) ReadDecimal128() (decimal.Decimal128, error) { } func (dvr *documentValueReader) ReadDouble() (float64, error) { - if err := dvr.ensureElementValue(TypeDouble, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeDouble, 0); err != nil { return 0, err } defer dvr.pop() @@ -230,7 +231,7 @@ func (dvr *documentValueReader) ReadDouble() (float64, error) { } func (dvr *documentValueReader) ReadInt32() (int32, error) { - if err := dvr.ensureElementValue(TypeInt32, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeInt32, 0); err != nil { return 0, err } defer dvr.pop() @@ -239,7 +240,7 @@ func (dvr *documentValueReader) ReadInt32() (int32, error) { } func (dvr *documentValueReader) ReadInt64() (int64, error) { - if err := dvr.ensureElementValue(TypeInt64, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeInt64, 0); err != nil { return 0, err } defer dvr.pop() @@ -248,7 +249,7 @@ func (dvr *documentValueReader) ReadInt64() (int64, error) { } func (dvr *documentValueReader) ReadJavascript() (code string, err error) { - if err := dvr.ensureElementValue(TypeJavaScript, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeJavaScript, 0); err != nil { return "", err } defer dvr.pop() @@ -257,7 +258,7 @@ func (dvr *documentValueReader) ReadJavascript() (code string, err error) { } func (dvr *documentValueReader) ReadMaxKey() error { - if err := dvr.ensureElementValue(TypeMaxKey, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeMaxKey, 0); err != nil { return err } defer dvr.pop() @@ -266,7 +267,7 @@ func (dvr *documentValueReader) ReadMaxKey() error { } func (dvr *documentValueReader) ReadMinKey() error { - if err := dvr.ensureElementValue(TypeMinKey, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeMinKey, 0); err != nil { return err } defer dvr.pop() @@ -275,7 +276,7 @@ func (dvr *documentValueReader) ReadMinKey() error { } func (dvr *documentValueReader) ReadNull() error { - if err := dvr.ensureElementValue(TypeNull, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeNull, 0); err != nil { return err } defer dvr.pop() @@ -284,7 +285,7 @@ func (dvr *documentValueReader) ReadNull() error { } func (dvr *documentValueReader) ReadObjectID() (objectid.ObjectID, error) { - if err := dvr.ensureElementValue(TypeObjectID, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeObjectID, 0); err != nil { return objectid.ObjectID{}, err } defer dvr.pop() @@ -293,7 +294,7 @@ func (dvr *documentValueReader) ReadObjectID() (objectid.ObjectID, error) { } func (dvr *documentValueReader) ReadRegex() (pattern string, options string, err error) { - if err := dvr.ensureElementValue(TypeRegex, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeRegex, 0); err != nil { return "", "", err } defer dvr.pop() @@ -303,7 +304,7 @@ func (dvr *documentValueReader) ReadRegex() (pattern string, options string, err } func (dvr *documentValueReader) ReadString() (string, error) { - if err := dvr.ensureElementValue(TypeString, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeString, 0); err != nil { return "", err } defer dvr.pop() @@ -312,7 +313,7 @@ func (dvr *documentValueReader) ReadString() (string, error) { } func (dvr *documentValueReader) ReadSymbol() (symbol string, err error) { - if err := dvr.ensureElementValue(TypeSymbol, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeSymbol, 0); err != nil { return "", err } defer dvr.pop() @@ -321,7 +322,7 @@ func (dvr *documentValueReader) ReadSymbol() (symbol string, err error) { } func (dvr *documentValueReader) ReadTimestamp() (t uint32, i uint32, err error) { - if err := dvr.ensureElementValue(TypeTimestamp, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeTimestamp, 0); err != nil { return 0, 0, err } defer dvr.pop() @@ -331,7 +332,7 @@ func (dvr *documentValueReader) ReadTimestamp() (t uint32, i uint32, err error) } func (dvr *documentValueReader) ReadUndefined() error { - if err := dvr.ensureElementValue(TypeUndefined, 0); err != nil { + if err := dvr.ensureElementValue(bson.TypeUndefined, 0); err != nil { return err } defer dvr.pop() diff --git a/bson/document_value_reader_test.go b/bson/bsoncodec/document_value_reader_test.go similarity index 84% rename from bson/document_value_reader_test.go rename to bson/bsoncodec/document_value_reader_test.go index 11fce95cb5..ebe8e202e4 100644 --- a/bson/document_value_reader_test.go +++ b/bson/bsoncodec/document_value_reader_test.go @@ -1,4 +1,4 @@ -package bson +package bsoncodec import ( "reflect" @@ -6,6 +6,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/bson/decimal" "github.com/mongodb/mongo-go-driver/bson/objectid" ) @@ -14,14 +15,14 @@ func TestBasicDecodeDocumentReader(t *testing.T) { for _, tc := range unmarshalingTestCases { t.Run(tc.name, func(t *testing.T) { got := reflect.New(tc.sType).Interface() - doc, err := ReadDocument(tc.data) + doc, err := bson.ReadDocument(tc.data) noerr(t, err) vr, err := NewDocumentValueReader(doc) noerr(t, err) reg := NewRegistryBuilder().Build() - codec, err := reg.Lookup(reflect.TypeOf(got)) + decoder, err := reg.LookupDecoder(reflect.TypeOf(got)) noerr(t, err) - err = codec.DecodeValue(DecodeContext{Registry: reg}, vr, got) + err = decoder.DecodeValue(DecodeContext{Registry: reg}, vr, got) noerr(t, err) if !reflect.DeepEqual(got, tc.want) { @@ -37,7 +38,7 @@ func TestDocumentValueReader(t *testing.T) { d128 := decimal.NewDecimal128(1, 2) testCases := []struct { name string - v *Value + v *bson.Value fn reflect.Value results []interface{} }{ @@ -45,11 +46,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadBinary/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadBinary), - []interface{}{[]byte(nil), byte(0), (&documentValueReader{stack: []dvrState{{}}}).typeError(TypeBinary)}, + []interface{}{[]byte(nil), byte(0), (&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeBinary)}, }, { "ReadBinary/success", - VC.BinaryWithSubtype([]byte{0x01, 0x02, 0x03}, 0xEA), + bson.VC.BinaryWithSubtype([]byte{0x01, 0x02, 0x03}, 0xEA), reflect.ValueOf((*documentValueReader).ReadBinary), []interface{}{[]byte{0x01, 0x02, 0x03}, byte(0xEA), nil}, }, @@ -57,11 +58,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadBoolean/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadBoolean), - []interface{}{false, (&documentValueReader{stack: []dvrState{{}}}).typeError(TypeBoolean)}, + []interface{}{false, (&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeBoolean)}, }, { "ReadBoolean/success", - VC.Boolean(true), + bson.VC.Boolean(true), reflect.ValueOf((*documentValueReader).ReadBoolean), []interface{}{true, nil}, }, @@ -69,11 +70,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadDBPointer/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadDBPointer), - []interface{}{"", objectid.ObjectID{}, (&documentValueReader{stack: []dvrState{{}}}).typeError(TypeDBPointer)}, + []interface{}{"", objectid.ObjectID{}, (&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeDBPointer)}, }, { "ReadDBPointer/success", - VC.DBPointer("foobar", oid), + bson.VC.DBPointer("foobar", oid), reflect.ValueOf((*documentValueReader).ReadDBPointer), []interface{}{"foobar", oid, nil}, }, @@ -81,11 +82,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadDateTime/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadDateTime), - []interface{}{int64(0), (&documentValueReader{stack: []dvrState{{}}}).typeError(TypeDateTime)}, + []interface{}{int64(0), (&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeDateTime)}, }, { "ReadDateTime/success", - VC.DateTime(now), + bson.VC.DateTime(now), reflect.ValueOf((*documentValueReader).ReadDateTime), []interface{}{now, nil}, }, @@ -93,11 +94,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadDecimal128/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadDecimal128), - []interface{}{decimal.Decimal128{}, (&documentValueReader{stack: []dvrState{{}}}).typeError(TypeDecimal128)}, + []interface{}{decimal.Decimal128{}, (&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeDecimal128)}, }, { "ReadDecimal128/success", - VC.Decimal128(d128), + bson.VC.Decimal128(d128), reflect.ValueOf((*documentValueReader).ReadDecimal128), []interface{}{d128, nil}, }, @@ -105,11 +106,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadDouble/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadDouble), - []interface{}{float64(0), (&documentValueReader{stack: []dvrState{{}}}).typeError(TypeDouble)}, + []interface{}{float64(0), (&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeDouble)}, }, { "ReadDouble/success", - VC.Double(1.2345), + bson.VC.Double(1.2345), reflect.ValueOf((*documentValueReader).ReadDouble), []interface{}{1.2345, nil}, }, @@ -117,11 +118,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadInt32/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadInt32), - []interface{}{int32(0), (&documentValueReader{stack: []dvrState{{}}}).typeError(TypeInt32)}, + []interface{}{int32(0), (&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeInt32)}, }, { "ReadInt32/success", - VC.Int32(12345), + bson.VC.Int32(12345), reflect.ValueOf((*documentValueReader).ReadInt32), []interface{}{int32(12345), nil}, }, @@ -129,11 +130,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadInt64/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadInt64), - []interface{}{int64(0), (&documentValueReader{stack: []dvrState{{}}}).typeError(TypeInt64)}, + []interface{}{int64(0), (&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeInt64)}, }, { "ReadInt64/success", - VC.Int64(1234567890), + bson.VC.Int64(1234567890), reflect.ValueOf((*documentValueReader).ReadInt64), []interface{}{int64(1234567890), nil}, }, @@ -141,11 +142,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadJavascript/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadJavascript), - []interface{}{"", (&documentValueReader{stack: []dvrState{{}}}).typeError(TypeJavaScript)}, + []interface{}{"", (&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeJavaScript)}, }, { "ReadJavascript/success", - VC.JavaScript("var foo = bar;"), + bson.VC.JavaScript("var foo = bar;"), reflect.ValueOf((*documentValueReader).ReadJavascript), []interface{}{"var foo = bar;", nil}, }, @@ -153,11 +154,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadMaxKey/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadMaxKey), - []interface{}{(&documentValueReader{stack: []dvrState{{}}}).typeError(TypeMaxKey)}, + []interface{}{(&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeMaxKey)}, }, { "ReadMaxKey/success", - VC.MaxKey(), + bson.VC.MaxKey(), reflect.ValueOf((*documentValueReader).ReadMaxKey), []interface{}{nil}, }, @@ -165,11 +166,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadMinKey/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadMinKey), - []interface{}{(&documentValueReader{stack: []dvrState{{}}}).typeError(TypeMinKey)}, + []interface{}{(&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeMinKey)}, }, { "ReadMinKey/success", - VC.MinKey(), + bson.VC.MinKey(), reflect.ValueOf((*documentValueReader).ReadMinKey), []interface{}{nil}, }, @@ -177,11 +178,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadNull/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadNull), - []interface{}{(&documentValueReader{stack: []dvrState{{}}}).typeError(TypeNull)}, + []interface{}{(&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeNull)}, }, { "ReadNull/success", - VC.Null(), + bson.VC.Null(), reflect.ValueOf((*documentValueReader).ReadNull), []interface{}{nil}, }, @@ -189,11 +190,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadObjectID/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadObjectID), - []interface{}{objectid.ObjectID{}, (&documentValueReader{stack: []dvrState{{}}}).typeError(TypeObjectID)}, + []interface{}{objectid.ObjectID{}, (&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeObjectID)}, }, { "ReadObjectID/success", - VC.ObjectID(oid), + bson.VC.ObjectID(oid), reflect.ValueOf((*documentValueReader).ReadObjectID), []interface{}{oid, nil}, }, @@ -201,11 +202,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadRegex/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadRegex), - []interface{}{"", "", (&documentValueReader{stack: []dvrState{{}}}).typeError(TypeRegex)}, + []interface{}{"", "", (&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeRegex)}, }, { "ReadRegex/success", - VC.Regex("foo", "bar"), + bson.VC.Regex("foo", "bar"), reflect.ValueOf((*documentValueReader).ReadRegex), []interface{}{"foo", "bar", nil}, }, @@ -213,11 +214,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadString/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadString), - []interface{}{"", (&documentValueReader{stack: []dvrState{{}}}).typeError(TypeString)}, + []interface{}{"", (&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeString)}, }, { "ReadString/success", - VC.String("hello, world!"), + bson.VC.String("hello, world!"), reflect.ValueOf((*documentValueReader).ReadString), []interface{}{"hello, world!", nil}, }, @@ -225,11 +226,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadSymbol/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadSymbol), - []interface{}{"", (&documentValueReader{stack: []dvrState{{}}}).typeError(TypeSymbol)}, + []interface{}{"", (&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeSymbol)}, }, { "ReadSymbol/success", - VC.Symbol("hello, world!"), + bson.VC.Symbol("hello, world!"), reflect.ValueOf((*documentValueReader).ReadSymbol), []interface{}{"hello, world!", nil}, }, @@ -237,11 +238,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadTimestamp/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadTimestamp), - []interface{}{uint32(0), uint32(0), (&documentValueReader{stack: []dvrState{{}}}).typeError(TypeTimestamp)}, + []interface{}{uint32(0), uint32(0), (&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeTimestamp)}, }, { "ReadTimestamp/success", - VC.Timestamp(10, 20), + bson.VC.Timestamp(10, 20), reflect.ValueOf((*documentValueReader).ReadTimestamp), []interface{}{uint32(10), uint32(20), nil}, }, @@ -249,11 +250,11 @@ func TestDocumentValueReader(t *testing.T) { "ReadUndefined/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadUndefined), - []interface{}{(&documentValueReader{stack: []dvrState{{}}}).typeError(TypeUndefined)}, + []interface{}{(&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeUndefined)}, }, { "ReadUndefined/success", - VC.Undefined(), + bson.VC.Undefined(), reflect.ValueOf((*documentValueReader).ReadUndefined), []interface{}{nil}, }, @@ -261,7 +262,7 @@ func TestDocumentValueReader(t *testing.T) { "ReadCodeWithScope/incorrect type", nil, reflect.ValueOf((*documentValueReader).ReadCodeWithScope), - []interface{}{"", nil, (&documentValueReader{stack: []dvrState{{}}}).typeError(TypeCodeWithScope)}, + []interface{}{"", nil, (&documentValueReader{stack: []dvrState{{}}}).typeError(bson.TypeCodeWithScope)}, }, } @@ -294,9 +295,9 @@ func TestDocumentValueReader(t *testing.T) { t.Run("ReadCodeWithScope/success", func(t *testing.T) { wantkey, wantcode := "foo", "var hello = world;" - doc := NewDocument( - EC.CodeWithScope(wantkey, wantcode, - NewDocument(EC.Undefined("bar")), + doc := bson.NewDocument( + bson.EC.CodeWithScope(wantkey, wantcode, + bson.NewDocument(bson.EC.Undefined("bar")), )) dvr := &documentValueReader{ stack: []dvrState{ @@ -333,7 +334,7 @@ func TestDocumentValueReader(t *testing.T) { t.Run("Skip/success", func(t *testing.T) { firstkey, secondkey := "foo", "baz" - doc := NewDocument(EC.String(firstkey, "bar"), EC.Null(secondkey)) + doc := bson.NewDocument(bson.EC.String(firstkey, "bar"), bson.EC.Null(secondkey)) dvr, err := NewDocumentValueReader(doc) noerr(t, err) dr, err := dvr.ReadDocument() diff --git a/bson/document_value_writer.go b/bson/bsoncodec/document_value_writer.go similarity index 77% rename from bson/document_value_writer.go rename to bson/bsoncodec/document_value_writer.go index b2471a95ce..98f0b353fa 100644 --- a/bson/document_value_writer.go +++ b/bson/bsoncodec/document_value_writer.go @@ -1,10 +1,11 @@ -package bson +package bsoncodec import ( "errors" "fmt" "sync" + "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/bson/decimal" "github.com/mongodb/mongo-go-driver/bson/objectid" ) @@ -25,19 +26,19 @@ type documentValueWriter struct { type dvwState struct { mode mode - d *Document + d *bson.Document key string } // NewDocumentValueWriter creates a ValueWriter that adds elements to d. -func NewDocumentValueWriter(d *Document) (ValueWriter, error) { +func NewDocumentValueWriter(d *bson.Document) (ValueWriter, error) { if d == nil { return nil, errors.New("cannot create ValueWriter with nil *Document") } return newDocumentValueWriter(d), nil } -func newDocumentValueWriter(d *Document) *documentValueWriter { +func newDocumentValueWriter(d *bson.Document) *documentValueWriter { // We start the document value writer in the mode of writing a document dvw := new(documentValueWriter) stack := make([]dvwState, 1, 5) @@ -47,7 +48,7 @@ func newDocumentValueWriter(d *Document) *documentValueWriter { return dvw } -func (dvw *documentValueWriter) reset(d *Document) { +func (dvw *documentValueWriter) reset(d *bson.Document) { dvw.stack = dvw.stack[:1] dvw.frame = 0 dvw.stack[0] = dvwState{mode: mTopLevel, d: d} @@ -87,7 +88,7 @@ func (dvw *documentValueWriter) pop() { } func (dvw *documentValueWriter) invalidTransitionError(destination mode) error { - te := transitionError{ + te := TransitionError{ current: dvw.stack[dvw.frame].mode, destination: destination, } @@ -113,9 +114,9 @@ func (dvw *documentValueWriter) WriteArray() (ArrayWriter, error) { return nil, dvw.invalidTransitionError(mArray) } - d := NewDocument() - arr := ArrayFromDocument(d) - dvw.stack[dvw.frame].d.Append(EC.Array(dvw.stack[dvw.frame].key, arr)) + d := bson.NewDocument() + arr := bson.ArrayFromDocument(d) + dvw.stack[dvw.frame].d.Append(bson.EC.Array(dvw.stack[dvw.frame].key, arr)) dvw.push(mArray) dvw.stack[dvw.frame].d = d return dvw, nil @@ -127,7 +128,7 @@ func (dvw *documentValueWriter) WriteBinary(b []byte) error { } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.Binary(dvw.stack[dvw.frame].key, b)) + dvw.stack[dvw.frame].d.Append(bson.EC.Binary(dvw.stack[dvw.frame].key, b)) return nil } @@ -137,7 +138,7 @@ func (dvw *documentValueWriter) WriteBinaryWithSubtype(b []byte, btype byte) err } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.BinaryWithSubtype(dvw.stack[dvw.frame].key, b, btype)) + dvw.stack[dvw.frame].d.Append(bson.EC.BinaryWithSubtype(dvw.stack[dvw.frame].key, b, btype)) return nil } @@ -147,7 +148,7 @@ func (dvw *documentValueWriter) WriteBoolean(b bool) error { } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.Boolean(dvw.stack[dvw.frame].key, b)) + dvw.stack[dvw.frame].d.Append(bson.EC.Boolean(dvw.stack[dvw.frame].key, b)) return nil } @@ -158,8 +159,8 @@ func (dvw *documentValueWriter) WriteCodeWithScope(code string) (DocumentWriter, return nil, dvw.invalidTransitionError(mCodeWithScope) } - scope := NewDocument() - dvw.stack[dvw.frame].d.Append(EC.CodeWithScope(dvw.stack[dvw.frame].key, code, scope)) + scope := bson.NewDocument() + dvw.stack[dvw.frame].d.Append(bson.EC.CodeWithScope(dvw.stack[dvw.frame].key, code, scope)) dvw.push(mDocument) dvw.stack[dvw.frame].d = scope return dvw, nil @@ -171,7 +172,7 @@ func (dvw *documentValueWriter) WriteDBPointer(ns string, oid objectid.ObjectID) } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.DBPointer(dvw.stack[dvw.frame].key, ns, oid)) + dvw.stack[dvw.frame].d.Append(bson.EC.DBPointer(dvw.stack[dvw.frame].key, ns, oid)) return nil } @@ -181,7 +182,7 @@ func (dvw *documentValueWriter) WriteDateTime(dt int64) error { } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.DateTime(dvw.stack[dvw.frame].key, dt)) + dvw.stack[dvw.frame].d.Append(bson.EC.DateTime(dvw.stack[dvw.frame].key, dt)) return nil } @@ -191,7 +192,7 @@ func (dvw *documentValueWriter) WriteDecimal128(d128 decimal.Decimal128) error { } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.Decimal128(dvw.stack[dvw.frame].key, d128)) + dvw.stack[dvw.frame].d.Append(bson.EC.Decimal128(dvw.stack[dvw.frame].key, d128)) return nil } @@ -201,7 +202,7 @@ func (dvw *documentValueWriter) WriteDouble(f64 float64) error { } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.Double(dvw.stack[dvw.frame].key, f64)) + dvw.stack[dvw.frame].d.Append(bson.EC.Double(dvw.stack[dvw.frame].key, f64)) return nil } @@ -211,7 +212,7 @@ func (dvw *documentValueWriter) WriteInt32(i32 int32) error { } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.Int32(dvw.stack[dvw.frame].key, i32)) + dvw.stack[dvw.frame].d.Append(bson.EC.Int32(dvw.stack[dvw.frame].key, i32)) return nil } @@ -221,7 +222,7 @@ func (dvw *documentValueWriter) WriteInt64(i64 int64) error { } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.Int64(dvw.stack[dvw.frame].key, i64)) + dvw.stack[dvw.frame].d.Append(bson.EC.Int64(dvw.stack[dvw.frame].key, i64)) return nil } @@ -231,7 +232,7 @@ func (dvw *documentValueWriter) WriteJavascript(code string) error { } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.JavaScript(dvw.stack[dvw.frame].key, code)) + dvw.stack[dvw.frame].d.Append(bson.EC.JavaScript(dvw.stack[dvw.frame].key, code)) return nil } @@ -241,7 +242,7 @@ func (dvw *documentValueWriter) WriteMaxKey() error { } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.MaxKey(dvw.stack[dvw.frame].key)) + dvw.stack[dvw.frame].d.Append(bson.EC.MaxKey(dvw.stack[dvw.frame].key)) return nil } @@ -251,7 +252,7 @@ func (dvw *documentValueWriter) WriteMinKey() error { } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.MinKey(dvw.stack[dvw.frame].key)) + dvw.stack[dvw.frame].d.Append(bson.EC.MinKey(dvw.stack[dvw.frame].key)) return nil } @@ -261,7 +262,7 @@ func (dvw *documentValueWriter) WriteNull() error { } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.Null(dvw.stack[dvw.frame].key)) + dvw.stack[dvw.frame].d.Append(bson.EC.Null(dvw.stack[dvw.frame].key)) return nil } @@ -271,7 +272,7 @@ func (dvw *documentValueWriter) WriteObjectID(oid objectid.ObjectID) error { } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.ObjectID(dvw.stack[dvw.frame].key, oid)) + dvw.stack[dvw.frame].d.Append(bson.EC.ObjectID(dvw.stack[dvw.frame].key, oid)) return nil } @@ -281,7 +282,7 @@ func (dvw *documentValueWriter) WriteRegex(pattern string, options string) error } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.Regex(dvw.stack[dvw.frame].key, pattern, options)) + dvw.stack[dvw.frame].d.Append(bson.EC.Regex(dvw.stack[dvw.frame].key, pattern, options)) return nil } @@ -291,7 +292,7 @@ func (dvw *documentValueWriter) WriteString(str string) error { } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.String(dvw.stack[dvw.frame].key, str)) + dvw.stack[dvw.frame].d.Append(bson.EC.String(dvw.stack[dvw.frame].key, str)) return nil } @@ -306,8 +307,8 @@ func (dvw *documentValueWriter) WriteDocument() (DocumentWriter, error) { return nil, dvw.invalidTransitionError(mDocument) } - d := NewDocument() - dvw.stack[dvw.frame].d.Append(EC.SubDocument(dvw.stack[dvw.frame].key, d)) + d := bson.NewDocument() + dvw.stack[dvw.frame].d.Append(bson.EC.SubDocument(dvw.stack[dvw.frame].key, d)) dvw.push(mDocument) dvw.stack[dvw.frame].d = d return dvw, nil @@ -319,7 +320,7 @@ func (dvw *documentValueWriter) WriteSymbol(symbol string) error { } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.Symbol(dvw.stack[dvw.frame].key, symbol)) + dvw.stack[dvw.frame].d.Append(bson.EC.Symbol(dvw.stack[dvw.frame].key, symbol)) return nil } @@ -329,7 +330,7 @@ func (dvw *documentValueWriter) WriteTimestamp(t uint32, i uint32) error { } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.Timestamp(dvw.stack[dvw.frame].key, t, i)) + dvw.stack[dvw.frame].d.Append(bson.EC.Timestamp(dvw.stack[dvw.frame].key, t, i)) return nil } @@ -339,7 +340,7 @@ func (dvw *documentValueWriter) WriteUndefined() error { } defer dvw.pop() - dvw.stack[dvw.frame].d.Append(EC.Undefined(dvw.stack[dvw.frame].key)) + dvw.stack[dvw.frame].d.Append(bson.EC.Undefined(dvw.stack[dvw.frame].key)) return nil } diff --git a/bson/document_value_writer_test.go b/bson/bsoncodec/document_value_writer_test.go similarity index 75% rename from bson/document_value_writer_test.go rename to bson/bsoncodec/document_value_writer_test.go index 7c5d56741b..5e50a2d4fc 100644 --- a/bson/document_value_writer_test.go +++ b/bson/bsoncodec/document_value_writer_test.go @@ -1,10 +1,11 @@ -package bson +package bsoncodec import ( "fmt" "reflect" "testing" + "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/bson/decimal" "github.com/mongodb/mongo-go-driver/bson/objectid" ) @@ -12,7 +13,7 @@ import ( func noerr(t *testing.T, err error) { if err != nil { t.Helper() - t.Errorf("Unepexted error: %v", err) + t.Errorf("Unexpected error: (%T)%v", err, err) t.FailNow() } } @@ -23,127 +24,127 @@ func TestDocumentValueWriter(t *testing.T) { name string fn interface{} params []interface{} - want *Document + want *bson.Document }{ { "WriteBinary", (*documentValueWriter).WriteBinary, []interface{}{[]byte{0x01, 0x02, 0x03}}, - NewDocument(EC.Binary("foo", []byte{0x01, 0x02, 0x03})), + bson.NewDocument(bson.EC.Binary("foo", []byte{0x01, 0x02, 0x03})), }, { "WriteBinaryWithSubtype (not 0x02)", (*documentValueWriter).WriteBinaryWithSubtype, []interface{}{[]byte{0x01, 0x02, 0x03}, byte(0xFF)}, - NewDocument(EC.BinaryWithSubtype("foo", []byte{0x01, 0x02, 0x03}, 0xFF)), + bson.NewDocument(bson.EC.BinaryWithSubtype("foo", []byte{0x01, 0x02, 0x03}, 0xFF)), }, { "WriteBinaryWithSubtype (0x02)", (*documentValueWriter).WriteBinaryWithSubtype, []interface{}{[]byte{0x01, 0x02, 0x03}, byte(0x02)}, - NewDocument(EC.BinaryWithSubtype("foo", []byte{0x01, 0x02, 0x03}, 0x02)), + bson.NewDocument(bson.EC.BinaryWithSubtype("foo", []byte{0x01, 0x02, 0x03}, 0x02)), }, { "WriteBoolean", (*documentValueWriter).WriteBoolean, []interface{}{true}, - NewDocument(EC.Boolean("foo", true)), + bson.NewDocument(bson.EC.Boolean("foo", true)), }, { "WriteDBPointer", (*documentValueWriter).WriteDBPointer, []interface{}{"bar", oid}, - NewDocument(EC.DBPointer("foo", "bar", oid)), + bson.NewDocument(bson.EC.DBPointer("foo", "bar", oid)), }, { "WriteDateTime", (*documentValueWriter).WriteDateTime, []interface{}{int64(12345678)}, - NewDocument(EC.DateTime("foo", 12345678)), + bson.NewDocument(bson.EC.DateTime("foo", 12345678)), }, { "WriteDecimal128", (*documentValueWriter).WriteDecimal128, []interface{}{decimal.NewDecimal128(10, 20)}, - NewDocument(EC.Decimal128("foo", decimal.NewDecimal128(10, 20))), + bson.NewDocument(bson.EC.Decimal128("foo", decimal.NewDecimal128(10, 20))), }, { "WriteDouble", (*documentValueWriter).WriteDouble, []interface{}{float64(3.14159)}, - NewDocument(EC.Double("foo", 3.14159)), + bson.NewDocument(bson.EC.Double("foo", 3.14159)), }, { "WriteInt32", (*documentValueWriter).WriteInt32, []interface{}{int32(123456)}, - NewDocument(EC.Int32("foo", 123456)), + bson.NewDocument(bson.EC.Int32("foo", 123456)), }, { "WriteInt64", (*documentValueWriter).WriteInt64, []interface{}{int64(1234567890)}, - NewDocument(EC.Int64("foo", 1234567890)), + bson.NewDocument(bson.EC.Int64("foo", 1234567890)), }, { "WriteJavascript", (*documentValueWriter).WriteJavascript, []interface{}{"var foo = 'bar';"}, - NewDocument(EC.JavaScript("foo", "var foo = 'bar';")), + bson.NewDocument(bson.EC.JavaScript("foo", "var foo = 'bar';")), }, { "WriteMaxKey", (*documentValueWriter).WriteMaxKey, []interface{}{}, - NewDocument(EC.MaxKey("foo")), + bson.NewDocument(bson.EC.MaxKey("foo")), }, { "WriteMinKey", (*documentValueWriter).WriteMinKey, []interface{}{}, - NewDocument(EC.MinKey("foo")), + bson.NewDocument(bson.EC.MinKey("foo")), }, { "WriteNull", (*documentValueWriter).WriteNull, []interface{}{}, - NewDocument(EC.Null("foo")), + bson.NewDocument(bson.EC.Null("foo")), }, { "WriteObjectID", (*documentValueWriter).WriteObjectID, []interface{}{oid}, - NewDocument(EC.ObjectID("foo", oid)), + bson.NewDocument(bson.EC.ObjectID("foo", oid)), }, { "WriteRegex", (*documentValueWriter).WriteRegex, []interface{}{"bar", "baz"}, - NewDocument(EC.Regex("foo", "bar", "baz")), + bson.NewDocument(bson.EC.Regex("foo", "bar", "baz")), }, { "WriteString", (*documentValueWriter).WriteString, []interface{}{"hello, world!"}, - NewDocument(EC.String("foo", "hello, world!")), + bson.NewDocument(bson.EC.String("foo", "hello, world!")), }, { "WriteSymbol", (*documentValueWriter).WriteSymbol, []interface{}{"symbollolz"}, - NewDocument(EC.Symbol("foo", "symbollolz")), + bson.NewDocument(bson.EC.Symbol("foo", "symbollolz")), }, { "WriteTimestamp", (*documentValueWriter).WriteTimestamp, []interface{}{uint32(10), uint32(20)}, - NewDocument(EC.Timestamp("foo", 10, 20)), + bson.NewDocument(bson.EC.Timestamp("foo", 10, 20)), }, { "WriteUndefined", (*documentValueWriter).WriteUndefined, []interface{}{}, - NewDocument(EC.Undefined("foo")), + bson.NewDocument(bson.EC.Undefined("foo")), }, } @@ -160,7 +161,7 @@ func TestDocumentValueWriter(t *testing.T) { t.Fatalf("fn must have one return value and it must be an error.") } params := make([]reflect.Value, 1, len(tc.params)+1) - got := NewDocument() + got := bson.NewDocument() dvw := newDocumentValueWriter(got) params[0] = reflect.ValueOf(dvw) for _, param := range tc.params { @@ -184,10 +185,10 @@ func TestDocumentValueWriter(t *testing.T) { } t.Run("incorrect transition", func(t *testing.T) { - dvw = newDocumentValueWriter(NewDocument()) + dvw = newDocumentValueWriter(bson.NewDocument()) results := fn.Call(params) got := results[0].Interface().(error) - want := transitionError{current: mTopLevel} + want := TransitionError{current: mTopLevel} if !compareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } @@ -196,43 +197,43 @@ func TestDocumentValueWriter(t *testing.T) { } t.Run("WriteArray", func(t *testing.T) { - dvw := newDocumentValueWriter(NewDocument()) + dvw := newDocumentValueWriter(bson.NewDocument()) dvw.push(mArray) - want := transitionError{current: mArray, destination: mArray, parent: mTopLevel} + want := TransitionError{current: mArray, destination: mArray, parent: mTopLevel} _, got := dvw.WriteArray() if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } }) t.Run("WriteCodeWithScope", func(t *testing.T) { - dvw := newDocumentValueWriter(NewDocument()) + dvw := newDocumentValueWriter(bson.NewDocument()) dvw.push(mArray) - want := transitionError{current: mArray, destination: mCodeWithScope, parent: mTopLevel} + want := TransitionError{current: mArray, destination: mCodeWithScope, parent: mTopLevel} _, got := dvw.WriteCodeWithScope("") if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } }) t.Run("WriteDocument", func(t *testing.T) { - dvw := newDocumentValueWriter(NewDocument()) + dvw := newDocumentValueWriter(bson.NewDocument()) dvw.push(mArray) - want := transitionError{current: mArray, destination: mDocument, parent: mTopLevel} + want := TransitionError{current: mArray, destination: mDocument, parent: mTopLevel} _, got := dvw.WriteDocument() if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } }) t.Run("WriteDocumentElement", func(t *testing.T) { - dvw := newDocumentValueWriter(NewDocument()) + dvw := newDocumentValueWriter(bson.NewDocument()) dvw.push(mElement) - want := transitionError{current: mElement, destination: mElement, parent: mTopLevel} + want := TransitionError{current: mElement, destination: mElement, parent: mTopLevel} _, got := dvw.WriteDocumentElement("") if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } }) t.Run("WriteDocumentEnd", func(t *testing.T) { - dvw := newDocumentValueWriter(NewDocument()) + dvw := newDocumentValueWriter(bson.NewDocument()) dvw.push(mElement) want := fmt.Errorf("incorrect mode to end document: %s", mElement) got := dvw.WriteDocumentEnd() @@ -241,16 +242,16 @@ func TestDocumentValueWriter(t *testing.T) { } }) t.Run("WriteArrayElement", func(t *testing.T) { - dvw := newDocumentValueWriter(NewDocument()) + dvw := newDocumentValueWriter(bson.NewDocument()) dvw.push(mElement) - want := transitionError{current: mElement, destination: mValue, parent: mTopLevel} + want := TransitionError{current: mElement, destination: mValue, parent: mTopLevel} _, got := dvw.WriteArrayElement() if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } }) t.Run("WriteArrayEnd", func(t *testing.T) { - dvw := newDocumentValueWriter(NewDocument()) + dvw := newDocumentValueWriter(bson.NewDocument()) dvw.push(mElement) want := fmt.Errorf("incorrect mode to end array: %s", mElement) got := dvw.WriteArrayEnd() @@ -265,33 +266,33 @@ func TestDocumentValueWriterPublicAPI(t *testing.T) { testCases := []struct { name string fn func(*testing.T, *documentValueWriter) - want *Document + want *bson.Document }{ { "simple document", dvwBasicDoc, - NewDocument(EC.Boolean("foo", true)), + bson.NewDocument(bson.EC.Boolean("foo", true)), }, { "nested document", dvwNestedDoc, - NewDocument(EC.SubDocumentFromElements("foo", EC.Boolean("bar", true)), EC.Boolean("baz", true)), + bson.NewDocument(bson.EC.SubDocumentFromElements("foo", bson.EC.Boolean("bar", true)), bson.EC.Boolean("baz", true)), }, { "simple array", dvwBasicArray, - NewDocument(EC.ArrayFromElements("foo", VC.Boolean(true))), + bson.NewDocument(bson.EC.ArrayFromElements("foo", bson.VC.Boolean(true))), }, { "code with scope", dvwCodeWithScopeNoNested, - NewDocument(EC.CodeWithScope("foo", "var hello = world;", NewDocument(EC.Boolean("bar", false)))), + bson.NewDocument(bson.EC.CodeWithScope("foo", "var hello = world;", bson.NewDocument(bson.EC.Boolean("bar", false)))), }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - got := NewDocument() + got := bson.NewDocument() dvw := newDocumentValueWriter(got) tc.fn(t, dvw) if !got.Equal(tc.want) { diff --git a/bson/encoder.go b/bson/bsoncodec/encoder.go similarity index 55% rename from bson/encoder.go rename to bson/bsoncodec/encoder.go index 90e8be17f4..89292c324b 100644 --- a/bson/encoder.go +++ b/bson/bsoncodec/encoder.go @@ -1,4 +1,4 @@ -package bson +package bsoncodec import ( "errors" @@ -11,18 +11,18 @@ import ( // must have both Reset and SetRegistry called on them. var encPool = sync.Pool{ New: func() interface{} { - return new(Encoderv2) + return new(Encoder) }, } -// An Encoderv2 writes a serialization format to an output stream. -type Encoderv2 struct { +// An Encoder writes a serialization format to an output stream. +type Encoder struct { r *Registry vw ValueWriter } -// NewEncoderv2 returns a new encoder that uses Registry r to write to w. -func NewEncoderv2(r *Registry, vw ValueWriter) (*Encoderv2, error) { +// NewEncoder returns a new encoder that uses Registry r to write to w. +func NewEncoder(r *Registry, vw ValueWriter) (*Encoder, error) { if r == nil { return nil, errors.New("cannot create a new Encoder with a nil Registry") } @@ -30,7 +30,7 @@ func NewEncoderv2(r *Registry, vw ValueWriter) (*Encoderv2, error) { return nil, errors.New("cannot create a new Encoder with a nil ValueWriter") } - return &Encoderv2{ + return &Encoder{ r: r, vw: vw, }, nil @@ -40,24 +40,32 @@ func NewEncoderv2(r *Registry, vw ValueWriter) (*Encoderv2, error) { // // The documentation for Marshal contains details about the conversion of Go // values to BSON. -func (e *Encoderv2) Encode(val interface{}) error { - // TODO: Add checking to see if val is an allowable type - codec, err := e.r.Lookup(reflect.TypeOf(val)) +func (e *Encoder) Encode(val interface{}) error { + if marshaler, ok := val.(Marshaler); ok { + // TODO(skriptble): Should we have a MarshalAppender interface so that we can have []byte reuse? + buf, err := marshaler.MarshalBSON() + if err != nil { + return err + } + return Copier{r: e.r}.CopyDocumentFromBytes(e.vw, buf) + } + + encoder, err := e.r.LookupEncoder(reflect.TypeOf(val)) if err != nil { return err } - return codec.EncodeValue(EncodeContext{Registry: e.r}, e.vw, val) + return encoder.EncodeValue(EncodeContext{Registry: e.r}, e.vw, val) } // Reset will reset the state of the encoder, using the same *Registry used in // the original construction but using vw. -func (e *Encoderv2) Reset(vw ValueWriter) error { +func (e *Encoder) Reset(vw ValueWriter) error { e.vw = vw return nil } // SetRegistry replaces the current registry of the encoder with r. -func (e *Encoderv2) SetRegistry(r *Registry) error { +func (e *Encoder) SetRegistry(r *Registry) error { e.r = r return nil } diff --git a/bson/bsoncodec/encoder_test.go b/bson/bsoncodec/encoder_test.go new file mode 100644 index 0000000000..49add6b596 --- /dev/null +++ b/bson/bsoncodec/encoder_test.go @@ -0,0 +1,96 @@ +package bsoncodec + +import ( + "bytes" + "errors" + "testing" + + "github.com/mongodb/mongo-go-driver/bson" +) + +func TestEncoderEncode(t *testing.T) { + for _, tc := range marshalingTestCases { + t.Run(tc.name, func(t *testing.T) { + got := make(writer, 0, 1024) + vw := newValueWriter(&got) + reg := NewRegistryBuilder().Build() + enc, err := NewEncoder(reg, vw) + noerr(t, err) + err = enc.Encode(tc.val) + noerr(t, err) + + if !bytes.Equal(got, tc.want) { + t.Errorf("Bytes are not equal. got %v; want %v", bson.Reader(got), bson.Reader(tc.want)) + t.Errorf("Bytes:\n%v\n%v", got, tc.want) + } + }) + } + + t.Run("Marshaler", func(t *testing.T) { + testCases := []struct { + name string + buf []byte + err error + wanterr error + vw ValueWriter + }{ + { + "error", + nil, + errors.New("Marshaler error"), + errors.New("Marshaler error"), + &llValueReaderWriter{}, + }, + { + "copy error", + []byte{0x05, 0x00, 0x00, 0x00, 0x00}, + nil, + errors.New("copy error"), + &llValueReaderWriter{err: errors.New("copy error"), errAfter: llvrwWriteDocument}, + }, + { + "success", + []byte{0x07, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00}, + nil, + nil, + nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + marshaler := testMarshaler{buf: tc.buf, err: tc.err} + + var vw ValueWriter + compareVW := false + if tc.vw != nil { + vw = tc.vw + } else { + compareVW = true + vw = newValueWriterFromSlice([]byte{}) + } + + enc, err := NewEncoder(defaultRegistry, vw) + noerr(t, err) + got := enc.Encode(marshaler) + want := tc.wanterr + if !compareErrors(got, want) { + t.Errorf("Did not receive expected error. got %v; want %v", got, want) + } + if compareVW { + buf := vw.(*valueWriter).buf + if !bytes.Equal(buf, tc.buf) { + t.Errorf("Copied bytes do not match. got %v; want %v", buf, tc.buf) + } + } + }) + } + }) +} + +type testMarshaler struct { + buf []byte + err error +} + +func (tm testMarshaler) MarshalBSON() ([]byte, error) { return tm.buf, tm.err } diff --git a/bson/extjson_reader.go b/bson/bsoncodec/extjson_reader.go similarity index 95% rename from bson/extjson_reader.go rename to bson/bsoncodec/extjson_reader.go index 5e5b99b686..aaa003e2c2 100644 --- a/bson/extjson_reader.go +++ b/bson/bsoncodec/extjson_reader.go @@ -1,8 +1,9 @@ -package bson +package bsoncodec import ( "io" + "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/bson/decimal" "github.com/mongodb/mongo-go-driver/bson/objectid" ) @@ -12,7 +13,7 @@ type extJSONValueReader struct{} func newExtJSONValueReader(io.Reader) *extJSONValueReader { return nil } -func (ejvr *extJSONValueReader) Type() Type { +func (ejvr *extJSONValueReader) Type() bson.Type { panic("not implemented") } diff --git a/bson/extjson_writer.go b/bson/bsoncodec/extjson_writer.go similarity index 99% rename from bson/extjson_writer.go rename to bson/bsoncodec/extjson_writer.go index 9b9141108c..37f292c71d 100644 --- a/bson/extjson_writer.go +++ b/bson/bsoncodec/extjson_writer.go @@ -1,4 +1,4 @@ -package bson +package bsoncodec import ( "errors" diff --git a/bson/internal_reader.go b/bson/bsoncodec/internal_reader.go similarity index 77% rename from bson/internal_reader.go rename to bson/bsoncodec/internal_reader.go index 413e298810..adb3421e33 100644 --- a/bson/internal_reader.go +++ b/bson/bsoncodec/internal_reader.go @@ -1,6 +1,7 @@ -package bson +package bsoncodec import ( + "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/bson/decimal" "github.com/mongodb/mongo-go-driver/bson/objectid" ) @@ -21,7 +22,7 @@ type DocumentReader interface { // is implemented by several types with different underlying representations of // BSON, such as a bson.Document, raw BSON bytes, or extended JSON. type ValueReader interface { - Type() Type + Type() bson.Type Skip() error ReadArray() (ArrayReader, error) @@ -47,21 +48,11 @@ type ValueReader interface { ReadUndefined() error } -// ElementAsBSON will retrieve the next value from vr and return it as a kind -// byte and the value as a slice of bytes. -func ElementAsBSON(vr ValueReader) (kind byte, data []byte, err error) { - return -} - -type reader struct { - b []byte - idx int64 -} - -func (r *reader) Read(p []byte) (int, error) { - return 0, nil -} - -func (r *reader) ReadAt(p []byte, off int64) (int, error) { - return 0, nil +// BytesReader is a generic interface used to read BSON bytes from a +// ValueReader. This imterface is meant to be a superset of ValueReader, so that +// types that implement ValueReader may also implement this interface. +// +// The bytes of the value will be appended to dst. +type BytesReader interface { + ReadValueBytes(dst []byte) (bson.Type, []byte, error) } diff --git a/bson/llvalue_reader_writer_test.go b/bson/bsoncodec/llvalue_reader_writer_test.go similarity index 97% rename from bson/llvalue_reader_writer_test.go rename to bson/bsoncodec/llvalue_reader_writer_test.go index d25003d022..03a0041663 100644 --- a/bson/llvalue_reader_writer_test.go +++ b/bson/bsoncodec/llvalue_reader_writer_test.go @@ -1,8 +1,9 @@ -package bson +package bsoncodec import ( "testing" + "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/bson/decimal" "github.com/mongodb/mongo-go-driver/bson/objectid" ) @@ -66,12 +67,12 @@ type llValueReaderWriter struct { t *testing.T invoked llvrwInvoked readval interface{} - bsontype Type + bsontype bson.Type err error errAfter llvrwInvoked // error after this method is called } -func (llvrw *llValueReaderWriter) Type() Type { +func (llvrw *llValueReaderWriter) Type() bson.Type { return llvrw.bsontype } @@ -94,7 +95,7 @@ func (llvrw *llValueReaderWriter) ReadBinary() (b []byte, btype byte, err error) return nil, 0x00, llvrw.err } - bin, ok := llvrw.readval.(Binary) + bin, ok := llvrw.readval.(bson.Binary) if !ok { llvrw.t.Errorf("Incorrect type provided for return value of ReadBinary: %T", llvrw.readval) return nil, 0x00, nil @@ -141,7 +142,7 @@ func (llvrw *llValueReaderWriter) ReadDBPointer() (ns string, oid objectid.Objec return "", objectid.ObjectID{}, llvrw.err } - db, ok := llvrw.readval.(DBPointer) + db, ok := llvrw.readval.(bson.DBPointer) if !ok { llvrw.t.Errorf("Incorrect type provided for return value of ReadDBPointer: %T", llvrw.readval) return "", objectid.ObjectID{}, nil @@ -284,7 +285,7 @@ func (llvrw *llValueReaderWriter) ReadRegex() (pattern string, options string, e if llvrw.errAfter == llvrw.invoked { return "", "", llvrw.err } - rgx, ok := llvrw.readval.(Regex) + rgx, ok := llvrw.readval.(bson.Regex) if !ok { llvrw.t.Errorf("Incorrect type provided for return value of ReadRegex: %T", llvrw.readval) return "", "", nil @@ -312,7 +313,7 @@ func (llvrw *llValueReaderWriter) ReadSymbol() (symbol string, err error) { if llvrw.errAfter == llvrw.invoked { return "", llvrw.err } - symb, ok := llvrw.readval.(Symbol) + symb, ok := llvrw.readval.(bson.Symbol) if !ok { llvrw.t.Errorf("Incorrect type provided for return value of ReadSymbol: %T", llvrw.readval) return "", nil @@ -326,7 +327,7 @@ func (llvrw *llValueReaderWriter) ReadTimestamp() (t uint32, i uint32, err error if llvrw.errAfter == llvrw.invoked { return 0, 0, llvrw.err } - ts, ok := llvrw.readval.(Timestamp) + ts, ok := llvrw.readval.(bson.Timestamp) if !ok { llvrw.t.Errorf("Incorrect type provided for return value of ReadTimestamp: %T", llvrw.readval) return 0, 0, nil diff --git a/bson/bsoncodec/marshal.go b/bson/bsoncodec/marshal.go new file mode 100644 index 0000000000..a012cd744e --- /dev/null +++ b/bson/bsoncodec/marshal.go @@ -0,0 +1,92 @@ +package bsoncodec + +import ( + "reflect" + + "github.com/mongodb/mongo-go-driver/bson" +) + +// Marshaler is an interface implemented by types that can marshal themselves +// into a BSON document represented as bytes. The bytes returned must be a valid +// BSON document if the error is nil. +type Marshaler interface { + MarshalBSON() ([]byte, error) +} + +// ValueMarshaler is an interface implemented by types that can marshal +// themselves into a BSON value as bytes. The type must be the valid type for +// the bytes returned. The bytes and byte type together must be valid if the +// error is nil. +type ValueMarshaler interface { + MarshalBSONValue() (bson.Type, []byte, error) +} + +// Marshal returns the BSON encoding of val. +// +// Marshal will use the default registry created by NewRegistry to recursively +// marshal val into a []byte. Marshal will inspect struct tags and alter the +// marshaling process accordingly. +func Marshal(val interface{}) ([]byte, error) { + return MarshalWithRegistry(defaultRegistry, val) +} + +// MarshalAppend will append the BSON encoding of val to dst. If dst is not +// large enough to hold the BSON encoding of val, dst will be grown. +func MarshalAppend(dst []byte, val interface{}) ([]byte, error) { + return MarshalAppendWithRegistry(defaultRegistry, dst, val) +} + +// MarshalWithRegistry returns the BSON encoding of val using Registry r. +func MarshalWithRegistry(r *Registry, val interface{}) ([]byte, error) { + dst := make([]byte, 0, 256) // TODO: make the default cap a constant + return MarshalAppendWithRegistry(r, dst, val) +} + +// MarshalAppendWithRegistry will append the BSON encoding of val to dst using +// Registry r. If dst is not large enough to hold the BSON encoding of val, dst +// will be grown. +func MarshalAppendWithRegistry(r *Registry, dst []byte, val interface{}) ([]byte, error) { + // w := writer(dst) + // vw := newValueWriter(&w) + vw := vwPool.Get().(*valueWriter) + defer vwPool.Put(vw) + + vw.reset(dst) + + enc := encPool.Get().(*Encoder) + defer encPool.Put(enc) + + err := enc.Reset(vw) + if err != nil { + return nil, err + } + err = enc.SetRegistry(r) + if err != nil { + return nil, err + } + + err = enc.Encode(val) + if err != nil { + return nil, err + } + + return vw.buf, nil +} + +func marshalElement(r *Registry, dst []byte, key string, val interface{}) ([]byte, error) { + vw := vwPool.Get().(*valueWriter) + defer vwPool.Put(vw) + + vw.reset(dst) + _, err := vw.WriteDocumentElement(key) + if err != nil { + return dst, err + } + t := reflect.TypeOf(val) + enc, err := r.LookupEncoder(t) + if err != nil { + return dst, err + } + err = enc.EncodeValue(EncodeContext{Registry: r}, vw, val) + return dst, err +} diff --git a/bson/bsoncodec/marshal_test.go b/bson/bsoncodec/marshal_test.go new file mode 100644 index 0000000000..c5c77948c3 --- /dev/null +++ b/bson/bsoncodec/marshal_test.go @@ -0,0 +1,125 @@ +package bsoncodec + +import ( + "bytes" + "testing" + + "github.com/mongodb/mongo-go-driver/bson" + "github.com/stretchr/testify/require" +) + +func TestMarshalAppendWithRegistry(t *testing.T) { + for _, tc := range marshalingTestCases { + t.Run(tc.name, func(t *testing.T) { + dst := make([]byte, 0, 1024) + var reg *Registry + if tc.reg != nil { + reg = tc.reg + } else { + reg = NewRegistryBuilder().Build() + } + got, err := MarshalAppendWithRegistry(reg, dst, tc.val) + noerr(t, err) + + if !bytes.Equal(got, tc.want) { + t.Errorf("Bytes are not equal. got %v; want %v", bson.Reader(got), bson.Reader(tc.want)) + t.Errorf("Bytes:\n%v\n%v", got, tc.want) + } + }) + } +} + +func TestMarshalWithRegistry(t *testing.T) { + for _, tc := range marshalingTestCases { + t.Run(tc.name, func(t *testing.T) { + var reg *Registry + if tc.reg != nil { + reg = tc.reg + } else { + reg = NewRegistryBuilder().Build() + } + got, err := MarshalWithRegistry(reg, tc.val) + noerr(t, err) + + if !bytes.Equal(got, tc.want) { + t.Errorf("Bytes are not equal. got %v; want %v", bson.Reader(got), bson.Reader(tc.want)) + t.Errorf("Bytes:\n%v\n%v", got, tc.want) + } + }) + } +} + +func TestMarshalAppend(t *testing.T) { + for _, tc := range marshalingTestCases { + t.Run(tc.name, func(t *testing.T) { + if tc.reg != nil { + t.Skip() // test requires custom registry + } + dst := make([]byte, 0, 1024) + got, err := MarshalAppend(dst, tc.val) + noerr(t, err) + + if !bytes.Equal(got, tc.want) { + t.Errorf("Bytes are not equal. got %v; want %v", bson.Reader(got), bson.Reader(tc.want)) + t.Errorf("Bytes:\n%v\n%v", got, tc.want) + } + }) + } +} + +func TestMarshal_roundtripFromBytes(t *testing.T) { + before := []byte{ + // length + 0x1c, 0x0, 0x0, 0x0, + + // --- begin array --- + + // type - document + 0x3, + // key - "foo" + 0x66, 0x6f, 0x6f, 0x0, + + // length + 0x12, 0x0, 0x0, 0x0, + // type - string + 0x2, + // key - "bar" + 0x62, 0x61, 0x72, 0x0, + // value - string length + 0x4, 0x0, 0x0, 0x0, + // value - "baz" + 0x62, 0x61, 0x7a, 0x0, + + // null terminator + 0x0, + + // --- end array --- + + // null terminator + 0x0, + } + + doc := bson.NewDocument() + require.NoError(t, Unmarshal(before, doc)) + + after, err := Marshal(doc) + require.NoError(t, err) + + require.True(t, bytes.Equal(before, after)) +} + +func TestMarshal_roundtripFromDoc(t *testing.T) { + before := bson.NewDocument( + bson.EC.String("foo", "bar"), + bson.EC.Int32("baz", -27), + bson.EC.ArrayFromElements("bing", bson.VC.Null(), bson.VC.Regex("word", "i")), + ) + + b, err := Marshal(before) + require.NoError(t, err) + + after := bson.NewDocument() + require.NoError(t, Unmarshal(b, &after)) + + require.True(t, before.Equal(after)) +} diff --git a/bson/marshaling_cases_test.go b/bson/bsoncodec/marshaling_cases_test.go similarity index 62% rename from bson/marshaling_cases_test.go rename to bson/bsoncodec/marshaling_cases_test.go index 30a0ee963d..b7bf29e1fc 100644 --- a/bson/marshaling_cases_test.go +++ b/bson/bsoncodec/marshaling_cases_test.go @@ -1,4 +1,6 @@ -package bson +package bsoncodec + +import "github.com/mongodb/mongo-go-driver/bson" type marshalingTestCase struct { name string @@ -14,6 +16,6 @@ var marshalingTestCases = []marshalingTestCase{ struct { Foo bool }{Foo: true}, - bytesFromDoc(NewDocument(EC.Boolean("foo", true))), + bytesFromDoc(bson.NewDocument(bson.EC.Boolean("foo", true))), }, } diff --git a/bson/mode.go b/bson/bsoncodec/mode.go similarity index 81% rename from bson/mode.go rename to bson/bsoncodec/mode.go index 6ae9fb832c..0e967b7e27 100644 --- a/bson/mode.go +++ b/bson/bsoncodec/mode.go @@ -1,4 +1,4 @@ -package bson +package bsoncodec import "fmt" @@ -40,13 +40,15 @@ func (m mode) String() string { return str } -type transitionError struct { +// TransitionError is an error returned when an invalid progressing a +// ValueReader or ValueWriter state machine occurs. +type TransitionError struct { parent mode current mode destination mode } -func (te transitionError) Error() string { +func (te TransitionError) Error() string { if te.destination == mode(0) { return fmt.Sprintf("invalid state transition: cannot read/write value while in %s", te.current) } diff --git a/bson/bsoncodec/registry.go b/bson/bsoncodec/registry.go new file mode 100644 index 0000000000..5c2774a5ee --- /dev/null +++ b/bson/bsoncodec/registry.go @@ -0,0 +1,451 @@ +package bsoncodec + +import ( + "errors" + "reflect" + "sync" +) + +// ErrNilType is returned when nil is passed to either LookupEncoder or LookupDecoder. +var ErrNilType = errors.New("cannot perform an encoder or decoder lookup on ") + +// ErrNoEncoder is returned when there wasn't an encoder available for a type. +type ErrNoEncoder struct { + Type reflect.Type +} + +func (ene ErrNoEncoder) Error() string { + return "no encoder found for " + ene.Type.String() +} + +// ErrNoDecoder is returned when there wasn't a decoder available for a type. +type ErrNoDecoder struct { + Type reflect.Type +} + +func (end ErrNoDecoder) Error() string { + return "no decoder found for " + end.Type.String() +} + +// ErrNotInterface is returned when the provided type is not an interface. +var ErrNotInterface = errors.New("The provided type is not an interface") + +var defaultRegistry *Registry + +func init() { + defaultRegistry = NewRegistryBuilder().Build() +} + +// A RegistryBuilder is used to build a Registry. This type is not goroutine +// safe. +type RegistryBuilder struct { + typeEncoders map[reflect.Type]ValueEncoder + interfaceEncoders []interfaceValueEncoder + kindEncoders map[reflect.Kind]ValueEncoder + + typeDecoders map[reflect.Type]ValueDecoder + interfaceDecoders []interfaceValueDecoder + kindDecoders map[reflect.Kind]ValueDecoder +} + +// A Registry is used to store and retrieve codecs for types and interfaces. This type is the main +// typed passed around and Encoders and Decoders are constructed from it. +type Registry struct { + typeEncoders map[reflect.Type]ValueEncoder + typeDecoders map[reflect.Type]ValueDecoder + + interfaceEncoders []interfaceValueEncoder + interfaceDecoders []interfaceValueDecoder + + kindEncoders map[reflect.Kind]ValueEncoder + kindDecoders map[reflect.Kind]ValueDecoder + + mu sync.RWMutex +} + +// NewEmptyRegistryBuilder creates a new RegistryBuilder with no default kind +// Codecs. +func NewEmptyRegistryBuilder() *RegistryBuilder { + return &RegistryBuilder{ + typeEncoders: make(map[reflect.Type]ValueEncoder), + typeDecoders: make(map[reflect.Type]ValueDecoder), + + interfaceEncoders: make([]interfaceValueEncoder, 0), + interfaceDecoders: make([]interfaceValueDecoder, 0), + + kindEncoders: make(map[reflect.Kind]ValueEncoder), + kindDecoders: make(map[reflect.Kind]ValueDecoder), + } +} + +// NewRegistryBuilder creates a new RegistryBuilder. +func NewRegistryBuilder() *RegistryBuilder { + var dve DefaultValueEncoders + var dvd DefaultValueDecoders + typeEncoders := map[reflect.Type]ValueEncoder{ + tDocument: ValueEncoderFunc(dve.DocumentEncodeValue), + tArray: ValueEncoderFunc(dve.ArrayEncodeValue), + tValue: ValueEncoderFunc(dve.ValueEncodeValue), + reflect.PtrTo(tByteSlice): ValueEncoderFunc(dve.ByteSliceEncodeValue), + reflect.PtrTo(tElementSlice): ValueEncoderFunc(dve.ElementSliceEncodeValue), + reflect.PtrTo(tTime): ValueEncoderFunc(dve.TimeEncodeValue), + reflect.PtrTo(tEmpty): ValueEncoderFunc(dve.EmptyInterfaceEncodeValue), + reflect.PtrTo(tBinary): ValueEncoderFunc(dve.BooleanEncodeValue), + reflect.PtrTo(tUndefined): ValueEncoderFunc(dve.UndefinedEncodeValue), + reflect.PtrTo(tOID): ValueEncoderFunc(dve.ObjectIDEncodeValue), + reflect.PtrTo(tDateTime): ValueEncoderFunc(dve.DateTimeEncodeValue), + reflect.PtrTo(tNull): ValueEncoderFunc(dve.NullEncodeValue), + reflect.PtrTo(tRegex): ValueEncoderFunc(dve.RegexEncodeValue), + reflect.PtrTo(tDBPointer): ValueEncoderFunc(dve.DBPointerEncodeValue), + reflect.PtrTo(tCodeWithScope): ValueEncoderFunc(dve.CodeWithScopeEncodeValue), + reflect.PtrTo(tTimestamp): ValueEncoderFunc(dve.TimestampEncodeValue), + reflect.PtrTo(tDecimal): ValueEncoderFunc(dve.Decimal128EncodeValue), + reflect.PtrTo(tMinKey): ValueEncoderFunc(dve.MinKeyEncodeValue), + reflect.PtrTo(tMaxKey): ValueEncoderFunc(dve.MaxKeyEncodeValue), + reflect.PtrTo(tJSONNumber): ValueEncoderFunc(dve.JSONNumberEncodeValue), + reflect.PtrTo(tURL): ValueEncoderFunc(dve.URLEncodeValue), + reflect.PtrTo(tReader): ValueEncoderFunc(dve.ReaderEncodeValue), + } + + typeDecoders := map[reflect.Type]ValueDecoder{ + tDocument: ValueDecoderFunc(dvd.DocumentDecodeValue), + tArray: ValueDecoderFunc(dvd.ArrayDecodeValue), + tValue: ValueDecoderFunc(dvd.ValueDecodeValue), + reflect.PtrTo(tByteSlice): ValueDecoderFunc(dvd.ByteSliceDecodeValue), + reflect.PtrTo(tElementSlice): ValueDecoderFunc(dvd.ElementSliceDecodeValue), + reflect.PtrTo(tTime): ValueDecoderFunc(dvd.TimeDecodeValue), + reflect.PtrTo(tEmpty): ValueDecoderFunc(dvd.EmptyInterfaceDecodeValue), + reflect.PtrTo(tBinary): ValueDecoderFunc(dvd.BooleanDecodeValue), + reflect.PtrTo(tUndefined): ValueDecoderFunc(dvd.UndefinedDecodeValue), + reflect.PtrTo(tOID): ValueDecoderFunc(dvd.ObjectIDDecodeValue), + reflect.PtrTo(tDateTime): ValueDecoderFunc(dvd.DateTimeDecodeValue), + reflect.PtrTo(tNull): ValueDecoderFunc(dvd.NullDecodeValue), + reflect.PtrTo(tRegex): ValueDecoderFunc(dvd.RegexDecodeValue), + reflect.PtrTo(tDBPointer): ValueDecoderFunc(dvd.DBPointerDecodeValue), + reflect.PtrTo(tCodeWithScope): ValueDecoderFunc(dvd.CodeWithScopeDecodeValue), + reflect.PtrTo(tTimestamp): ValueDecoderFunc(dvd.TimestampDecodeValue), + reflect.PtrTo(tDecimal): ValueDecoderFunc(dvd.Decimal128DecodeValue), + reflect.PtrTo(tMinKey): ValueDecoderFunc(dvd.MinKeyDecodeValue), + reflect.PtrTo(tMaxKey): ValueDecoderFunc(dvd.MaxKeyDecodeValue), + reflect.PtrTo(tJSONNumber): ValueDecoderFunc(dvd.JSONNumberDecodeValue), + reflect.PtrTo(tURL): ValueDecoderFunc(dvd.URLDecodeValue), + reflect.PtrTo(tReader): ValueDecoderFunc(dvd.ReaderDecodeValue), + } + + interfaceEncoders := []interfaceValueEncoder{ + {i: tValueMarshaler, ve: ValueEncoderFunc(dve.ValueMarshalerEncodeValue)}, + } + + interfaceDecoders := []interfaceValueDecoder{ + {i: tValueUnmarshaler, vd: ValueDecoderFunc(dvd.ValueUnmarshalerDecodeValue)}, + } + + kindEncoders := map[reflect.Kind]ValueEncoder{ + reflect.Bool: ValueEncoderFunc(dve.BooleanEncodeValue), + reflect.Int: ValueEncoderFunc(dve.IntEncodeValue), + reflect.Int8: ValueEncoderFunc(dve.IntEncodeValue), + reflect.Int16: ValueEncoderFunc(dve.IntEncodeValue), + reflect.Int32: ValueEncoderFunc(dve.IntEncodeValue), + reflect.Int64: ValueEncoderFunc(dve.IntEncodeValue), + reflect.Uint: ValueEncoderFunc(dve.UintEncodeValue), + reflect.Uint8: ValueEncoderFunc(dve.UintEncodeValue), + reflect.Uint16: ValueEncoderFunc(dve.UintEncodeValue), + reflect.Uint32: ValueEncoderFunc(dve.UintEncodeValue), + reflect.Uint64: ValueEncoderFunc(dve.UintEncodeValue), + reflect.Float32: ValueEncoderFunc(dve.FloatEncodeValue), + reflect.Float64: ValueEncoderFunc(dve.FloatEncodeValue), + reflect.Array: ValueEncoderFunc(dve.SliceEncodeValue), + reflect.Map: ValueEncoderFunc(dve.MapEncodeValue), + reflect.Slice: ValueEncoderFunc(dve.SliceEncodeValue), + reflect.String: ValueEncoderFunc(dve.StringEncodeValue), + reflect.Struct: &StructCodec{cache: make(map[reflect.Type]*structDescription), parser: DefaultStructTagParser}, + } + + kindDecoders := map[reflect.Kind]ValueDecoder{ + reflect.Bool: ValueDecoderFunc(dvd.BooleanDecodeValue), + reflect.Int: ValueDecoderFunc(dvd.IntDecodeValue), + reflect.Int8: ValueDecoderFunc(dvd.IntDecodeValue), + reflect.Int16: ValueDecoderFunc(dvd.IntDecodeValue), + reflect.Int32: ValueDecoderFunc(dvd.IntDecodeValue), + reflect.Int64: ValueDecoderFunc(dvd.IntDecodeValue), + reflect.Uint: ValueDecoderFunc(dvd.UintDecodeValue), + reflect.Uint8: ValueDecoderFunc(dvd.UintDecodeValue), + reflect.Uint16: ValueDecoderFunc(dvd.UintDecodeValue), + reflect.Uint32: ValueDecoderFunc(dvd.UintDecodeValue), + reflect.Uint64: ValueDecoderFunc(dvd.UintDecodeValue), + reflect.Float32: ValueDecoderFunc(dvd.FloatDecodeValue), + reflect.Float64: ValueDecoderFunc(dvd.FloatDecodeValue), + reflect.Array: ValueDecoderFunc(dvd.SliceDecodeValue), + reflect.Map: ValueDecoderFunc(dvd.MapDecodeValue), + reflect.Slice: ValueDecoderFunc(dvd.SliceDecodeValue), + reflect.String: ValueDecoderFunc(dvd.StringDecodeValue), + reflect.Struct: &StructCodec{cache: make(map[reflect.Type]*structDescription), parser: DefaultStructTagParser}, + } + + return &RegistryBuilder{ + typeEncoders: typeEncoders, + typeDecoders: typeDecoders, + interfaceEncoders: interfaceEncoders, + interfaceDecoders: interfaceDecoders, + kindEncoders: kindEncoders, + kindDecoders: kindDecoders, + } +} + +// RegisterCodec will register the provided ValueCodec for the provided type. +func (rb *RegistryBuilder) RegisterCodec(t reflect.Type, codec ValueCodec) *RegistryBuilder { + rb.RegisterEncoder(t, codec) + rb.RegisterDecoder(t, codec) + return rb +} + +// RegisterEncoder will register the provided ValueEncoder to the provided type. +func (rb *RegistryBuilder) RegisterEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder { + switch t.Kind() { + case reflect.Interface: + for idx, ir := range rb.interfaceEncoders { + if ir.i == t { + rb.interfaceEncoders[idx].ve = enc + return rb + } + } + + rb.interfaceEncoders = append(rb.interfaceEncoders, interfaceValueEncoder{i: t, ve: enc}) + default: + if t.Kind() != reflect.Ptr { + t = reflect.PtrTo(t) + } + rb.typeEncoders[t] = enc + } + return rb +} + +// RegisterDecoder will register the provided ValueDecoder to the provided type. +func (rb *RegistryBuilder) RegisterDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder { + switch t.Kind() { + case reflect.Interface: + for idx, ir := range rb.interfaceDecoders { + if ir.i == t { + rb.interfaceDecoders[idx].vd = dec + return rb + } + } + + rb.interfaceDecoders = append(rb.interfaceDecoders, interfaceValueDecoder{i: t, vd: dec}) + default: + if t.Kind() != reflect.Ptr { + t = reflect.PtrTo(t) + } + rb.typeDecoders[t] = dec + } + return rb +} + +// RegisterDefaultEncoder will registr the provided ValueEncoder to the provided +// kind. +func (rb *RegistryBuilder) RegisterDefaultEncoder(kind reflect.Kind, enc ValueEncoder) *RegistryBuilder { + rb.kindEncoders[kind] = enc + return rb +} + +// RegisterDefaultDecoder will register the provided ValueDecoder to the +// provided kind. +func (rb *RegistryBuilder) RegisterDefaultDecoder(kind reflect.Kind, dec ValueDecoder) *RegistryBuilder { + rb.kindDecoders[kind] = dec + return rb +} + +// Build creates a Registry from the current state of this RegistryBuilder. +func (rb *RegistryBuilder) Build() *Registry { + registry := new(Registry) + + registry.typeEncoders = make(map[reflect.Type]ValueEncoder) + for t, enc := range rb.typeEncoders { + registry.typeEncoders[t] = enc + } + + registry.typeDecoders = make(map[reflect.Type]ValueDecoder) + for t, dec := range rb.typeDecoders { + registry.typeDecoders[t] = dec + } + + registry.interfaceEncoders = make([]interfaceValueEncoder, len(rb.interfaceEncoders)) + copy(registry.interfaceEncoders, rb.interfaceEncoders) + + registry.interfaceDecoders = make([]interfaceValueDecoder, len(rb.interfaceDecoders)) + copy(registry.interfaceDecoders, rb.interfaceDecoders) + + registry.kindEncoders = make(map[reflect.Kind]ValueEncoder) + for kind, enc := range rb.kindEncoders { + registry.kindEncoders[kind] = enc + } + + registry.kindDecoders = make(map[reflect.Kind]ValueDecoder) + for kind, dec := range rb.kindDecoders { + registry.kindDecoders[kind] = dec + } + + return registry +} + +// LookupEncoder will inspect the registry for an encoder that satisfies the +// type provided. An encoder registered for a specific type will take +// precedence over an encoder registered for an interface the type satisfies, +// which takes precedence over an encoder for the reflect.Kind of the value. If +// no encoder can be found, an error is returned. +func (r *Registry) LookupEncoder(t reflect.Type) (ValueEncoder, error) { + if t == nil { + return nil, ErrNilType + } + encodererr := ErrNoEncoder{Type: t} + r.mu.RLock() + enc, found := r.lookupTypeEncoder(t) + r.mu.RUnlock() + if found { + if enc == nil { + return nil, ErrNoEncoder{Type: t} + } + return enc, nil + } + + enc, found = r.lookupInterfaceEncoder(t) + if found { + r.mu.Lock() + if t.Kind() != reflect.Ptr { + t = reflect.PtrTo(t) + } + r.typeEncoders[t] = enc + r.mu.Unlock() + return enc, nil + } + + if t.Kind() == reflect.Map && t.Key().Kind() != reflect.String { + r.mu.Lock() + r.typeEncoders[t] = nil + r.mu.Unlock() + return nil, encodererr + } + + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + enc, found = r.kindEncoders[t.Kind()] + if !found { + r.mu.Lock() + r.typeEncoders[t] = nil + r.mu.Unlock() + return nil, encodererr + } + + r.mu.Lock() + r.typeEncoders[t] = enc + r.mu.Unlock() + return enc, nil +} + +func (r *Registry) lookupTypeEncoder(t reflect.Type) (ValueEncoder, bool) { + if t.Kind() != reflect.Ptr { + t = reflect.PtrTo(t) + } + + enc, found := r.typeEncoders[t] + return enc, found +} + +func (r *Registry) lookupInterfaceEncoder(t reflect.Type) (ValueEncoder, bool) { + for _, ienc := range r.interfaceEncoders { + if !t.Implements(ienc.i) { + continue + } + + return ienc.ve, true + } + return nil, false +} + +// LookupDecoder will inspect the registry for a decoder that satisfies the +// type provided. A decoder registered for a specific type will take +// precedence over a decoder registered for an interface the type satisfies, +// which takes precedence over a decoder for the reflect.Kind of the value. If +// no decoder can be found, an error is returned. +func (r *Registry) LookupDecoder(t reflect.Type) (ValueDecoder, error) { + if t == nil { + return nil, ErrNilType + } + decodererr := ErrNoDecoder{Type: t} + r.mu.RLock() + dec, found := r.lookupTypeDecoder(t) + r.mu.RUnlock() + if found { + if dec == nil { + return nil, ErrNoDecoder{Type: t} + } + return dec, nil + } + + dec, found = r.lookupInterfaceDecoder(t) + if found { + r.mu.Lock() + if t.Kind() != reflect.Ptr { + t = reflect.PtrTo(t) + } + r.typeDecoders[t] = dec + r.mu.Unlock() + return dec, nil + } + + if t.Kind() == reflect.Map && t.Key().Kind() != reflect.String { + r.mu.Lock() + r.typeDecoders[t] = nil + r.mu.Unlock() + return nil, decodererr + } + + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + dec, found = r.kindDecoders[t.Kind()] + if !found { + r.mu.Lock() + r.typeDecoders[t] = nil + r.mu.Unlock() + return nil, decodererr + } + + r.mu.Lock() + r.typeDecoders[t] = dec + r.mu.Unlock() + return dec, nil +} + +func (r *Registry) lookupTypeDecoder(t reflect.Type) (ValueDecoder, bool) { + if t.Kind() != reflect.Ptr { + t = reflect.PtrTo(t) + } + + dec, found := r.typeDecoders[t] + return dec, found +} + +func (r *Registry) lookupInterfaceDecoder(t reflect.Type) (ValueDecoder, bool) { + for _, idec := range r.interfaceDecoders { + if !t.Implements(idec.i) { + continue + } + + return idec.vd, true + } + return nil, false +} + +type interfaceValueEncoder struct { + i reflect.Type + ve ValueEncoder +} + +type interfaceValueDecoder struct { + i reflect.Type + vd ValueDecoder +} diff --git a/bson/registry_test.go b/bson/bsoncodec/registry_test.go similarity index 57% rename from bson/registry_test.go rename to bson/bsoncodec/registry_test.go index 879bfb4c83..812ad92100 100644 --- a/bson/registry_test.go +++ b/bson/bsoncodec/registry_test.go @@ -1,4 +1,4 @@ -package bson +package bsoncodec import ( "reflect" @@ -8,50 +8,48 @@ import ( ) func TestRegistry(t *testing.T) { - trInterface := NewRegistryBuilder() - trInterface.interfaces = append(trInterface.interfaces) t.Run("Register", func(t *testing.T) { fc1, fc2, fc3, fc4 := new(fakeCodec), new(fakeCodec), new(fakeCodec), new(fakeCodec) t.Run("interface", func(t *testing.T) { var t1f *testInterface1 var t2f *testInterface2 var t4f *testInterface4 - ips := []interfacePair{ - {i: reflect.TypeOf(t1f).Elem(), c: fc1}, - {i: reflect.TypeOf(t2f).Elem(), c: fc2}, - {i: reflect.TypeOf(t1f).Elem(), c: fc3}, - {i: reflect.TypeOf(t4f).Elem(), c: fc4}, + ips := []interfaceValueEncoder{ + {i: reflect.TypeOf(t1f).Elem(), ve: fc1}, + {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, + {i: reflect.TypeOf(t1f).Elem(), ve: fc3}, + {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, } - want := []interfacePair{ - {i: reflect.TypeOf(t1f).Elem(), c: fc3}, - {i: reflect.TypeOf(t2f).Elem(), c: fc2}, - {i: reflect.TypeOf(t4f).Elem(), c: fc4}, + want := []interfaceValueEncoder{ + {i: reflect.TypeOf(t1f).Elem(), ve: fc3}, + {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, + {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, } - rb := NewRegistryBuilder() + rb := NewEmptyRegistryBuilder() for _, ip := range ips { - rb.Register(ip.i, ip.c) + rb.RegisterEncoder(ip.i, ip.ve) } - got := rb.interfaces - if !cmp.Equal(got, want, cmp.AllowUnexported(interfacePair{}, fakeCodec{}), cmp.Comparer(typeComparer)) { + got := rb.interfaceEncoders + if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) { t.Errorf("The registered interfaces are not correct. got %v; want %v", got, want) } }) t.Run("type", func(t *testing.T) { ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{} rb := NewRegistryBuilder(). - Register(reflect.TypeOf(ft1), fc1). - Register(reflect.TypeOf(ft2), fc2). - Register(reflect.TypeOf(ft1), fc3). - Register(reflect.TypeOf(ft4), fc4) + RegisterEncoder(reflect.TypeOf(ft1), fc1). + RegisterEncoder(reflect.TypeOf(ft2), fc2). + RegisterEncoder(reflect.TypeOf(ft1), fc3). + RegisterEncoder(reflect.TypeOf(ft4), fc4) want := []struct { t reflect.Type - c Codec + c ValueEncoder }{ {reflect.PtrTo(reflect.TypeOf(ft1)), fc3}, {reflect.PtrTo(reflect.TypeOf(ft2)), fc2}, {reflect.PtrTo(reflect.TypeOf(ft4)), fc4}, } - got := rb.types + got := rb.typeEncoders for _, s := range want { wantT, wantC := s.t, s.c gotC, exists := got[wantT] @@ -66,19 +64,19 @@ func TestRegistry(t *testing.T) { t.Run("kind", func(t *testing.T) { k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map rb := NewRegistryBuilder(). - RegisterDefault(k1, fc1). - RegisterDefault(k2, fc2). - RegisterDefault(k1, fc3). - RegisterDefault(k4, fc4) + RegisterDefaultEncoder(k1, fc1). + RegisterDefaultEncoder(k2, fc2). + RegisterDefaultEncoder(k1, fc3). + RegisterDefaultEncoder(k4, fc4) want := []struct { k reflect.Kind - c Codec + c ValueEncoder }{ {k1, fc3}, {k2, fc2}, {k4, fc4}, } - got := rb.kinds + got := rb.kindEncoders for _, s := range want { wantK, wantC := s.k, s.c gotC, exists := got[wantK] @@ -95,56 +93,61 @@ func TestRegistry(t *testing.T) { codec := fakeCodec{num: 1} codec2 := fakeCodec{num: 2} rb := NewRegistryBuilder() - rb.RegisterDefault(reflect.Map, codec) - if rb.kinds[reflect.Map] != codec { - t.Errorf("Did not properly set the map codec. got %v; want %v", rb.kinds[reflect.Map], codec) + rb.RegisterDefaultEncoder(reflect.Map, codec) + if rb.kindEncoders[reflect.Map] != codec { + t.Errorf("Did not properly set the map codec. got %v; want %v", rb.kindEncoders[reflect.Map], codec) } - rb.RegisterDefault(reflect.Map, codec2) - if rb.kinds[reflect.Map] != codec2 { - t.Errorf("Did not properly set the map codec. got %v; want %v", rb.kinds[reflect.Map], codec2) + rb.RegisterDefaultEncoder(reflect.Map, codec2) + if rb.kindEncoders[reflect.Map] != codec2 { + t.Errorf("Did not properly set the map codec. got %v; want %v", rb.kindEncoders[reflect.Map], codec2) } }) t.Run("StructCodec", func(t *testing.T) { codec := fakeCodec{num: 1} codec2 := fakeCodec{num: 2} rb := NewRegistryBuilder() - rb.RegisterDefault(reflect.Struct, codec) - if rb.kinds[reflect.Struct] != codec { - t.Errorf("Did not properly set the struct codec. got %v; want %v", rb.kinds[reflect.Struct], codec) + rb.RegisterDefaultEncoder(reflect.Struct, codec) + if rb.kindEncoders[reflect.Struct] != codec { + t.Errorf("Did not properly set the struct codec. got %v; want %v", rb.kindEncoders[reflect.Struct], codec) } - rb.RegisterDefault(reflect.Struct, codec2) - if rb.kinds[reflect.Struct] != codec2 { - t.Errorf("Did not properly set the struct codec. got %v; want %v", rb.kinds[reflect.Struct], codec2) + rb.RegisterDefaultEncoder(reflect.Struct, codec2) + if rb.kindEncoders[reflect.Struct] != codec2 { + t.Errorf("Did not properly set the struct codec. got %v; want %v", rb.kindEncoders[reflect.Struct], codec2) } }) t.Run("SliceCodec", func(t *testing.T) { codec := fakeCodec{num: 1} codec2 := fakeCodec{num: 2} rb := NewRegistryBuilder() - rb.RegisterDefault(reflect.Slice, codec) - if rb.kinds[reflect.Slice] != codec { - t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kinds[reflect.Slice], codec) + rb.RegisterDefaultEncoder(reflect.Slice, codec) + if rb.kindEncoders[reflect.Slice] != codec { + t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Slice], codec) } - rb.RegisterDefault(reflect.Slice, codec2) - if rb.kinds[reflect.Slice] != codec2 { - t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kinds[reflect.Slice], codec2) + rb.RegisterDefaultEncoder(reflect.Slice, codec2) + if rb.kindEncoders[reflect.Slice] != codec2 { + t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Slice], codec2) } }) t.Run("ArrayCodec", func(t *testing.T) { codec := fakeCodec{num: 1} codec2 := fakeCodec{num: 2} rb := NewRegistryBuilder() - rb.RegisterDefault(reflect.Array, codec) - if rb.kinds[reflect.Array] != codec { - t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kinds[reflect.Array], codec) + rb.RegisterDefaultEncoder(reflect.Array, codec) + if rb.kindEncoders[reflect.Array] != codec { + t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Array], codec) } - rb.RegisterDefault(reflect.Array, codec2) - if rb.kinds[reflect.Array] != codec2 { - t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kinds[reflect.Array], codec2) + rb.RegisterDefaultEncoder(reflect.Array, codec2) + if rb.kindEncoders[reflect.Array] != codec2 { + t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Array], codec2) } }) }) t.Run("Lookup", func(t *testing.T) { + type Codec interface { + ValueEncoder + ValueDecoder + } + var arrinstance [12]int arr := reflect.TypeOf(arrinstance) slc := reflect.TypeOf(make([]int, 12)) @@ -157,13 +160,20 @@ func TestRegistry(t *testing.T) { fc1, fc2, fc4 := fakeCodec{num: 1}, fakeCodec{num: 2}, fakeCodec{num: 4} fsc, fslcc, fmc := new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec) reg := NewRegistryBuilder(). - Register(ft1, fc1). - Register(ft2, fc2). - Register(ti2, fc4). - RegisterDefault(reflect.Struct, fsc). - RegisterDefault(reflect.Slice, fslcc). - RegisterDefault(reflect.Array, fslcc). - RegisterDefault(reflect.Map, fmc). + RegisterEncoder(ft1, fc1). + RegisterEncoder(ft2, fc2). + RegisterEncoder(ti2, fc4). + RegisterDefaultEncoder(reflect.Struct, fsc). + RegisterDefaultEncoder(reflect.Slice, fslcc). + RegisterDefaultEncoder(reflect.Array, fslcc). + RegisterDefaultEncoder(reflect.Map, fmc). + RegisterDecoder(ft1, fc1). + RegisterDecoder(ft2, fc2). + RegisterDecoder(ti2, fc4). + RegisterDefaultDecoder(reflect.Struct, fsc). + RegisterDefaultDecoder(reflect.Slice, fslcc). + RegisterDefaultDecoder(reflect.Array, fslcc). + RegisterDefaultDecoder(reflect.Map, fmc). Build() testCases := []struct { @@ -233,28 +243,47 @@ func TestRegistry(t *testing.T) { "map non-string key", reflect.TypeOf(map[int]int{}), nil, - ErrNoCodec{Type: reflect.TypeOf(map[int]int{})}, + ErrNoEncoder{Type: reflect.TypeOf(map[int]int{})}, false, }, { "No Codec Registered", ft3, nil, - ErrNoCodec{Type: ft3}, + ErrNoEncoder{Type: ft3}, false, }, } allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{}) for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - gotcodec, goterr := reg.Lookup(tc.t) - if !cmp.Equal(goterr, tc.wanterr, cmp.Comparer(compareErrors)) { - t.Errorf("Errors did not match. got %v; want %v", goterr, tc.wanterr) - } - if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported) { - t.Errorf("Codecs did not match. got %v; want %v", gotcodec, tc.wantcodec) - } + t.Run("Encoder", func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { + gotcodec, goterr := reg.LookupEncoder(tc.t) + if !cmp.Equal(goterr, tc.wanterr, cmp.Comparer(compareErrors)) { + t.Errorf("Errors did not match. got %v; want %v", goterr, tc.wanterr) + } + if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported) { + t.Errorf("Codecs did not match. got %v; want %v", gotcodec, tc.wantcodec) + } + }) + }) + t.Run("Decoder", func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { + var wanterr error + if ene, ok := tc.wanterr.(ErrNoEncoder); ok { + wanterr = ErrNoDecoder{Type: ene.Type} + } else { + wanterr = tc.wanterr + } + gotcodec, goterr := reg.LookupDecoder(tc.t) + if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) { + t.Errorf("Errors did not match. got %v; want %v", goterr, wanterr) + } + if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported) { + t.Errorf("Codecs did not match. got %v; want %v", gotcodec, tc.wantcodec) + } + }) }) } }) diff --git a/bson/struct_codec.go b/bson/bsoncodec/struct_codec.go similarity index 82% rename from bson/struct_codec.go rename to bson/bsoncodec/struct_codec.go index f3a1c34cd6..fe159faccf 100644 --- a/bson/struct_codec.go +++ b/bson/bsoncodec/struct_codec.go @@ -1,10 +1,12 @@ -package bson +package bsoncodec import ( "errors" "fmt" "reflect" "sync" + + "github.com/mongodb/mongo-go-driver/bson" ) var defaultStructCodec = &StructCodec{ @@ -19,7 +21,8 @@ type StructCodec struct { parser StructTagParser } -var _ Codec = &StructCodec{} +var _ ValueEncoder = &StructCodec{} +var _ ValueDecoder = &StructCodec{} // NewStructCodec returns a StructCodec that uses p for struct tag parsing. func NewStructCodec(p StructTagParser) (*StructCodec, error) { @@ -66,11 +69,15 @@ func (sc *StructCodec) EncodeValue(r EncodeContext, vw ValueWriter, i interface{ rv = val.FieldByIndex(desc.inline) } - codec := desc.codec + if desc.encoder == nil { + return ErrNoEncoder{Type: rv.Type()} + } + + encoder := desc.encoder iszero := sc.isZero - if iz, ok := codec.(CodecZeroer); ok { - iszero = iz.IsZero + if iz, ok := encoder.(CodecZeroer); ok { + iszero = iz.IsTypeZero } if desc.omitEmpty && iszero(rv.Interface()) { @@ -83,7 +90,7 @@ func (sc *StructCodec) EncodeValue(r EncodeContext, vw ValueWriter, i interface{ } ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize} - err = codec.EncodeValue(ectx, vw2, rv.Interface()) + err = encoder.EncodeValue(ectx, vw2, rv.Interface()) if err != nil { return err } @@ -96,7 +103,7 @@ func (sc *StructCodec) EncodeValue(r EncodeContext, vw ValueWriter, i interface{ return exists } - return defaultMapCodec.encodeValue(r, dw, rv, collisionFn) + return defaultValueEncoders.mapEncodeValue(r, dw, rv, collisionFn) } return dw.WriteDocumentEnd() @@ -106,7 +113,16 @@ func (sc *StructCodec) EncodeValue(r EncodeContext, vw ValueWriter, i interface{ func (sc *StructCodec) DecodeValue(r DecodeContext, vr ValueReader, i interface{}) error { val := reflect.ValueOf(i) if val.Kind() == reflect.Ptr { + if val.IsNil() { + val = reflect.New(val.Type().Elem()) + } val = val.Elem() + if val.Kind() == reflect.Ptr { + if val.IsNil() && val.CanSet() { + val.Set(reflect.New(val.Type().Elem())) + } + val = val.Elem() + } } if val.Kind() != reflect.Struct || !val.CanAddr() { return fmt.Errorf("%T can only processes addressable structs, but got %T (addressable: %t)", sc, i, val.CanAddr()) @@ -124,7 +140,7 @@ func (sc *StructCodec) DecodeValue(r DecodeContext, vr ValueReader, i interface{ if inlineMap.IsNil() { inlineMap.Set(reflect.MakeMap(inlineMap.Type())) } - dFn, err = defaultMapCodec.decodeFn(r, inlineMap) + dFn, err = defaultValueDecoders.decodeFn(r, inlineMap) if err != nil { return err } @@ -180,17 +196,21 @@ func (sc *StructCodec) DecodeValue(r DecodeContext, vr ValueReader, i interface{ field = field.Addr() dctx := DecodeContext{Registry: r.Registry, Truncate: fd.truncate} - if ec, ok := fd.codec.(*elementCodec); ok { - err = ec.decodeValue(dctx, vr, name, field.Interface().(**Element)) - if err != nil { - return err + if fd.decoder == nil { + if field.Type() == reflect.PtrTo(tElement) { + err = defaultValueDecoders.elementDecodeValue(dctx, vr, name, field.Interface().(**bson.Element)) + if err != nil { + return err + } + + continue } - continue + + return ErrNoDecoder{Type: field.Elem().Type()} } - err = fd.codec.DecodeValue(dctx, vr, field.Interface()) + err = fd.decoder.DecodeValue(dctx, vr, field.Interface()) if err != nil { - fmt.Println(name) return err } } @@ -200,6 +220,10 @@ func (sc *StructCodec) DecodeValue(r DecodeContext, vr ValueReader, i interface{ func (sc *StructCodec) isZero(i interface{}) bool { v := reflect.ValueOf(i) + if z, ok := v.Interface().(bson.Zeroer); ok { + return z.IsZero() + } + switch v.Kind() { case reflect.Array, reflect.Map, reflect.Slice, reflect.String: return v.Len() == 0 @@ -213,11 +237,6 @@ func (sc *StructCodec) isZero(i interface{}) bool { return v.Float() == 0 case reflect.Interface, reflect.Ptr: return v.IsNil() - case reflect.Struct: - if z, ok := v.Interface().(Zeroer); ok { - return z.IsZero() - } - return false } return false @@ -236,7 +255,8 @@ type fieldDescription struct { minSize bool truncate bool inline []int - codec Codec + encoder ValueEncoder + decoder ValueDecoder } func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) { @@ -263,20 +283,25 @@ func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescr continue } - var codec Codec + var encoder ValueEncoder + var decoder ValueDecoder var err error switch sf.Type { case tElement: // We handle this as a special case within the struct codec. - codec = defaultElementCodec + encoder = ValueEncoderFunc(defaultValueEncoders.elementEncodeValue) default: - codec, err = r.Lookup(sf.Type) + encoder, err = r.LookupEncoder(sf.Type) + if err != nil { + encoder = nil + } + decoder, err = r.LookupDecoder(sf.Type) if err != nil { - return nil, err + decoder = nil } } - description := fieldDescription{idx: i, codec: codec} + description := fieldDescription{idx: i, encoder: encoder, decoder: decoder} stags, err := sc.parser.ParseStructTags(sf) if err != nil { diff --git a/bson/bsoncodec/struct_codec_test.go b/bson/bsoncodec/struct_codec_test.go new file mode 100644 index 0000000000..489d2c7982 --- /dev/null +++ b/bson/bsoncodec/struct_codec_test.go @@ -0,0 +1,32 @@ +package bsoncodec + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestZeoerInterfaceUsedByDecoder(t *testing.T) { + enc := &StructCodec{} + + // cases that are zero, because they are known types or pointers + var st *nonZeroer + assert.True(t, enc.isZero(st)) + assert.True(t, enc.isZero(0)) + assert.True(t, enc.isZero(false)) + + // cases that shouldn't be zero + st = &nonZeroer{value: false} + assert.False(t, enc.isZero(struct{ val bool }{val: true})) + assert.False(t, enc.isZero(struct{ val bool }{val: false})) + assert.False(t, enc.isZero(st)) + st.value = true + assert.False(t, enc.isZero(st)) + + // a test to see if the interface impacts the outcome + z := zeroTest{} + assert.False(t, enc.isZero(z)) + + z.reportZero = true + assert.True(t, enc.isZero(z)) +} diff --git a/bson/struct_tag_parser.go b/bson/bsoncodec/struct_tag_parser.go similarity index 99% rename from bson/struct_tag_parser.go rename to bson/bsoncodec/struct_tag_parser.go index a04f29b72c..e8f45550b3 100644 --- a/bson/struct_tag_parser.go +++ b/bson/bsoncodec/struct_tag_parser.go @@ -1,4 +1,4 @@ -package bson +package bsoncodec import ( "reflect" diff --git a/bson/struct_tag_parser_test.go b/bson/bsoncodec/struct_tag_parser_test.go similarity index 99% rename from bson/struct_tag_parser_test.go rename to bson/bsoncodec/struct_tag_parser_test.go index 2853ac7509..6decbec920 100644 --- a/bson/struct_tag_parser_test.go +++ b/bson/bsoncodec/struct_tag_parser_test.go @@ -1,4 +1,4 @@ -package bson +package bsoncodec import ( "reflect" diff --git a/bson/bsoncodec/types.go b/bson/bsoncodec/types.go new file mode 100644 index 0000000000..48ae6a83a1 --- /dev/null +++ b/bson/bsoncodec/types.go @@ -0,0 +1,57 @@ +package bsoncodec + +import ( + "encoding/json" + "net/url" + "reflect" + "time" + + "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/decimal" + "github.com/mongodb/mongo-go-driver/bson/objectid" +) + +var tDocument = reflect.TypeOf((*bson.Document)(nil)) +var tArray = reflect.TypeOf((*bson.Array)(nil)) +var tBinary = reflect.TypeOf(bson.Binary{}) +var tBool = reflect.TypeOf(false) +var tCodeWithScope = reflect.TypeOf(bson.CodeWithScope{}) +var tDBPointer = reflect.TypeOf(bson.DBPointer{}) +var tDecimal = reflect.TypeOf(decimal.Decimal128{}) +var tDateTime = reflect.TypeOf(bson.DateTime(0)) +var tUndefined = reflect.TypeOf(bson.Undefinedv2{}) +var tNull = reflect.TypeOf(bson.Nullv2{}) +var tValue = reflect.TypeOf((*bson.Value)(nil)) +var tFloat32 = reflect.TypeOf(float32(0)) +var tFloat64 = reflect.TypeOf(float64(0)) +var tInt = reflect.TypeOf(int(0)) +var tInt8 = reflect.TypeOf(int8(0)) +var tInt16 = reflect.TypeOf(int16(0)) +var tInt32 = reflect.TypeOf(int32(0)) +var tInt64 = reflect.TypeOf(int64(0)) +var tJavaScriptCode = reflect.TypeOf(bson.JavaScriptCode("")) +var tOID = reflect.TypeOf(objectid.ObjectID{}) +var tReader = reflect.TypeOf(bson.Reader(nil)) +var tRegex = reflect.TypeOf(bson.Regex{}) +var tString = reflect.TypeOf("") +var tSymbol = reflect.TypeOf(bson.Symbol("")) +var tTime = reflect.TypeOf(time.Time{}) +var tTimestamp = reflect.TypeOf(bson.Timestamp{}) +var tUint = reflect.TypeOf(uint(0)) +var tUint8 = reflect.TypeOf(uint8(0)) +var tUint16 = reflect.TypeOf(uint16(0)) +var tUint32 = reflect.TypeOf(uint32(0)) +var tUint64 = reflect.TypeOf(uint64(0)) +var tMinKey = reflect.TypeOf(bson.MinKeyv2{}) +var tMaxKey = reflect.TypeOf(bson.MaxKeyv2{}) + +var tEmpty = reflect.TypeOf((*interface{})(nil)).Elem() +var tElement = reflect.TypeOf((*bson.Element)(nil)) +var tByteSlice = reflect.TypeOf([]byte(nil)) +var tElementSlice = reflect.TypeOf(([]*bson.Element)(nil)) +var tByte = reflect.TypeOf(byte(0x00)) +var tURL = reflect.TypeOf(url.URL{}) +var tJSONNumber = reflect.TypeOf(json.Number("")) + +var tValueMarshaler = reflect.TypeOf((*ValueMarshaler)(nil)).Elem() +var tValueUnmarshaler = reflect.TypeOf((*ValueUnmarshaler)(nil)).Elem() diff --git a/bson/bsoncodec/unmarshal.go b/bson/bsoncodec/unmarshal.go new file mode 100644 index 0000000000..2e9fd1c459 --- /dev/null +++ b/bson/bsoncodec/unmarshal.go @@ -0,0 +1,47 @@ +package bsoncodec + +import "github.com/mongodb/mongo-go-driver/bson" + +// Unmarshaler is an interface implemented by types that can unmarshal a BSON +// document representation of themselves. The BSON bytes can be assumed to be +// valid. UnmarshalBSON must copy the BSON bytes if it wishes to retain the data +// after returning. +type Unmarshaler interface { + UnmarshalBSON([]byte) error +} + +// ValueUnmarshaler is an interface implemented by types that can unmarshal a +// BSON value representaiton of themselves. The BSON bytes and type can be +// assumed to be valid. UnmarshalBSONValue must copy the BSON value bytes if it +// wishes to retain the data after returning. +type ValueUnmarshaler interface { + UnmarshalBSONValue(bson.Type, []byte) error +} + +// Unmarshal parses the BSON-encoded data and stores the result in the value +// pointed to by val. If val is nil or not a pointer, Unmarshal returns +// InvalidUnmarshalError. +func Unmarshal(data []byte, val interface{}) error { + return UnmarshalWithRegistry(defaultRegistry, data, val) +} + +// UnmarshalWithRegistry parses the BSON-encoded data using Registry r and +// stores the result in the value pointed to by val. If val is nil or not +// a pointer, UnmarshalWithRegistry returns InvalidUnmarshalError. +func UnmarshalWithRegistry(r *Registry, data []byte, val interface{}) error { + vr := newValueReader(data) + + dec := decPool.Get().(*Decoder) + defer decPool.Put(dec) + + err := dec.Reset(vr) + if err != nil { + return err + } + err = dec.SetRegistry(r) + if err != nil { + return err + } + + return dec.Decode(val) +} diff --git a/bson/bsoncodec/unmarshal_test.go b/bson/bsoncodec/unmarshal_test.go new file mode 100644 index 0000000000..4331cdacc8 --- /dev/null +++ b/bson/bsoncodec/unmarshal_test.go @@ -0,0 +1,43 @@ +package bsoncodec + +import ( + "reflect" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestUnmarshal(t *testing.T) { + for _, tc := range unmarshalingTestCases { + t.Run(tc.name, func(t *testing.T) { + if tc.reg != nil { + t.Skip() // test requires custom registry + } + got := reflect.New(tc.sType).Interface() + err := Unmarshal(tc.data, got) + noerr(t, err) + if !cmp.Equal(got, tc.want) { + t.Errorf("Did not unmarshal as expected. got %v; want %v", got, tc.want) + } + }) + } +} + +func TestUnmarshalWithRegistry(t *testing.T) { + for _, tc := range unmarshalingTestCases { + t.Run(tc.name, func(t *testing.T) { + var reg *Registry + if tc.reg != nil { + reg = tc.reg + } else { + reg = NewRegistryBuilder().Build() + } + got := reflect.New(tc.sType).Interface() + err := UnmarshalWithRegistry(reg, tc.data, got) + noerr(t, err) + if !cmp.Equal(got, tc.want) { + t.Errorf("Did not unmarshal as expected. got %v; want %v", got, tc.want) + } + }) + } +} diff --git a/bson/unmarshaling_cases_test.go b/bson/bsoncodec/unmarshaling_cases_test.go similarity index 68% rename from bson/unmarshaling_cases_test.go rename to bson/bsoncodec/unmarshaling_cases_test.go index 9bf7b6b3a5..7c22d07566 100644 --- a/bson/unmarshaling_cases_test.go +++ b/bson/bsoncodec/unmarshaling_cases_test.go @@ -1,10 +1,12 @@ -package bson +package bsoncodec import ( "bytes" "fmt" "io" "reflect" + + "github.com/mongodb/mongo-go-driver/bson" ) type unmarshalingTestCase struct { @@ -25,7 +27,7 @@ var unmarshalingTestCases = []unmarshalingTestCase{ &struct { Foo bool }{Foo: true}, - bytesFromDoc(NewDocument(EC.Boolean("foo", true))), + bytesFromDoc(bson.NewDocument(bson.EC.Boolean("foo", true))), }, { "nested document", @@ -44,7 +46,7 @@ var unmarshalingTestCases = []unmarshalingTestCase{ Bar bool }{Bar: true}, }, - bytesFromDoc(NewDocument(EC.SubDocumentFromElements("foo", EC.Boolean("bar", true)))), + bytesFromDoc(bson.NewDocument(bson.EC.SubDocumentFromElements("foo", bson.EC.Boolean("bar", true)))), }, { "simple array", @@ -57,11 +59,11 @@ var unmarshalingTestCases = []unmarshalingTestCase{ }{ Foo: []bool{true}, }, - bytesFromDoc(NewDocument(EC.ArrayFromElements("foo", VC.Boolean(true)))), + bytesFromDoc(bson.NewDocument(bson.EC.ArrayFromElements("foo", bson.VC.Boolean(true)))), }, } -func ioReaderFromDoc(doc *Document) io.Reader { +func ioReaderFromDoc(doc *bson.Document) io.Reader { b, err := doc.MarshalBSON() if err != nil { panic(fmt.Errorf("Couldn't marshal BSON document: %v", err)) diff --git a/bson/value_reader.go b/bson/bsoncodec/value_reader.go similarity index 73% rename from bson/value_reader.go rename to bson/bsoncodec/value_reader.go index 7f529eae5f..8da27bc4ed 100644 --- a/bson/value_reader.go +++ b/bson/bsoncodec/value_reader.go @@ -1,4 +1,4 @@ -package bson +package bsoncodec import ( "bytes" @@ -7,11 +7,19 @@ import ( "fmt" "io" "math" + "sync" + "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/bson/decimal" "github.com/mongodb/mongo-go-driver/bson/objectid" ) +var vrPool = sync.Pool{ + New: func() interface{} { + return new(valueReader) + }, +} + // ErrEOA is the error returned when the end of a BSON array has been reached. var ErrEOA = errors.New("end of array") @@ -20,7 +28,7 @@ var ErrEOD = errors.New("end of document") type vrState struct { mode mode - vType Type + vType bson.Type end int64 } @@ -50,6 +58,17 @@ func newValueReader(b []byte) *valueReader { } } +func (vr *valueReader) reset(b []byte) { + if vr.stack == nil { + vr.stack = make([]vrState, 1, 5) + } + vr.stack = vr.stack[:1] + vr.stack[0] = vrState{mode: mTopLevel} + vr.d = b + vr.offset = 0 + vr.frame = 0 +} + func (vr *valueReader) advanceFrame() { if vr.frame+1 >= int64(len(vr.stack)) { // We need to grow the stack length := len(vr.stack) @@ -97,14 +116,14 @@ func (vr *valueReader) pushArray() error { return nil } -func (vr *valueReader) pushElement(t Type) { +func (vr *valueReader) pushElement(t bson.Type) { vr.advanceFrame() vr.stack[vr.frame].mode = mElement vr.stack[vr.frame].vType = t } -func (vr *valueReader) pushValue(t Type) { +func (vr *valueReader) pushValue(t bson.Type) { vr.advanceFrame() vr.stack[vr.frame].mode = mValue @@ -135,7 +154,7 @@ func (vr *valueReader) pop() { } func (vr *valueReader) invalidTransitionErr(destination mode) error { - te := transitionError{ + te := TransitionError{ current: vr.stack[vr.frame].mode, destination: destination, } @@ -145,7 +164,7 @@ func (vr *valueReader) invalidTransitionErr(destination mode) error { return te } -func (vr *valueReader) typeError(t Type) error { +func (vr *valueReader) typeError(t bson.Type) error { return fmt.Errorf("positioned on %s, but attempted to read %s", vr.stack[vr.frame].vType, t) } @@ -153,7 +172,7 @@ func (vr *valueReader) invalidDocumentLengthError() error { return fmt.Errorf("document is invalid, end byte is at %d, but null byte found at %d", vr.stack[vr.frame].end, vr.offset) } -func (vr *valueReader) ensureElementValue(t Type, destination mode) error { +func (vr *valueReader) ensureElementValue(t bson.Type, destination mode) error { switch vr.stack[vr.frame].mode { case mElement, mValue: if vr.stack[vr.frame].vType != t { @@ -166,71 +185,93 @@ func (vr *valueReader) ensureElementValue(t Type, destination mode) error { return nil } -func (vr *valueReader) Type() Type { +func (vr *valueReader) Type() bson.Type { return vr.stack[vr.frame].vType } -func (vr *valueReader) Skip() error { +func (vr *valueReader) nextElementLength() (int32, error) { + var length int32 + var err error + switch vr.stack[vr.frame].vType { + case bson.TypeArray, bson.TypeEmbeddedDocument, bson.TypeCodeWithScope: + length, err = vr.peakLength() + case bson.TypeBinary: + length, err = vr.peakLength() + length += 4 + 1 // binary length + subtype byte + case bson.TypeBoolean: + length = 1 + case bson.TypeDBPointer: + length, err = vr.peakLength() + length += 4 + 12 // string length + ObjectID length + case bson.TypeDateTime, bson.TypeDouble, bson.TypeInt64, bson.TypeTimestamp: + length = 8 + case bson.TypeDecimal128: + length = 16 + case bson.TypeInt32: + length = 4 + case bson.TypeJavaScript, bson.TypeString, bson.TypeSymbol: + length, err = vr.peakLength() + length += 4 + case bson.TypeMaxKey, bson.TypeMinKey, bson.TypeNull, bson.TypeUndefined: + length = 0 + case bson.TypeObjectID: + length = 12 + case bson.TypeRegex: + regex := bytes.IndexByte(vr.d[vr.offset:], 0x00) + if regex < 0 { + err = io.EOF + break + } + pattern := bytes.IndexByte(vr.d[regex+1:], 0x00) + if pattern < 0 { + err = io.EOF + break + } + length = int32(int64(regex) + 1 + int64(pattern) + 1 - vr.offset) + default: + return 0, fmt.Errorf("attempted to read bytes of unknown BSON type %v", vr.stack[vr.frame].vType) + } + + return length, err +} + +func (vr *valueReader) ReadValueBytes(dst []byte) (bson.Type, []byte, error) { switch vr.stack[vr.frame].mode { case mElement, mValue: default: - return vr.invalidTransitionErr(0) + return bson.Type(0), nil, vr.invalidTransitionErr(0) } - defer vr.pop() + length, err := vr.nextElementLength() + if err != nil { + return bson.Type(0), dst, err + } - switch vr.stack[vr.frame].vType { - case TypeArray, TypeEmbeddedDocument, TypeCodeWithScope: - length, err := vr.readLength() - if err != nil { - return err - } - return vr.skipBytes(length - 4) // skipBytes skips that exact number of bytes, so we need to take off the lenght bytes - case TypeBinary: - length, err := vr.readLength() - if err != nil { - return err - } - // We need to skip length + 1 (for the subtype byte) - return vr.skipBytes(length + 1) - case TypeBoolean: - return vr.skipBytes(1) - case TypeDBPointer: - length, err := vr.readLength() - if err != nil { - return err - } - // skip length + 12 for the ObjectID - return vr.skipBytes(length + 12) - case TypeDateTime, TypeDouble, TypeInt64, TypeTimestamp: - return vr.skipBytes(8) - case TypeDecimal128: - return vr.skipBytes(16) - case TypeInt32: - return vr.skipBytes(4) - case TypeJavaScript, TypeString, TypeSymbol: - length, err := vr.readLength() - if err != nil { - return err - } - return vr.skipBytes(length) - case TypeMaxKey, TypeMinKey, TypeNull, TypeUndefined: - return nil - case TypeObjectID: - return vr.skipBytes(12) - case TypeRegex: - err := vr.skipCString() - if err != nil { - return err - } - return vr.skipCString() + dst, err = vr.appendBytes(dst, length) + t := vr.stack[vr.frame].vType + vr.pop() + return t, dst, err +} + +func (vr *valueReader) Skip() error { + switch vr.stack[vr.frame].mode { + case mElement, mValue: default: - return fmt.Errorf("attempted to skip unknown BSON type %v", vr.stack[vr.frame].vType) + return vr.invalidTransitionErr(0) + } + + length, err := vr.nextElementLength() + if err != nil { + return err } + + err = vr.skipBytes(length) + vr.pop() + return err } func (vr *valueReader) ReadArray() (ArrayReader, error) { - if err := vr.ensureElementValue(TypeArray, mArray); err != nil { + if err := vr.ensureElementValue(bson.TypeArray, mArray); err != nil { return nil, err } @@ -243,7 +284,7 @@ func (vr *valueReader) ReadArray() (ArrayReader, error) { } func (vr *valueReader) ReadBinary() (b []byte, btype byte, err error) { - if err := vr.ensureElementValue(TypeBinary, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeBinary, 0); err != nil { return nil, 0, err } @@ -266,7 +307,7 @@ func (vr *valueReader) ReadBinary() (b []byte, btype byte, err error) { } func (vr *valueReader) ReadBoolean() (bool, error) { - if err := vr.ensureElementValue(TypeBoolean, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeBoolean, 0); err != nil { return false, err } @@ -294,8 +335,8 @@ func (vr *valueReader) ReadDocument() (DocumentReader, error) { vr.stack[vr.frame].end = int64(size) + vr.offset - 4 return vr, nil case mElement, mValue: - if vr.stack[vr.frame].vType != TypeEmbeddedDocument { - return nil, vr.typeError(TypeEmbeddedDocument) + if vr.stack[vr.frame].vType != bson.TypeEmbeddedDocument { + return nil, vr.typeError(bson.TypeEmbeddedDocument) } default: return nil, vr.invalidTransitionErr(mDocument) @@ -310,7 +351,7 @@ func (vr *valueReader) ReadDocument() (DocumentReader, error) { } func (vr *valueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) { - if err := vr.ensureElementValue(TypeCodeWithScope, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeCodeWithScope, 0); err != nil { return "", nil, err } @@ -346,7 +387,7 @@ func (vr *valueReader) ReadCodeWithScope() (code string, dr DocumentReader, err } func (vr *valueReader) ReadDBPointer() (ns string, oid objectid.ObjectID, err error) { - if err := vr.ensureElementValue(TypeDBPointer, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeDBPointer, 0); err != nil { return "", oid, err } @@ -374,7 +415,7 @@ func (vr *valueReader) ReadDBPointer() (ns string, oid objectid.ObjectID, err er } func (vr *valueReader) ReadDateTime() (int64, error) { - if err := vr.ensureElementValue(TypeDateTime, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeDateTime, 0); err != nil { return 0, err } @@ -388,7 +429,7 @@ func (vr *valueReader) ReadDateTime() (int64, error) { } func (vr *valueReader) ReadDecimal128() (decimal.Decimal128, error) { - if err := vr.ensureElementValue(TypeDecimal128, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeDecimal128, 0); err != nil { return decimal.Decimal128{}, err } @@ -405,7 +446,7 @@ func (vr *valueReader) ReadDecimal128() (decimal.Decimal128, error) { } func (vr *valueReader) ReadDouble() (float64, error) { - if err := vr.ensureElementValue(TypeDouble, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeDouble, 0); err != nil { return 0, err } @@ -419,7 +460,7 @@ func (vr *valueReader) ReadDouble() (float64, error) { } func (vr *valueReader) ReadInt32() (int32, error) { - if err := vr.ensureElementValue(TypeInt32, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeInt32, 0); err != nil { return 0, err } @@ -428,7 +469,7 @@ func (vr *valueReader) ReadInt32() (int32, error) { } func (vr *valueReader) ReadInt64() (int64, error) { - if err := vr.ensureElementValue(TypeInt64, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeInt64, 0); err != nil { return 0, err } @@ -437,7 +478,7 @@ func (vr *valueReader) ReadInt64() (int64, error) { } func (vr *valueReader) ReadJavascript() (code string, err error) { - if err := vr.ensureElementValue(TypeJavaScript, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeJavaScript, 0); err != nil { return "", err } @@ -446,7 +487,7 @@ func (vr *valueReader) ReadJavascript() (code string, err error) { } func (vr *valueReader) ReadMaxKey() error { - if err := vr.ensureElementValue(TypeMaxKey, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeMaxKey, 0); err != nil { return err } @@ -455,7 +496,7 @@ func (vr *valueReader) ReadMaxKey() error { } func (vr *valueReader) ReadMinKey() error { - if err := vr.ensureElementValue(TypeMinKey, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeMinKey, 0); err != nil { return err } @@ -464,7 +505,7 @@ func (vr *valueReader) ReadMinKey() error { } func (vr *valueReader) ReadNull() error { - if err := vr.ensureElementValue(TypeNull, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeNull, 0); err != nil { return err } @@ -473,7 +514,7 @@ func (vr *valueReader) ReadNull() error { } func (vr *valueReader) ReadObjectID() (objectid.ObjectID, error) { - if err := vr.ensureElementValue(TypeObjectID, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeObjectID, 0); err != nil { return objectid.ObjectID{}, err } @@ -490,7 +531,7 @@ func (vr *valueReader) ReadObjectID() (objectid.ObjectID, error) { } func (vr *valueReader) ReadRegex() (string, string, error) { - if err := vr.ensureElementValue(TypeRegex, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeRegex, 0); err != nil { return "", "", err } @@ -509,7 +550,7 @@ func (vr *valueReader) ReadRegex() (string, string, error) { } func (vr *valueReader) ReadString() (string, error) { - if err := vr.ensureElementValue(TypeString, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeString, 0); err != nil { return "", err } @@ -518,7 +559,7 @@ func (vr *valueReader) ReadString() (string, error) { } func (vr *valueReader) ReadSymbol() (symbol string, err error) { - if err := vr.ensureElementValue(TypeSymbol, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeSymbol, 0); err != nil { return "", err } @@ -527,7 +568,7 @@ func (vr *valueReader) ReadSymbol() (symbol string, err error) { } func (vr *valueReader) ReadTimestamp() (t uint32, i uint32, err error) { - if err := vr.ensureElementValue(TypeTimestamp, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeTimestamp, 0); err != nil { return 0, 0, err } @@ -546,7 +587,7 @@ func (vr *valueReader) ReadTimestamp() (t uint32, i uint32, err error) { } func (vr *valueReader) ReadUndefined() error { - if err := vr.ensureElementValue(TypeUndefined, 0); err != nil { + if err := vr.ensureElementValue(bson.TypeUndefined, 0); err != nil { return err } @@ -580,7 +621,7 @@ func (vr *valueReader) ReadElement() (string, ValueReader, error) { return "", nil, err } - vr.pushElement(Type(t)) + vr.pushElement(bson.Type(t)) return name, vr, nil } @@ -610,7 +651,7 @@ func (vr *valueReader) ReadValue() (ValueReader, error) { return nil, err } - vr.pushValue(Type(t)) + vr.pushValue(bson.Type(t)) return vr, nil } @@ -624,6 +665,16 @@ func (vr *valueReader) readBytes(length int32) ([]byte, error) { return vr.d[start : start+int64(length)], nil } +func (vr *valueReader) appendBytes(dst []byte, length int32) ([]byte, error) { + if vr.offset+int64(length) > int64(len(vr.d)) { + return nil, io.EOF + } + + start := vr.offset + vr.offset += int64(length) + return append(dst, vr.d[start:start+int64(length)]...), nil +} + func (vr *valueReader) skipBytes(length int32) error { if vr.offset+int64(length) > int64(len(vr.d)) { return io.EOF @@ -678,6 +729,15 @@ func (vr *valueReader) readString() (string, error) { return string(vr.d[start : start+int64(length)-1]), nil } +func (vr *valueReader) peakLength() (int32, error) { + if vr.offset+4 > int64(len(vr.d)) { + return 0, io.EOF + } + + idx := vr.offset + return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil +} + func (vr *valueReader) readLength() (int32, error) { return vr.readi32() } func (vr *valueReader) readi32() (int32, error) { diff --git a/bson/value_reader_test.go b/bson/bsoncodec/value_reader_test.go similarity index 73% rename from bson/value_reader_test.go rename to bson/bsoncodec/value_reader_test.go index 96e878b7ab..112db15f2b 100644 --- a/bson/value_reader_test.go +++ b/bson/bsoncodec/value_reader_test.go @@ -1,4 +1,4 @@ -package bson +package bsoncodec import ( "bytes" @@ -8,6 +8,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/bson/decimal" "github.com/mongodb/mongo-go-driver/bson/elements" "github.com/mongodb/mongo-go-driver/bson/objectid" @@ -22,7 +23,7 @@ func TestValueReader(t *testing.T) { btype byte b []byte err error - vType Type + vType bson.Type }{ { "incorrect type", @@ -30,8 +31,8 @@ func TestValueReader(t *testing.T) { 0, 0, nil, - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeBinary), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeBinary), + bson.TypeEmbeddedDocument, }, { "length too short", @@ -40,7 +41,7 @@ func TestValueReader(t *testing.T) { 0, nil, io.EOF, - TypeBinary, + bson.TypeBinary, }, { "no byte available", @@ -49,7 +50,7 @@ func TestValueReader(t *testing.T) { 0, nil, io.EOF, - TypeBinary, + bson.TypeBinary, }, { "not enough bytes for binary", @@ -58,7 +59,7 @@ func TestValueReader(t *testing.T) { 0, nil, io.EOF, - TypeBinary, + bson.TypeBinary, }, { "success", @@ -67,7 +68,7 @@ func TestValueReader(t *testing.T) { 0xEA, []byte{0x01, 0x02, 0x03}, nil, - TypeBinary, + bson.TypeBinary, }, } @@ -106,15 +107,15 @@ func TestValueReader(t *testing.T) { offset int64 boolean bool err error - vType Type + vType bson.Type }{ { "incorrect type", []byte{}, 0, false, - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeBoolean), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeBoolean), + bson.TypeEmbeddedDocument, }, { "no byte available", @@ -122,7 +123,7 @@ func TestValueReader(t *testing.T) { 0, false, io.EOF, - TypeBoolean, + bson.TypeBoolean, }, { "invalid byte for boolean", @@ -130,7 +131,7 @@ func TestValueReader(t *testing.T) { 0, false, fmt.Errorf("invalid byte for boolean, %b", 0x03), - TypeBoolean, + bson.TypeBoolean, }, { "success", @@ -138,7 +139,7 @@ func TestValueReader(t *testing.T) { 0, true, nil, - TypeBoolean, + bson.TypeBoolean, }, } @@ -195,12 +196,12 @@ func TestValueReader(t *testing.T) { offset: 0, stack: []vrState{ {mode: mTopLevel}, - {mode: mElement, vType: TypeBoolean}, + {mode: mElement, vType: bson.TypeBoolean}, }, frame: 1, } - var wanterr = (&valueReader{stack: []vrState{{mode: mElement, vType: TypeBoolean}}}).typeError(TypeEmbeddedDocument) + var wanterr = (&valueReader{stack: []vrState{{mode: mElement, vType: bson.TypeBoolean}}}).typeError(bson.TypeEmbeddedDocument) _, err := vr.ReadDocument() if err == nil || err.Error() != wanterr.Error() { t.Errorf("Incorrect returned error. got %v; want %v", err, wanterr) @@ -213,7 +214,7 @@ func TestValueReader(t *testing.T) { t.Errorf("Incorrect returned error. got %v; want %v", err, wanterr) } - vr.stack[1].mode, vr.stack[1].vType = mElement, TypeEmbeddedDocument + vr.stack[1].mode, vr.stack[1].vType = mElement, bson.TypeEmbeddedDocument vr.d = []byte{0x0A, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00} vr.offset = 4 _, err = vr.ReadDocument() @@ -258,49 +259,49 @@ func TestValueReader(t *testing.T) { data []byte offset int64 err error - vType Type + vType bson.Type }{ { "incorrect type", []byte{}, 0, - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeCodeWithScope), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeCodeWithScope), + bson.TypeEmbeddedDocument, }, { "total length not enough bytes", []byte{}, 0, io.EOF, - TypeCodeWithScope, + bson.TypeCodeWithScope, }, { "string length not enough bytes", codeWithScope[:4], 0, io.EOF, - TypeCodeWithScope, + bson.TypeCodeWithScope, }, { "not enough string bytes", codeWithScope[:8], 0, io.EOF, - TypeCodeWithScope, + bson.TypeCodeWithScope, }, { "document length not enough bytes", codeWithScope[:12], 0, io.EOF, - TypeCodeWithScope, + bson.TypeCodeWithScope, }, { "length mismatch", mismatchCodeWithScope, 0, fmt.Errorf("length of CodeWithScope does not match lengths of components; total: %d; components: %d", 17, 19), - TypeCodeWithScope, + bson.TypeCodeWithScope, }, } @@ -335,7 +336,7 @@ func TestValueReader(t *testing.T) { d: doc, stack: []vrState{ {mode: mTopLevel}, - {mode: mElement, vType: TypeCodeWithScope}, + {mode: mElement, vType: bson.TypeCodeWithScope}, }, frame: 1, } @@ -367,7 +368,7 @@ func TestValueReader(t *testing.T) { ns string oid objectid.ObjectID err error - vType Type + vType bson.Type }{ { "incorrect type", @@ -375,8 +376,8 @@ func TestValueReader(t *testing.T) { 0, "", objectid.ObjectID{}, - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeDBPointer), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeDBPointer), + bson.TypeEmbeddedDocument, }, { "length too short", @@ -385,7 +386,7 @@ func TestValueReader(t *testing.T) { "", objectid.ObjectID{}, io.EOF, - TypeDBPointer, + bson.TypeDBPointer, }, { "not enough bytes for namespace", @@ -394,7 +395,7 @@ func TestValueReader(t *testing.T) { "", objectid.ObjectID{}, io.EOF, - TypeDBPointer, + bson.TypeDBPointer, }, { "not enough bytes for objectID", @@ -403,7 +404,7 @@ func TestValueReader(t *testing.T) { "", objectid.ObjectID{}, io.EOF, - TypeDBPointer, + bson.TypeDBPointer, }, { "success", @@ -415,7 +416,7 @@ func TestValueReader(t *testing.T) { "foo", objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, nil, - TypeDBPointer, + bson.TypeDBPointer, }, } @@ -454,15 +455,15 @@ func TestValueReader(t *testing.T) { offset int64 dt int64 err error - vType Type + vType bson.Type }{ { "incorrect type", []byte{}, 0, 0, - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeDateTime), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeDateTime), + bson.TypeEmbeddedDocument, }, { "length too short", @@ -470,7 +471,7 @@ func TestValueReader(t *testing.T) { 0, 0, io.EOF, - TypeDateTime, + bson.TypeDateTime, }, { "success", @@ -478,7 +479,7 @@ func TestValueReader(t *testing.T) { 0, 255, nil, - TypeDateTime, + bson.TypeDateTime, }, } @@ -514,15 +515,15 @@ func TestValueReader(t *testing.T) { offset int64 dc128 decimal.Decimal128 err error - vType Type + vType bson.Type }{ { "incorrect type", []byte{}, 0, decimal.Decimal128{}, - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeDecimal128), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeDecimal128), + bson.TypeEmbeddedDocument, }, { "length too short", @@ -530,7 +531,7 @@ func TestValueReader(t *testing.T) { 0, decimal.Decimal128{}, io.EOF, - TypeDecimal128, + bson.TypeDecimal128, }, { "success", @@ -541,7 +542,7 @@ func TestValueReader(t *testing.T) { 0, decimal.NewDecimal128(65280, 255), nil, - TypeDecimal128, + bson.TypeDecimal128, }, } @@ -582,15 +583,15 @@ func TestValueReader(t *testing.T) { offset int64 double float64 err error - vType Type + vType bson.Type }{ { "incorrect type", []byte{}, 0, 0, - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeDouble), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeDouble), + bson.TypeEmbeddedDocument, }, { "length too short", @@ -598,7 +599,7 @@ func TestValueReader(t *testing.T) { 0, 0, io.EOF, - TypeDouble, + bson.TypeDouble, }, { "success", @@ -606,7 +607,7 @@ func TestValueReader(t *testing.T) { 0, math.Float64frombits(255), nil, - TypeDouble, + bson.TypeDouble, }, } @@ -642,15 +643,15 @@ func TestValueReader(t *testing.T) { offset int64 i32 int32 err error - vType Type + vType bson.Type }{ { "incorrect type", []byte{}, 0, 0, - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeInt32), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeInt32), + bson.TypeEmbeddedDocument, }, { "length too short", @@ -658,7 +659,7 @@ func TestValueReader(t *testing.T) { 0, 0, io.EOF, - TypeInt32, + bson.TypeInt32, }, { "success", @@ -666,7 +667,7 @@ func TestValueReader(t *testing.T) { 0, 255, nil, - TypeInt32, + bson.TypeInt32, }, } @@ -702,15 +703,15 @@ func TestValueReader(t *testing.T) { offset int64 i64 int64 err error - vType Type + vType bson.Type }{ { "incorrect type", []byte{}, 0, 0, - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeInt64), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeInt64), + bson.TypeEmbeddedDocument, }, { "length too short", @@ -718,7 +719,7 @@ func TestValueReader(t *testing.T) { 0, 0, io.EOF, - TypeInt64, + bson.TypeInt64, }, { "success", @@ -726,7 +727,7 @@ func TestValueReader(t *testing.T) { 0, 255, nil, - TypeInt64, + bson.TypeInt64, }, } @@ -763,7 +764,7 @@ func TestValueReader(t *testing.T) { fn func(*valueReader) (string, error) css string // code, string, symbol :P err error - vType Type + vType bson.Type }{ { "ReadJavascript/incorrect type", @@ -771,8 +772,8 @@ func TestValueReader(t *testing.T) { 0, (*valueReader).ReadJavascript, "", - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeJavaScript), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeJavaScript), + bson.TypeEmbeddedDocument, }, { "ReadString/incorrect type", @@ -780,8 +781,8 @@ func TestValueReader(t *testing.T) { 0, (*valueReader).ReadString, "", - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeString), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeString), + bson.TypeEmbeddedDocument, }, { "ReadSymbol/incorrect type", @@ -789,8 +790,8 @@ func TestValueReader(t *testing.T) { 0, (*valueReader).ReadSymbol, "", - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeSymbol), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeSymbol), + bson.TypeEmbeddedDocument, }, { "ReadJavascript/length too short", @@ -799,7 +800,7 @@ func TestValueReader(t *testing.T) { (*valueReader).ReadJavascript, "", io.EOF, - TypeJavaScript, + bson.TypeJavaScript, }, { "ReadString/length too short", @@ -808,7 +809,7 @@ func TestValueReader(t *testing.T) { (*valueReader).ReadString, "", io.EOF, - TypeString, + bson.TypeString, }, { "ReadSymbol/length too short", @@ -817,7 +818,7 @@ func TestValueReader(t *testing.T) { (*valueReader).ReadSymbol, "", io.EOF, - TypeSymbol, + bson.TypeSymbol, }, { "ReadJavascript/incorrect end byte", @@ -826,7 +827,7 @@ func TestValueReader(t *testing.T) { (*valueReader).ReadJavascript, "", fmt.Errorf("string does not end with null byte, but with %v", 0x05), - TypeJavaScript, + bson.TypeJavaScript, }, { "ReadString/incorrect end byte", @@ -835,7 +836,7 @@ func TestValueReader(t *testing.T) { (*valueReader).ReadString, "", fmt.Errorf("string does not end with null byte, but with %v", 0x05), - TypeString, + bson.TypeString, }, { "ReadSymbol/incorrect end byte", @@ -844,7 +845,7 @@ func TestValueReader(t *testing.T) { (*valueReader).ReadSymbol, "", fmt.Errorf("string does not end with null byte, but with %v", 0x05), - TypeSymbol, + bson.TypeSymbol, }, { "ReadJavascript/success", @@ -853,7 +854,7 @@ func TestValueReader(t *testing.T) { (*valueReader).ReadJavascript, "foo", nil, - TypeJavaScript, + bson.TypeJavaScript, }, { "ReadString/success", @@ -862,7 +863,7 @@ func TestValueReader(t *testing.T) { (*valueReader).ReadString, "foo", nil, - TypeString, + bson.TypeString, }, { "ReadSymbol/success", @@ -871,7 +872,7 @@ func TestValueReader(t *testing.T) { (*valueReader).ReadSymbol, "foo", nil, - TypeSymbol, + bson.TypeSymbol, }, } @@ -905,55 +906,55 @@ func TestValueReader(t *testing.T) { name string fn func(*valueReader) error err error - vType Type + vType bson.Type }{ { "ReadMaxKey/incorrect type", (*valueReader).ReadMaxKey, - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeMaxKey), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeMaxKey), + bson.TypeEmbeddedDocument, }, { "ReadMaxKey/success", (*valueReader).ReadMaxKey, nil, - TypeMaxKey, + bson.TypeMaxKey, }, { "ReadMinKey/incorrect type", (*valueReader).ReadMinKey, - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeMinKey), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeMinKey), + bson.TypeEmbeddedDocument, }, { "ReadMinKey/success", (*valueReader).ReadMinKey, nil, - TypeMinKey, + bson.TypeMinKey, }, { "ReadNull/incorrect type", (*valueReader).ReadNull, - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeNull), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeNull), + bson.TypeEmbeddedDocument, }, { "ReadNull/success", (*valueReader).ReadNull, nil, - TypeNull, + bson.TypeNull, }, { "ReadUndefined/incorrect type", (*valueReader).ReadUndefined, - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeUndefined), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeUndefined), + bson.TypeEmbeddedDocument, }, { "ReadUndefined/success", (*valueReader).ReadUndefined, nil, - TypeUndefined, + bson.TypeUndefined, }, } @@ -984,15 +985,15 @@ func TestValueReader(t *testing.T) { offset int64 oid objectid.ObjectID err error - vType Type + vType bson.Type }{ { "incorrect type", []byte{}, 0, objectid.ObjectID{}, - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeObjectID), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeObjectID), + bson.TypeEmbeddedDocument, }, { "not enough bytes for objectID", @@ -1000,7 +1001,7 @@ func TestValueReader(t *testing.T) { 0, objectid.ObjectID{}, io.EOF, - TypeObjectID, + bson.TypeObjectID, }, { "success", @@ -1008,7 +1009,7 @@ func TestValueReader(t *testing.T) { 0, objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, nil, - TypeObjectID, + bson.TypeObjectID, }, } @@ -1045,7 +1046,7 @@ func TestValueReader(t *testing.T) { pattern string options string err error - vType Type + vType bson.Type }{ { "incorrect type", @@ -1053,8 +1054,8 @@ func TestValueReader(t *testing.T) { 0, "", "", - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeRegex), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeRegex), + bson.TypeEmbeddedDocument, }, { "length too short", @@ -1063,7 +1064,7 @@ func TestValueReader(t *testing.T) { "", "", io.EOF, - TypeRegex, + bson.TypeRegex, }, { "not enough bytes for options", @@ -1072,7 +1073,7 @@ func TestValueReader(t *testing.T) { "", "", io.EOF, - TypeRegex, + bson.TypeRegex, }, { "success", @@ -1081,7 +1082,7 @@ func TestValueReader(t *testing.T) { "foo", "bar", nil, - TypeRegex, + bson.TypeRegex, }, } @@ -1121,7 +1122,7 @@ func TestValueReader(t *testing.T) { ts uint32 incr uint32 err error - vType Type + vType bson.Type }{ { "incorrect type", @@ -1129,8 +1130,8 @@ func TestValueReader(t *testing.T) { 0, 0, 0, - (&valueReader{stack: []vrState{{vType: TypeEmbeddedDocument}}, frame: 0}).typeError(TypeTimestamp), - TypeEmbeddedDocument, + (&valueReader{stack: []vrState{{vType: bson.TypeEmbeddedDocument}}, frame: 0}).typeError(bson.TypeTimestamp), + bson.TypeEmbeddedDocument, }, { "not enough bytes for increment", @@ -1139,7 +1140,7 @@ func TestValueReader(t *testing.T) { 0, 0, io.EOF, - TypeTimestamp, + bson.TypeTimestamp, }, { "not enough bytes for timestamp", @@ -1148,7 +1149,7 @@ func TestValueReader(t *testing.T) { 0, 0, io.EOF, - TypeTimestamp, + bson.TypeTimestamp, }, { "success", @@ -1157,7 +1158,7 @@ func TestValueReader(t *testing.T) { 256, 255, nil, - TypeTimestamp, + bson.TypeTimestamp, }, } @@ -1190,8 +1191,8 @@ func TestValueReader(t *testing.T) { } }) - t.Run("Skip", func(t *testing.T) { - docb, err := NewDocument(EC.Null("foobar")).MarshalBSON() + t.Run("ReadBytes & Skip", func(t *testing.T) { + docb, err := bson.NewDocument(bson.EC.Null("foobar")).MarshalBSON() noerr(t, err) cwsbytes := make([]byte, 41) _, err = elements.CodeWithScope.Encode(0, cwsbytes, "var hellow = world;", docb) @@ -1199,215 +1200,248 @@ func TestValueReader(t *testing.T) { strbytes := []byte{0x04, 0x00, 0x00, 0x00, 'f', 'o', 'o', 0x00} testCases := []struct { name string - t Type + t bson.Type data []byte err error offset int64 }{ { "Array/invalid length", - TypeArray, + bson.TypeArray, []byte{0x01, 0x02, 0x03}, io.EOF, 0, }, { "Array/not enough bytes", - TypeArray, + bson.TypeArray, []byte{0x0F, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03}, - io.EOF, 4, + io.EOF, 0, }, { "Array/success", - TypeArray, + bson.TypeArray, []byte{0x08, 0x00, 0x00, 0x00, 0x0A, '1', 0x00, 0x00}, nil, 8, }, { "EmbeddedDocument/invalid length", - TypeEmbeddedDocument, + bson.TypeEmbeddedDocument, []byte{0x01, 0x02, 0x03}, io.EOF, 0, }, { "EmbeddedDocument/not enough bytes", - TypeEmbeddedDocument, + bson.TypeEmbeddedDocument, []byte{0x0F, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03}, - io.EOF, 4, + io.EOF, 0, }, { "EmbeddedDocument/success", - TypeEmbeddedDocument, + bson.TypeEmbeddedDocument, []byte{0x08, 0x00, 0x00, 0x00, 0x0A, 'A', 0x00, 0x00}, nil, 8, }, { "CodeWithScope/invalid length", - TypeCodeWithScope, + bson.TypeCodeWithScope, []byte{0x01, 0x02, 0x03}, io.EOF, 0, }, { "CodeWithScope/not enough bytes", - TypeCodeWithScope, + bson.TypeCodeWithScope, []byte{0x0F, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03}, - io.EOF, 4, + io.EOF, 0, }, { "CodeWithScope/success", - TypeCodeWithScope, + bson.TypeCodeWithScope, cwsbytes, nil, 41, }, { "Binary/invalid length", - TypeBinary, + bson.TypeBinary, []byte{0x01, 0x02, 0x03}, io.EOF, 0, }, { "Binary/not enough bytes", - TypeBinary, + bson.TypeBinary, []byte{0x0F, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03}, - io.EOF, 4, + io.EOF, 0, }, { "Binary/success", - TypeBinary, + bson.TypeBinary, []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03}, nil, 8, }, { "Boolean/invalid length", - TypeBoolean, + bson.TypeBoolean, []byte{}, io.EOF, 0, }, { "Boolean/success", - TypeBoolean, + bson.TypeBoolean, []byte{0x01}, nil, 1, }, { "DBPointer/invalid length", - TypeDBPointer, + bson.TypeDBPointer, []byte{0x01, 0x02, 0x03}, io.EOF, 0, }, { "DBPointer/not enough bytes", - TypeDBPointer, + bson.TypeDBPointer, []byte{0x0F, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03}, - io.EOF, 4, + io.EOF, 0, }, { "DBPointer/success", - TypeDBPointer, + bson.TypeDBPointer, []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, nil, 17, }, - {"DBPointer/not enough bytes", TypeDateTime, []byte{0x01, 0x02, 0x03, 0x04}, io.EOF, 0}, - {"DBPointer/success", TypeDateTime, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, nil, 8}, - {"Double/not enough bytes", TypeDouble, []byte{0x01, 0x02, 0x03, 0x04}, io.EOF, 0}, - {"Double/success", TypeDouble, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, nil, 8}, - {"Int64/not enough bytes", TypeInt64, []byte{0x01, 0x02, 0x03, 0x04}, io.EOF, 0}, - {"Int64/success", TypeInt64, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, nil, 8}, - {"Timestamp/not enough bytes", TypeTimestamp, []byte{0x01, 0x02, 0x03, 0x04}, io.EOF, 0}, - {"Timestamp/success", TypeTimestamp, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, nil, 8}, + {"DBPointer/not enough bytes", bson.TypeDateTime, []byte{0x01, 0x02, 0x03, 0x04}, io.EOF, 0}, + {"DBPointer/success", bson.TypeDateTime, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, nil, 8}, + {"Double/not enough bytes", bson.TypeDouble, []byte{0x01, 0x02, 0x03, 0x04}, io.EOF, 0}, + {"Double/success", bson.TypeDouble, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, nil, 8}, + {"Int64/not enough bytes", bson.TypeInt64, []byte{0x01, 0x02, 0x03, 0x04}, io.EOF, 0}, + {"Int64/success", bson.TypeInt64, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, nil, 8}, + {"Timestamp/not enough bytes", bson.TypeTimestamp, []byte{0x01, 0x02, 0x03, 0x04}, io.EOF, 0}, + {"Timestamp/success", bson.TypeTimestamp, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, nil, 8}, { "Decimal128/not enough bytes", - TypeDecimal128, + bson.TypeDecimal128, []byte{0x01, 0x02, 0x03, 0x04}, io.EOF, 0, }, { "Decimal128/success", - TypeDecimal128, + bson.TypeDecimal128, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10}, nil, 16, }, - {"Int32/not enough bytes", TypeInt32, []byte{0x01, 0x02}, io.EOF, 0}, - {"Int32/success", TypeInt32, []byte{0x01, 0x02, 0x03, 0x04}, nil, 4}, - {"Javascript/invalid length", TypeJavaScript, strbytes[:2], io.EOF, 0}, - {"Javascript/not enough bytes", TypeJavaScript, strbytes[:5], io.EOF, 4}, - {"Javascript/success", TypeJavaScript, strbytes, nil, 8}, - {"String/invalid length", TypeString, strbytes[:2], io.EOF, 0}, - {"String/not enough bytes", TypeString, strbytes[:5], io.EOF, 4}, - {"String/success", TypeString, strbytes, nil, 8}, - {"Symbol/invalid length", TypeSymbol, strbytes[:2], io.EOF, 0}, - {"Symbol/not enough bytes", TypeSymbol, strbytes[:5], io.EOF, 4}, - {"Symbol/success", TypeSymbol, strbytes, nil, 8}, - {"MaxKey/success", TypeMaxKey, []byte{0x01}, nil, 0}, - {"MinKey/success", TypeMinKey, []byte{0x01}, nil, 0}, - {"Null/success", TypeNull, []byte{0x01}, nil, 0}, - {"Undefined/success", TypeUndefined, []byte{0x01}, nil, 0}, + {"Int32/not enough bytes", bson.TypeInt32, []byte{0x01, 0x02}, io.EOF, 0}, + {"Int32/success", bson.TypeInt32, []byte{0x01, 0x02, 0x03, 0x04}, nil, 4}, + {"Javascript/invalid length", bson.TypeJavaScript, strbytes[:2], io.EOF, 0}, + {"Javascript/not enough bytes", bson.TypeJavaScript, strbytes[:5], io.EOF, 0}, + {"Javascript/success", bson.TypeJavaScript, strbytes, nil, 8}, + {"String/invalid length", bson.TypeString, strbytes[:2], io.EOF, 0}, + {"String/not enough bytes", bson.TypeString, strbytes[:5], io.EOF, 0}, + {"String/success", bson.TypeString, strbytes, nil, 8}, + {"Symbol/invalid length", bson.TypeSymbol, strbytes[:2], io.EOF, 0}, + {"Symbol/not enough bytes", bson.TypeSymbol, strbytes[:5], io.EOF, 0}, + {"Symbol/success", bson.TypeSymbol, strbytes, nil, 8}, + {"MaxKey/success", bson.TypeMaxKey, []byte{}, nil, 0}, + {"MinKey/success", bson.TypeMinKey, []byte{}, nil, 0}, + {"Null/success", bson.TypeNull, []byte{}, nil, 0}, + {"Undefined/success", bson.TypeUndefined, []byte{}, nil, 0}, { "ObjectID/not enough bytes", - TypeObjectID, + bson.TypeObjectID, []byte{0x01, 0x02, 0x03, 0x04}, io.EOF, 0, }, { "ObjectID/success", - TypeObjectID, + bson.TypeObjectID, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, nil, 12, }, { "Regex/not enough bytes (first string)", - TypeRegex, + bson.TypeRegex, []byte{'f', 'o', 'o'}, io.EOF, 0, }, { "Regex/not enough bytes (second string)", - TypeRegex, + bson.TypeRegex, []byte{'f', 'o', 'o', 0x00, 'b', 'a', 'r'}, - io.EOF, 4, + io.EOF, 0, }, { "Regex/success", - TypeRegex, + bson.TypeRegex, []byte{'f', 'o', 'o', 0x00, 'b', 'a', 'r', 0x00}, nil, 8, }, { "Unknown Type", - Type(0), + bson.Type(0), nil, - fmt.Errorf("attempted to skip unknown BSON type %v", Type(0)), 0, + fmt.Errorf("attempted to read bytes of unknown BSON type %v", bson.Type(0)), 0, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - vr := &valueReader{ - d: tc.data, - stack: []vrState{ - {mode: mTopLevel}, - {mode: mElement, vType: tc.t}, - }, - frame: 1, - } + t.Run("Skip", func(t *testing.T) { + vr := &valueReader{ + d: tc.data, + stack: []vrState{ + {mode: mTopLevel}, + {mode: mElement, vType: tc.t}, + }, + frame: 1, + } - err := vr.Skip() - if !errequal(t, err, tc.err) { - t.Errorf("Did not receive expected error; got %v; want %v", err, tc.err) - } - if vr.offset != tc.offset { - t.Errorf("Offset not set at correct position; got %d; want %d", vr.offset, tc.offset) - } + err := vr.Skip() + if !errequal(t, err, tc.err) { + t.Errorf("Did not receive expected error; got %v; want %v", err, tc.err) + } + if tc.err == nil && vr.offset != tc.offset { + t.Errorf("Offset not set at correct position; got %d; want %d", vr.offset, tc.offset) + } + }) + t.Run("ReadBytes", func(t *testing.T) { + vr := &valueReader{ + d: tc.data, + stack: []vrState{ + {mode: mTopLevel}, + {mode: mElement, vType: tc.t}, + }, + frame: 1, + } + + _, got, err := vr.ReadValueBytes(nil) + if !errequal(t, err, tc.err) { + t.Errorf("Did not receive expected error; got %v; want %v", err, tc.err) + } + if tc.err == nil && vr.offset != tc.offset { + t.Errorf("Offset not set at correct position; got %d; want %d", vr.offset, tc.offset) + } + if tc.err == nil && !bytes.Equal(got, tc.data) { + t.Errorf("Did not receive expected bytes. got %v; want %v", got, tc.data) + } + }) }) } }) t.Run("invalid transition", func(t *testing.T) { - vr := &valueReader{stack: []vrState{{mode: mTopLevel}}} - wanterr := (&valueReader{stack: []vrState{{mode: mTopLevel}}}).invalidTransitionErr(0) - goterr := vr.Skip() - if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) { - t.Errorf("Expected correct invalid transition error. got %v; want %v", goterr, wanterr) - } + t.Run("Skip", func(t *testing.T) { + vr := &valueReader{stack: []vrState{{mode: mTopLevel}}} + wanterr := (&valueReader{stack: []vrState{{mode: mTopLevel}}}).invalidTransitionErr(0) + goterr := vr.Skip() + if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) { + t.Errorf("Expected correct invalid transition error. got %v; want %v", goterr, wanterr) + } + }) + t.Run("ReadBytes", func(t *testing.T) { + vr := &valueReader{stack: []vrState{{mode: mTopLevel}}} + wanterr := (&valueReader{stack: []vrState{{mode: mTopLevel}}}).invalidTransitionErr(0) + _, _, goterr := vr.ReadValueBytes(nil) + if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) { + t.Errorf("Expected correct invalid transition error. got %v; want %v", goterr, wanterr) + } + }) }) } diff --git a/bson/value_writer.go b/bson/bsoncodec/value_writer.go similarity index 95% rename from bson/value_writer.go rename to bson/bsoncodec/value_writer.go index 5c044e3fb1..ede02bc06c 100644 --- a/bson/value_writer.go +++ b/bson/bsoncodec/value_writer.go @@ -1,4 +1,4 @@ -package bson +package bsoncodec import ( "errors" @@ -8,6 +8,7 @@ import ( "strconv" "sync" + "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/bson/decimal" "github.com/mongodb/mongo-go-driver/bson/internal/llbson" "github.com/mongodb/mongo-go-driver/bson/objectid" @@ -165,10 +166,12 @@ func (vw *valueWriter) reset(buf []byte) { vw.stack = vw.stack[:1] vw.stack[0] = vwState{mode: mTopLevel} vw.buf = buf + vw.frame = 0 + vw.w = nil } func (vw *valueWriter) invalidTransitionError(destination mode) error { - te := transitionError{ + te := TransitionError{ current: vw.stack[vw.frame].mode, destination: destination, } @@ -192,6 +195,16 @@ func (vw *valueWriter) writeElementHeader(t llbson.Type, destination mode) error return nil } +func (vw *valueWriter) WriteValueBytes(t bson.Type, b []byte) error { + err := vw.writeElementHeader(llbson.Type(t), mode(0)) + if err != nil { + return err + } + vw.buf = append(vw.buf, b...) + vw.pop() + return nil +} + func (vw *valueWriter) WriteArray() (ArrayWriter, error) { if err := vw.writeElementHeader(llbson.TypeArray, mArray); err != nil { return nil, err @@ -210,9 +223,9 @@ func (vw *valueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error { if err := vw.writeElementHeader(llbson.TypeBinary, mode(0)); err != nil { return err } - defer vw.pop() vw.buf = llbson.AppendBinary(vw.buf, btype, b) + vw.pop() return nil } @@ -220,9 +233,9 @@ func (vw *valueWriter) WriteBoolean(b bool) error { if err := vw.writeElementHeader(llbson.TypeBoolean, mode(0)); err != nil { return err } - defer vw.pop() vw.buf = llbson.AppendBoolean(vw.buf, b) + vw.pop() return nil } @@ -247,9 +260,9 @@ func (vw *valueWriter) WriteDBPointer(ns string, oid objectid.ObjectID) error { if err := vw.writeElementHeader(llbson.TypeDBPointer, mode(0)); err != nil { return err } - defer vw.pop() vw.buf = llbson.AppendDBPointer(vw.buf, ns, oid) + vw.pop() return nil } @@ -257,9 +270,9 @@ func (vw *valueWriter) WriteDateTime(dt int64) error { if err := vw.writeElementHeader(llbson.TypeDateTime, mode(0)); err != nil { return err } - defer vw.pop() vw.buf = llbson.AppendDateTime(vw.buf, dt) + vw.pop() return nil } @@ -267,9 +280,9 @@ func (vw *valueWriter) WriteDecimal128(d128 decimal.Decimal128) error { if err := vw.writeElementHeader(llbson.TypeDecimal128, mode(0)); err != nil { return err } - defer vw.pop() vw.buf = llbson.AppendDecimal128(vw.buf, d128) + vw.pop() return nil } @@ -277,9 +290,9 @@ func (vw *valueWriter) WriteDouble(f float64) error { if err := vw.writeElementHeader(llbson.TypeDouble, mode(0)); err != nil { return err } - defer vw.pop() vw.buf = llbson.AppendDouble(vw.buf, f) + vw.pop() return nil } @@ -287,9 +300,9 @@ func (vw *valueWriter) WriteInt32(i32 int32) error { if err := vw.writeElementHeader(llbson.TypeInt32, mode(0)); err != nil { return err } - defer vw.pop() vw.buf = llbson.AppendInt32(vw.buf, i32) + vw.pop() return nil } @@ -297,9 +310,9 @@ func (vw *valueWriter) WriteInt64(i64 int64) error { if err := vw.writeElementHeader(llbson.TypeInt64, mode(0)); err != nil { return err } - defer vw.pop() vw.buf = llbson.AppendInt64(vw.buf, i64) + vw.pop() return nil } @@ -307,9 +320,9 @@ func (vw *valueWriter) WriteJavascript(code string) error { if err := vw.writeElementHeader(llbson.TypeJavaScript, mode(0)); err != nil { return err } - defer vw.pop() vw.buf = llbson.AppendJavaScript(vw.buf, code) + vw.pop() return nil } @@ -317,8 +330,8 @@ func (vw *valueWriter) WriteMaxKey() error { if err := vw.writeElementHeader(llbson.TypeMaxKey, mode(0)); err != nil { return err } - defer vw.pop() + vw.pop() return nil } @@ -326,8 +339,8 @@ func (vw *valueWriter) WriteMinKey() error { if err := vw.writeElementHeader(llbson.TypeMinKey, mode(0)); err != nil { return err } - defer vw.pop() + vw.pop() return nil } @@ -335,8 +348,8 @@ func (vw *valueWriter) WriteNull() error { if err := vw.writeElementHeader(llbson.TypeNull, mode(0)); err != nil { return err } - defer vw.pop() + vw.pop() return nil } @@ -344,9 +357,9 @@ func (vw *valueWriter) WriteObjectID(oid objectid.ObjectID) error { if err := vw.writeElementHeader(llbson.TypeObjectID, mode(0)); err != nil { return err } - defer vw.pop() vw.buf = llbson.AppendObjectID(vw.buf, oid) + vw.pop() return nil } @@ -354,9 +367,9 @@ func (vw *valueWriter) WriteRegex(pattern string, options string) error { if err := vw.writeElementHeader(llbson.TypeRegex, mode(0)); err != nil { return err } - defer vw.pop() vw.buf = llbson.AppendRegex(vw.buf, pattern, options) + vw.pop() return nil } @@ -364,9 +377,9 @@ func (vw *valueWriter) WriteString(s string) error { if err := vw.writeElementHeader(llbson.TypeString, mode(0)); err != nil { return err } - defer vw.pop() vw.buf = llbson.AppendString(vw.buf, s) + vw.pop() return nil } @@ -387,9 +400,9 @@ func (vw *valueWriter) WriteSymbol(symbol string) error { if err := vw.writeElementHeader(llbson.TypeSymbol, mode(0)); err != nil { return err } - defer vw.pop() vw.buf = llbson.AppendSymbol(vw.buf, symbol) + vw.pop() return nil } @@ -397,9 +410,9 @@ func (vw *valueWriter) WriteTimestamp(t uint32, i uint32) error { if err := vw.writeElementHeader(llbson.TypeTimestamp, mode(0)); err != nil { return err } - defer vw.pop() vw.buf = llbson.AppendTimestamp(vw.buf, t, i) + vw.pop() return nil } @@ -407,8 +420,8 @@ func (vw *valueWriter) WriteUndefined() error { if err := vw.writeElementHeader(llbson.TypeUndefined, mode(0)); err != nil { return err } - defer vw.pop() + vw.pop() return nil } diff --git a/bson/value_writer_test.go b/bson/bsoncodec/value_writer_test.go similarity index 84% rename from bson/value_writer_test.go rename to bson/bsoncodec/value_writer_test.go index 8b510bae17..3e65bba248 100644 --- a/bson/value_writer_test.go +++ b/bson/bsoncodec/value_writer_test.go @@ -1,4 +1,4 @@ -package bson +package bsoncodec import ( "bytes" @@ -9,12 +9,13 @@ import ( "reflect" "testing" + "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/bson/decimal" "github.com/mongodb/mongo-go-driver/bson/internal/llbson" "github.com/mongodb/mongo-go-driver/bson/objectid" ) -func bytesFromDoc(doc *Document) []byte { +func bytesFromDoc(doc *bson.Document) []byte { b, err := doc.MarshalBSON() if err != nil { panic(fmt.Errorf("Couldn't marshal BSON document: %v", err)) @@ -210,7 +211,7 @@ func TestValueWriter(t *testing.T) { vw = newValueWriter(ioutil.Discard) results := fn.Call(params) got := results[0].Interface().(error) - want := transitionError{current: mTopLevel} + want := TransitionError{current: mTopLevel} if !compareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } @@ -221,7 +222,7 @@ func TestValueWriter(t *testing.T) { t.Run("WriteArray", func(t *testing.T) { vw := newValueWriter(ioutil.Discard) vw.push(mArray) - want := transitionError{current: mArray, destination: mArray, parent: mTopLevel} + want := TransitionError{current: mArray, destination: mArray, parent: mTopLevel} _, got := vw.WriteArray() if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) @@ -230,7 +231,7 @@ func TestValueWriter(t *testing.T) { t.Run("WriteCodeWithScope", func(t *testing.T) { vw := newValueWriter(ioutil.Discard) vw.push(mArray) - want := transitionError{current: mArray, destination: mCodeWithScope, parent: mTopLevel} + want := TransitionError{current: mArray, destination: mCodeWithScope, parent: mTopLevel} _, got := vw.WriteCodeWithScope("") if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) @@ -239,7 +240,7 @@ func TestValueWriter(t *testing.T) { t.Run("WriteDocument", func(t *testing.T) { vw := newValueWriter(ioutil.Discard) vw.push(mArray) - want := transitionError{current: mArray, destination: mDocument, parent: mTopLevel} + want := TransitionError{current: mArray, destination: mDocument, parent: mTopLevel} _, got := vw.WriteDocument() if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) @@ -248,7 +249,7 @@ func TestValueWriter(t *testing.T) { t.Run("WriteDocumentElement", func(t *testing.T) { vw := newValueWriter(ioutil.Discard) vw.push(mElement) - want := transitionError{current: mElement, destination: mElement, parent: mTopLevel} + want := TransitionError{current: mElement, destination: mElement, parent: mTopLevel} _, got := vw.WriteDocumentElement("") if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) @@ -281,7 +282,7 @@ func TestValueWriter(t *testing.T) { t.Run("WriteArrayElement", func(t *testing.T) { vw := newValueWriter(ioutil.Discard) vw.push(mElement) - want := transitionError{current: mElement, destination: mValue, parent: mTopLevel} + want := TransitionError{current: mElement, destination: mValue, parent: mTopLevel} _, got := vw.WriteArrayElement() if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) @@ -305,6 +306,37 @@ func TestValueWriter(t *testing.T) { } maxSize = math.MaxInt32 }) + + t.Run("WriteBytes", func(t *testing.T) { + t.Run("writeElementHeader error", func(t *testing.T) { + vw := newValueWriterFromSlice(nil) + want := TransitionError{current: mTopLevel, destination: mode(0)} + got := vw.WriteValueBytes(bson.TypeEmbeddedDocument, nil) + if !compareErrors(got, want) { + t.Errorf("Did not received expected error. got %v; want %v", got, want) + } + }) + t.Run("success", func(t *testing.T) { + vw := newValueWriterFromSlice(make([]byte, 0, 512)) + _, err := vw.WriteDocument() + noerr(t, err) + _, err = vw.WriteDocumentElement("foo") + noerr(t, err) + doc := bson.NewDocument(bson.EC.String("hello", "world")) + b, err := doc.MarshalBSON() + noerr(t, err) + err = vw.WriteValueBytes(bson.TypeEmbeddedDocument, b) + noerr(t, err) + err = vw.WriteDocumentEnd() + noerr(t, err) + want, err := bson.NewDocument(bson.EC.SubDocument("foo", doc)).MarshalBSON() + noerr(t, err) + got := vw.buf + if !bytes.Equal(got, want) { + t.Errorf("Bytes are not equal. got %v; want %v", bson.Reader(got), bson.Reader(want)) + } + }) + }) } type errWriter struct { @@ -322,22 +354,22 @@ func TestValueWriterOLD(t *testing.T) { { "simple document", vwBasicDoc, - bytesFromDoc(NewDocument(EC.Boolean("foo", true))), + bytesFromDoc(bson.NewDocument(bson.EC.Boolean("foo", true))), }, { "nested document", vwNestedDoc, - bytesFromDoc(NewDocument(EC.SubDocumentFromElements("foo", EC.Boolean("bar", true)))), + bytesFromDoc(bson.NewDocument(bson.EC.SubDocumentFromElements("foo", bson.EC.Boolean("bar", true)))), }, { "simple array", vwBasicArray, - bytesFromDoc(NewDocument(EC.ArrayFromElements("foo", VC.Boolean(true)))), + bytesFromDoc(bson.NewDocument(bson.EC.ArrayFromElements("foo", bson.VC.Boolean(true)))), }, { "code with scope", vwCodeWithScopeNoNested, - bytesFromDoc(NewDocument(EC.CodeWithScope("foo", "var hello = world;", NewDocument(EC.Boolean("bar", false))))), + bytesFromDoc(bson.NewDocument(bson.EC.CodeWithScope("foo", "var hello = world;", bson.NewDocument(bson.EC.Boolean("bar", false))))), }, } @@ -347,7 +379,7 @@ func TestValueWriterOLD(t *testing.T) { vw := newValueWriter(&got) tc.fn(t, vw) if !bytes.Equal(got, tc.want) { - t.Errorf("Documents are not equal. got %v; want %v", Reader(got), Reader(tc.want)) + t.Errorf("Documents are not equal. got %v; want %v", bson.Reader(got), bson.Reader(tc.want)) t.Errorf("Bytes:\n%v\n%v", got, tc.want) } }) diff --git a/bson/writer.go b/bson/bsoncodec/writer.go similarity index 85% rename from bson/writer.go rename to bson/bsoncodec/writer.go index c87ae8450f..bfebd7263d 100644 --- a/bson/writer.go +++ b/bson/bsoncodec/writer.go @@ -1,6 +1,7 @@ -package bson +package bsoncodec import ( + "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/bson/decimal" "github.com/mongodb/mongo-go-driver/bson/objectid" ) @@ -49,6 +50,13 @@ type ValueWriter interface { WriteUndefined() error } +// BytesWriter is the interface used to write BSON bytes to a ValueWriter. +// This interface is meant to be a superset of ValueWriter, so that types that +// implement ValueWriter may also implement this interface. +type BytesWriter interface { + WriteValueBytes(t bson.Type, b []byte) error +} + type writer []byte func (w *writer) Write(p []byte) (int, error) { diff --git a/bson/codec.go b/bson/codec.go deleted file mode 100644 index 934fa885a8..0000000000 --- a/bson/codec.go +++ /dev/null @@ -1,2055 +0,0 @@ -package bson - -import ( - "encoding/json" - "errors" - "fmt" - "math" - "net/url" - "reflect" - "strconv" - "strings" - "time" - - "github.com/mongodb/mongo-go-driver/bson/decimal" - "github.com/mongodb/mongo-go-driver/bson/objectid" -) - -var defaultBoolCodec = &BooleanCodec{} -var defaultIntCodec = &IntCodec{} -var defaultUintCodec = &UintCodec{} -var defaultFloatCodec = &FloatCodec{} -var defaultStringCodec = &StringCodec{} -var defaultDocumentCodec = &DocumentCodec{} -var defaultArrayCodec = &ArrayCodec{} -var defaultTimeCodec = &TimeCodec{} -var defaultElementCodec = &elementCodec{} -var defaultValueCodec = &ValueCodec{} -var defaultByteSliceCodec = &ByteSliceCodec{} -var defaultBinaryCodec = &BinaryCodec{} -var defaultUndefinedCodec = &UndefinedCodec{} -var defaultObjectIDCodec = &ObjectIDCodec{} -var defaultDateTimeCodec = &DateTimeCodec{} -var defaultNullCodec = &NullCodec{} -var defaultRegexCodec = &RegexCodec{} -var defaultDBPointerCodec = &DBPointerCodec{} -var defaultCodeWithScopeCodec = &CodeWithScopeCodec{} -var defaultTimestampCodec = &TimestampCodec{} -var defaultDecimal128Codec = &Decimal128Codec{} -var defaultMinKeyCodec = &MinKeyCodec{} -var defaultMaxKeyCodec = &MaxKeyCodec{} -var defaultJSONNumberCodec = &JSONNumberCodec{} -var defaultURLCodec = &URLCodec{} -var defaultReaderCodec = &ReaderCodec{} -var defaultElementSliceCodec = &ElementSliceCodec{} - -var ptBool = reflect.TypeOf((*bool)(nil)) -var ptInt8 = reflect.TypeOf((*int8)(nil)) -var ptInt16 = reflect.TypeOf((*int16)(nil)) -var ptInt32 = reflect.TypeOf((*int32)(nil)) -var ptInt64 = reflect.TypeOf((*int64)(nil)) -var ptInt = reflect.TypeOf((*int)(nil)) -var ptUint8 = reflect.TypeOf((*uint8)(nil)) -var ptUint16 = reflect.TypeOf((*uint16)(nil)) -var ptUint32 = reflect.TypeOf((*uint32)(nil)) -var ptUint64 = reflect.TypeOf((*uint64)(nil)) -var ptUint = reflect.TypeOf((*uint)(nil)) -var ptFloat32 = reflect.TypeOf((*float32)(nil)) -var ptFloat64 = reflect.TypeOf((*float64)(nil)) -var ptString = reflect.TypeOf((*string)(nil)) - -// CodecEncodeError is an error returned from a Codec's EncodeValue method when -// the provided value can't be encoded with the given Codec. -type CodecEncodeError struct { - Codec interface{} - Types []interface{} - Received interface{} -} - -func (cee CodecEncodeError) Error() string { - types := make([]string, 0, len(cee.Types)) - for _, t := range cee.Types { - types = append(types, fmt.Sprintf("%T", t)) - } - return fmt.Sprintf("%T can only process %s, but got a %T", cee.Codec, strings.Join(types, ", "), cee.Received) -} - -// CodecDecodeError is an error returned from a Codec's DecodeValue method when -// the provided value can't be decoded with the given Codec. -type CodecDecodeError struct { - Codec interface{} - Types []interface{} - Received interface{} -} - -func (dee CodecDecodeError) Error() string { - types := make([]string, 0, len(dee.Types)) - for _, t := range dee.Types { - types = append(types, fmt.Sprintf("%T", t)) - } - return fmt.Sprintf("%T can only process %s, but got a %T", dee.Codec, strings.Join(types, ", "), dee.Received) -} - -// EncodeContext is the contextual information required for a Codec to encode a -// value. -type EncodeContext struct { - *Registry - MinSize bool -} - -// DecodeContext is the contextual information required for a Codec to decode a -// value. -type DecodeContext struct { - *Registry - Truncate bool -} - -// Codec implementations handle encoding and decoding values. They can be -// registered in a registry which will handle invoking them. Callers of the -// DecodeValue methods pass in a pointer to the value, and -// implementations operate on a pointer to the value. This is true of pointer -// values as well, so a caller of DecodeValue for a pointer type *Foo will pass -// in **Foo. -type Codec interface { - EncodeValue(EncodeContext, ValueWriter, interface{}) error - DecodeValue(DecodeContext, ValueReader, interface{}) error -} - -// CodecZeroer is the interface implemented by Codecs that can also determine if -// a value of the type that would be encoded is zero. -type CodecZeroer interface { - Codec - IsZero(interface{}) bool -} - -// BooleanCodec is the Codec used for bool values. -type BooleanCodec struct{} - -var _ Codec = &BooleanCodec{} - -// EncodeValue implements the Codec interface. -func (bc *BooleanCodec) EncodeValue(ectx EncodeContext, vw ValueWriter, i interface{}) error { - b, ok := i.(bool) - if !ok { - if reflect.TypeOf(i).Kind() != reflect.Bool { - return CodecEncodeError{Codec: bc, Types: []interface{}{bool(true)}, Received: i} - } - - b = reflect.ValueOf(i).Bool() - } - - return vw.WriteBoolean(b) -} - -// DecodeValue implements the Codec interface. -func (bc *BooleanCodec) DecodeValue(dctx DecodeContext, vr ValueReader, i interface{}) error { - if vr.Type() != TypeBoolean { - return fmt.Errorf("cannot decode %v into a boolean", vr.Type()) - } - - var err error - if target, ok := i.(*bool); ok && target != nil { // if it is nil, we go the slow path. - *target, err = vr.ReadBoolean() - return err - } - - val := reflect.ValueOf(i) - if !val.IsValid() || val.Kind() != reflect.Ptr || !val.Elem().CanSet() { - return fmt.Errorf("%T can only be used to decode settable (non-nil) values", bc) - } - val = val.Elem() - if val.Type().Kind() != reflect.Bool { - return CodecDecodeError{Codec: bc, Types: []interface{}{bool(true)}, Received: i} - } - - b, err := vr.ReadBoolean() - val.SetBool(b) - return err -} - -// IntCodec is the Codec used for int8, int16, int32, int64, and int values. -type IntCodec struct{} - -var _ Codec = &IntCodec{} - -// EncodeValue implements the Codec interface. -func (ic *IntCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - switch t := i.(type) { - case int8: - return vw.WriteInt32(int32(t)) - case int16: - return vw.WriteInt32(int32(t)) - case int32: - return vw.WriteInt32(t) - case int64: - if ec.MinSize && t <= math.MaxInt32 { - return vw.WriteInt32(int32(t)) - } - return vw.WriteInt64(t) - case int: - if ec.MinSize && t <= math.MaxInt32 { - return vw.WriteInt32(int32(t)) - } - return vw.WriteInt64(int64(t)) - } - - val := reflect.ValueOf(i) - switch val.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32: - return vw.WriteInt32(int32(val.Int())) - case reflect.Int, reflect.Int64: - i64 := val.Int() - if ec.MinSize && i64 <= math.MaxInt32 { - return vw.WriteInt32(int32(i64)) - } - return vw.WriteInt64(i64) - } - - return CodecEncodeError{Codec: ic, Types: []interface{}{int8(0), int16(0), int32(0), int64(0), int(0)}, Received: i} -} - -// DecodeValue implements the Codec interface. -func (ic *IntCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - var i64 int64 - var err error - switch vr.Type() { - case TypeInt32: - i32, err := vr.ReadInt32() - if err != nil { - return err - } - i64 = int64(i32) - case TypeInt64: - i64, err = vr.ReadInt64() - if err != nil { - return err - } - case TypeDouble: - f64, err := vr.ReadDouble() - if err != nil { - return err - } - if !dc.Truncate && math.Floor(f64) != f64 { - return fmt.Errorf("%T can only truncate float64 to an integer type when truncation is enabled", ic) - } - if f64 > float64(math.MaxInt64) { - return fmt.Errorf("%g overflows int64", f64) - } - i64 = int64(f64) - default: - return fmt.Errorf("cannot decode %v into an integer type", vr.Type()) - } - - switch target := i.(type) { - case *int8: - if target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *int8", ic) - } - if i64 < math.MinInt8 || i64 > math.MaxInt8 { - return fmt.Errorf("%d overflows int8", i64) - } - *target = int8(i64) - return nil - case *int16: - if target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *int16", ic) - } - if i64 < math.MinInt16 || i64 > math.MaxInt16 { - return fmt.Errorf("%d overflows int16", i64) - } - *target = int16(i64) - return nil - case *int32: - if target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *int32", ic) - } - if i64 < math.MinInt32 || i64 > math.MaxInt32 { - return fmt.Errorf("%d overflows int32", i64) - } - *target = int32(i64) - return nil - case *int64: - if target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *int64", ic) - } - *target = int64(i64) - return nil - case *int: - if target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *int", ic) - } - if int64(int(i64)) != i64 { // Can we fit this inside of an int - return fmt.Errorf("%d overflows int", i64) - } - *target = int(i64) - return nil - } - - val := reflect.ValueOf(i) - if !val.IsValid() || val.Kind() != reflect.Ptr || !val.Elem().CanSet() { - return fmt.Errorf("%T can only be used to decode settable (non-nil) values", ic) - } - val = val.Elem() - - switch val.Type().Kind() { - case reflect.Int8: - if i64 < math.MinInt8 || i64 > math.MaxInt8 { - return fmt.Errorf("%d overflows int8", i64) - } - case reflect.Int16: - if i64 < math.MinInt16 || i64 > math.MaxInt16 { - return fmt.Errorf("%d overflows int16", i64) - } - case reflect.Int32: - if i64 < math.MinInt32 || i64 > math.MaxInt32 { - return fmt.Errorf("%d overflows int32", i64) - } - case reflect.Int64: - case reflect.Int: - if int64(int(i64)) != i64 { // Can we fit this inside of an int - return fmt.Errorf("%d overflows int", i64) - } - default: - return CodecDecodeError{ - Codec: ic, - Types: []interface{}{(*int8)(nil), (*int16)(nil), (*int32)(nil), (*int64)(nil), (*int)(nil)}, - Received: i, - } - } - - val.SetInt(i64) - return nil -} - -// UintCodec is the Codec used for uint8, uint16, uint32, uint64, and uint -// values. -type UintCodec struct{} - -var _ Codec = &UintCodec{} - -// EncodeValue implements the Codec interface. -func (uc *UintCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - switch t := i.(type) { - case uint8: - return vw.WriteInt32(int32(t)) - case uint16: - return vw.WriteInt32(int32(t)) - case uint: - if ec.MinSize && t <= math.MaxInt32 { - return vw.WriteInt32(int32(t)) - } - if uint64(t) > math.MaxInt64 { - return fmt.Errorf("%d overflows int64", t) - } - return vw.WriteInt64(int64(t)) - case uint32: - if ec.MinSize && t <= math.MaxInt32 { - return vw.WriteInt32(int32(t)) - } - return vw.WriteInt64(int64(t)) - case uint64: - if ec.MinSize && t <= math.MaxInt32 { - return vw.WriteInt32(int32(t)) - } - if t > math.MaxInt64 { - return fmt.Errorf("%d overflows int64", t) - } - return vw.WriteInt64(int64(t)) - } - - val := reflect.ValueOf(i) - switch val.Type().Kind() { - case reflect.Uint8, reflect.Uint16: - return vw.WriteInt32(int32(val.Uint())) - case reflect.Uint, reflect.Uint32, reflect.Uint64: - u64 := val.Uint() - if ec.MinSize && u64 <= math.MaxInt32 { - return vw.WriteInt32(int32(u64)) - } - if u64 > math.MaxInt64 { - return fmt.Errorf("%d overflows int64", u64) - } - return vw.WriteInt64(int64(u64)) - } - - return CodecEncodeError{Codec: uc, Types: []interface{}{uint8(0), uint16(0), uint32(0), uint64(0), uint(0)}, Received: i} -} - -// DecodeValue implements the Codec interface. -func (uc *UintCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - var i64 int64 - var err error - switch vr.Type() { - case TypeInt32: - i32, err := vr.ReadInt32() - if err != nil { - return err - } - i64 = int64(i32) - case TypeInt64: - i64, err = vr.ReadInt64() - if err != nil { - return err - } - case TypeDouble: - f64, err := vr.ReadDouble() - if err != nil { - return err - } - if !dc.Truncate && math.Floor(f64) != f64 { - return fmt.Errorf("%T can only truncate float64 to an integer type when truncation is enabled", uc) - } - if f64 > float64(math.MaxInt64) { - return fmt.Errorf("%g overflows int64", f64) - } - i64 = int64(f64) - default: - return fmt.Errorf("cannot decode %v into an integer type", vr.Type()) - } - - switch target := i.(type) { - case *uint8: - if target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *uint8", uc) - } - if i64 < 0 || i64 > math.MaxUint8 { - return fmt.Errorf("%d overflows uint8", i64) - } - *target = uint8(i64) - return nil - case *uint16: - if target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *uint16", uc) - } - if i64 < 0 || i64 > math.MaxUint16 { - return fmt.Errorf("%d overflows uint16", i64) - } - *target = uint16(i64) - return nil - case *uint32: - if target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *uint32", uc) - } - if i64 < 0 || i64 > math.MaxUint32 { - return fmt.Errorf("%d overflows uint32", i64) - } - *target = uint32(i64) - return nil - case *uint64: - if target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *uint64", uc) - } - if i64 < 0 { - return fmt.Errorf("%d overflows uint64", i64) - } - *target = uint64(i64) - return nil - case *uint: - if target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *uint", uc) - } - if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint - return fmt.Errorf("%d overflows uint", i64) - } - *target = uint(i64) - return nil - } - - val := reflect.ValueOf(i) - if !val.IsValid() || val.Kind() != reflect.Ptr || !val.Elem().CanSet() { - return fmt.Errorf("%T can only be used to decode settable (non-nil) values", uc) - } - val = val.Elem() - - switch val.Type().Kind() { - case reflect.Uint8: - if i64 < 0 || i64 > math.MaxUint8 { - return fmt.Errorf("%d overflows uint8", i64) - } - case reflect.Uint16: - if i64 < 0 || i64 > math.MaxUint16 { - return fmt.Errorf("%d overflows uint16", i64) - } - case reflect.Uint32: - if i64 < 0 || i64 > math.MaxUint32 { - return fmt.Errorf("%d overflows uint32", i64) - } - case reflect.Uint64: - if i64 < 0 { - return fmt.Errorf("%d overflows uint64", i64) - } - case reflect.Uint: - if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint - return fmt.Errorf("%d overflows uint", i64) - } - default: - return CodecDecodeError{ - Codec: uc, - Types: []interface{}{(*uint8)(nil), (*uint16)(nil), (*uint32)(nil), (*uint64)(nil), (*uint)(nil)}, - Received: i, - } - } - - val.SetUint(uint64(i64)) - return nil -} - -// FloatCodec is the Codec used for float32 and float64 values. -type FloatCodec struct{} - -var _ Codec = &FloatCodec{} - -// EncodeValue implements the Codec interface. -func (fc *FloatCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - switch t := i.(type) { - case float32: - return vw.WriteDouble(float64(t)) - case float64: - return vw.WriteDouble(t) - } - - val := reflect.ValueOf(i) - switch val.Type().Kind() { - case reflect.Float32, reflect.Float64: - return vw.WriteDouble(val.Float()) - } - - return CodecEncodeError{Codec: fc, Types: []interface{}{float32(0), float64(0)}, Received: i} -} - -// DecodeValue implements the Codec interface. -func (fc *FloatCodec) DecodeValue(ec DecodeContext, vr ValueReader, i interface{}) error { - var f float64 - var err error - switch vr.Type() { - case TypeInt32: - i32, err := vr.ReadInt32() - if err != nil { - return err - } - f = float64(i32) - case TypeInt64: - i64, err := vr.ReadInt64() - if err != nil { - return err - } - f = float64(i64) - case TypeDouble: - f, err = vr.ReadDouble() - if err != nil { - return err - } - default: - return fmt.Errorf("cannot decode %v into a float32 or float64 type", vr.Type()) - } - - switch target := i.(type) { - case *float32: - if target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *float32", fc) - } - if !ec.Truncate && float64(float32(f)) != f { - return fmt.Errorf("%T can only convert float64 to float32 when truncation is allowed", fc) - } - *target = float32(f) - return nil - case *float64: - if target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *float64", fc) - } - *target = f - return nil - } - - val := reflect.ValueOf(i) - if !val.IsValid() || val.Kind() != reflect.Ptr || !val.Elem().CanSet() { - return fmt.Errorf("%T can only be used to decode settable (non-nil) values", fc) - } - val = val.Elem() - - switch val.Type().Kind() { - case reflect.Float32: - if !ec.Truncate && float64(float32(f)) != f { - return fmt.Errorf("%T can only convert float64 to float32 when truncation is allowed", fc) - } - case reflect.Float64: - default: - return CodecDecodeError{Codec: fc, Types: []interface{}{(*float32)(nil), (*float64)(nil)}, Received: i} - } - - val.SetFloat(f) - return nil -} - -// StringCodec is the Codec used for string values. -type StringCodec struct{} - -var _ Codec = &StringCodec{} - -// EncodeValue implements the Codec interface. -func (sc *StringCodec) EncodeValue(ectx EncodeContext, vw ValueWriter, i interface{}) error { - switch t := i.(type) { - case string: - return vw.WriteString(t) - case JavaScriptCode: - return vw.WriteJavascript(string(t)) - case Symbol: - return vw.WriteSymbol(string(t)) - } - - val := reflect.ValueOf(i) - if val.Type().Kind() != reflect.String { - return CodecEncodeError{Codec: sc, Types: []interface{}{string(""), JavaScriptCode(""), Symbol("")}, Received: i} - } - - return vw.WriteString(val.String()) -} - -// DecodeValue implements the Codec interface. -func (sc *StringCodec) DecodeValue(dctx DecodeContext, vr ValueReader, i interface{}) error { - var str string - var err error - switch vr.Type() { - case TypeString: - str, err = vr.ReadString() - if err != nil { - return err - } - case TypeJavaScript: - str, err = vr.ReadJavascript() - if err != nil { - return err - } - case TypeSymbol: - str, err = vr.ReadSymbol() - if err != nil { - return err - } - default: - return fmt.Errorf("cannot decode %v into a string type", vr.Type()) - } - - switch t := i.(type) { - case *string: - if t == nil { - return fmt.Errorf("%T can only be used to decode non-nil *string", sc) - } - *t = str - return nil - case *JavaScriptCode: - if t == nil { - return fmt.Errorf("%T can only be used to decode non-nil *JavaScriptCode", sc) - } - *t = JavaScriptCode(str) - return nil - case *Symbol: - if t == nil { - return fmt.Errorf("%T can only be used to decode non-nil *Symbol", sc) - } - *t = Symbol(str) - return nil - } - - val := reflect.ValueOf(i) - if !val.IsValid() || val.Kind() != reflect.Ptr || !val.Elem().CanSet() { - return fmt.Errorf("%T can only be used to decode settable (non-nil) values", sc) - } - val = val.Elem() - - if val.Type().Kind() != reflect.String { - return CodecDecodeError{Codec: sc, Types: []interface{}{(*string)(nil), (*JavaScriptCode)(nil), (*Symbol)(nil)}, Received: i} - } - - val.SetString(str) - return nil -} - -// DocumentCodec is the Codec used for *Document values. -type DocumentCodec struct{} - -var _ Codec = &DocumentCodec{} - -// EncodeValue implements the Codec interface. -func (dc *DocumentCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - doc, ok := i.(*Document) - if !ok { - return CodecEncodeError{Codec: dc, Types: []interface{}{(*Document)(nil)}, Received: i} - } - - dw, err := vw.WriteDocument() - if err != nil { - return err - } - - return dc.encodeDocument(ec, dw, doc) -} - -// encodeDocument is a separate function that we use because CodeWithScope -// returns us a DocumentWriter and we need to do the same logic that we would do -// for a document but cannot use a Codec. -func (dc DocumentCodec) encodeDocument(ec EncodeContext, dw DocumentWriter, doc *Document) error { - itr := doc.Iterator() - - for itr.Next() { - elem := itr.Element() - dvw, err := dw.WriteDocumentElement(elem.Key()) - if err != nil { - return err - } - - val := elem.Value() - err = defaultValueCodec.encodeValue(ec, dvw, val) - - if err != nil { - return err - } - } - - if err := itr.Err(); err != nil { - return err - } - - return dw.WriteDocumentEnd() - -} - -// DecodeValue implements the Codec interface. -func (dc *DocumentCodec) DecodeValue(dctx DecodeContext, vr ValueReader, i interface{}) error { - doc, ok := i.(**Document) - if !ok { - return CodecDecodeError{Codec: dc, Types: []interface{}{(**Document)(nil)}, Received: i} - } - - if doc == nil { - return fmt.Errorf("%T can only be used to decode non-nil **Document", dc) - } - - dr, err := vr.ReadDocument() - if err != nil { - return err - } - - return dc.decodeDocument(dctx, dr, doc) -} - -func (dc DocumentCodec) decodeDocument(dctx DecodeContext, dr DocumentReader, pdoc **Document) error { - doc := NewDocument() - for { - key, vr, err := dr.ReadElement() - if err == ErrEOD { - break - } - if err != nil { - return err - } - - var elem *Element - err = defaultElementCodec.decodeValue(dctx, vr, key, &elem) - if err != nil { - return err - } - - doc.Append(elem) - } - - *pdoc = doc - return nil -} - -// ArrayCodec is the Codec used for *Array values. -type ArrayCodec struct{} - -var _ Codec = &ArrayCodec{} - -// EncodeValue implements the Codec interface. -func (ac *ArrayCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - arr, ok := i.(*Array) - if !ok { - return CodecEncodeError{Codec: ac, Types: []interface{}{(*Array)(nil)}, Received: i} - } - - aw, err := vw.WriteArray() - if err != nil { - return err - } - - itr := newArrayIterator(arr) - - for itr.Next() { - val := itr.Value() - dvw, err := aw.WriteArrayElement() - if err != nil { - return err - } - - err = defaultValueCodec.encodeValue(ec, dvw, val) - - if err != nil { - return err - } - } - - if err := itr.Err(); err != nil { - return err - } - - return aw.WriteArrayEnd() -} - -// DecodeValue implements the Codec interface. -func (ac *ArrayCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - parr, ok := i.(**Array) - if !ok { - return CodecDecodeError{Codec: ac, Types: []interface{}{(**Array)(nil)}, Received: i} - } - - if parr == nil { - return fmt.Errorf("%T can only be used to decode non-nil **Array", ac) - } - - ar, err := vr.ReadArray() - if err != nil { - return err - } - - arr := NewArray() - for { - vr, err := ar.ReadValue() - if err == ErrEOA { - break - } - if err != nil { - return err - } - - var val *Value - err = defaultValueCodec.decodeValue(dc, vr, &val) - if err != nil { - return err - } - - arr.Append(val) - } - - *parr = arr - return nil -} - -// BinaryCodec is the Codec used for Binary values. -type BinaryCodec struct{} - -var _ Codec = &BinaryCodec{} - -// EncodeValue implements the Codec interface. -func (bc *BinaryCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - var b Binary - switch t := i.(type) { - case Binary: - b = t - case *Binary: - b = *t - default: - return CodecEncodeError{Codec: bc, Types: []interface{}{Binary{}, (*Binary)(nil)}, Received: i} - } - - return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) -} - -// DecodeValue implements the Codec interface. -func (bc *BinaryCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - if vr.Type() != TypeBinary { - return fmt.Errorf("cannot decode %v into a Binary", vr.Type()) - } - - data, subtype, err := vr.ReadBinary() - if err != nil { - return err - } - - if target, ok := i.(*Binary); ok && target != nil { - *target = Binary{Data: data, Subtype: subtype} - return nil - } - - if target, ok := i.(**Binary); ok && target != nil { - pb := *target - if pb == nil { - pb = new(Binary) - } - *pb = Binary{Data: data, Subtype: subtype} - *target = pb - return nil - } - - return fmt.Errorf("%T can only be used to decode non-nil *Binary values, got %T", bc, i) -} - -// UndefinedCodec is the Codec for Undefined values. -type UndefinedCodec struct{} - -var _ Codec = &UndefinedCodec{} - -// EncodeValue implements the Codec interface. -func (uc *UndefinedCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - switch i.(type) { - case Undefinedv2, *Undefinedv2: - default: - return CodecEncodeError{Codec: uc, Types: []interface{}{Undefinedv2{}, (*Undefinedv2)(nil)}, Received: i} - } - - return vw.WriteUndefined() -} - -// DecodeValue implements the Codec interface. -func (uc *UndefinedCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - if vr.Type() != TypeUndefined { - return fmt.Errorf("cannot decode %v into an Undefined", vr.Type()) - } - - target, ok := i.(*Undefinedv2) - if !ok || target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *Undefined values, got %T", uc, i) - } - - *target = Undefinedv2{} - return vr.ReadUndefined() -} - -// ObjectIDCodec is the Codec for objectid.ObjectID values. -type ObjectIDCodec struct{} - -var _ Codec = &ObjectIDCodec{} - -// EncodeValue implements the Codec interface. -func (oidc *ObjectIDCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - var oid objectid.ObjectID - switch t := i.(type) { - case objectid.ObjectID: - oid = t - case *objectid.ObjectID: - oid = *t - default: - return CodecEncodeError{Codec: oidc, Types: []interface{}{objectid.ObjectID{}, (*objectid.ObjectID)(nil)}, Received: i} - } - - return vw.WriteObjectID(oid) -} - -// DecodeValue implements the Codec interface. -func (oidc *ObjectIDCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - if vr.Type() != TypeObjectID { - return fmt.Errorf("cannot decode %v into an ObjectID", vr.Type()) - } - - target, ok := i.(*objectid.ObjectID) - if !ok || target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *objectid.ObjectID values, got %T", oidc, i) - } - - oid, err := vr.ReadObjectID() - if err != nil { - return err - } - - *target = oid - return nil -} - -// DateTimeCodec is the Codec for DateTime values. -type DateTimeCodec struct{} - -var _ Codec = &DateTimeCodec{} - -// EncodeValue implements the Codec interface. -func (dtc *DateTimeCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - var dt DateTime - switch t := i.(type) { - case DateTime: - dt = t - case *DateTime: - dt = *t - default: - return CodecEncodeError{Codec: dtc, Types: []interface{}{DateTime(0), (*DateTime)(nil)}, Received: i} - } - - return vw.WriteDateTime(int64(dt)) -} - -// DecodeValue implements the Codec interface. -func (dtc *DateTimeCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - if vr.Type() != TypeDateTime { - return fmt.Errorf("cannot decode %v into a DateTime", vr.Type()) - } - - target, ok := i.(*DateTime) - if !ok || target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *DateTime values, got %T", dtc, i) - } - - dt, err := vr.ReadDateTime() - if err != nil { - return err - } - - *target = DateTime(dt) - return nil -} - -// NullCodec is the Codec for Null values. -type NullCodec struct{} - -var _ Codec = &NullCodec{} - -// EncodeValue implements the Codec interface. -func (nc *NullCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - switch i.(type) { - case Nullv2, *Nullv2: - default: - return CodecEncodeError{Codec: nc, Types: []interface{}{Nullv2{}, (*Nullv2)(nil)}, Received: i} - } - - return vw.WriteNull() -} - -// DecodeValue implements the Codec interface. -func (nc *NullCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - if vr.Type() != TypeNull { - return fmt.Errorf("cannot decode %v into a Null", vr.Type()) - } - - target, ok := i.(*Nullv2) - if !ok || target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *Null values, got %T", nc, i) - } - - *target = Nullv2{} - return vr.ReadNull() -} - -// RegexCodec is the Codec for Regex values. -type RegexCodec struct{} - -var _ Codec = &RegexCodec{} - -// EncodeValue implements the Codec interface. -func (rc *RegexCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - var regex Regex - switch t := i.(type) { - case Regex: - regex = t - case *Regex: - regex = *t - default: - return CodecEncodeError{Codec: rc, Types: []interface{}{Regex{}, (*Regex)(nil)}, Received: i} - } - - return vw.WriteRegex(regex.Pattern, regex.Options) -} - -// DecodeValue implements the Codec interface. -func (rc *RegexCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - if vr.Type() != TypeRegex { - return fmt.Errorf("cannot decode %v into a Regex", vr.Type()) - } - - target, ok := i.(*Regex) - if !ok || target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *Regex values, got %T", rc, i) - } - - pattern, options, err := vr.ReadRegex() - if err != nil { - return err - } - - *target = Regex{Pattern: pattern, Options: options} - return nil -} - -// DBPointerCodec is the Codec for DBPointer values. -type DBPointerCodec struct{} - -var _ Codec = &DBPointerCodec{} - -// EncodeValue implements the Codec interface. -func (dbpc *DBPointerCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - var dbp DBPointer - switch t := i.(type) { - case DBPointer: - dbp = t - case *DBPointer: - dbp = *t - default: - return CodecEncodeError{Codec: dbpc, Types: []interface{}{DBPointer{}, (*DBPointer)(nil)}, Received: i} - } - - return vw.WriteDBPointer(dbp.DB, dbp.Pointer) -} - -// DecodeValue implements the Codec interface. -func (dbpc *DBPointerCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - if vr.Type() != TypeDBPointer { - return fmt.Errorf("cannot decode %v into a DBPointer", vr.Type()) - } - - target, ok := i.(*DBPointer) - if !ok || target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *DBPointer values, got %T", dbpc, i) - } - - ns, pointer, err := vr.ReadDBPointer() - if err != nil { - return err - } - - *target = DBPointer{DB: ns, Pointer: pointer} - return nil -} - -// CodeWithScopeCodec is the Codec for CodeWithScope values. -type CodeWithScopeCodec struct{} - -var _ Codec = &CodeWithScopeCodec{} - -// EncodeValue implements the Codec interface. -func (cwsc *CodeWithScopeCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - var cws CodeWithScope - switch t := i.(type) { - case CodeWithScope: - cws = t - case *CodeWithScope: - cws = *t - default: - return CodecEncodeError{Codec: cwsc, Types: []interface{}{CodeWithScope{}, (*CodeWithScope)(nil)}, Received: i} - } - - dw, err := vw.WriteCodeWithScope(cws.Code) - if err != nil { - return err - } - return defaultDocumentCodec.encodeDocument(ec, dw, cws.Scope) -} - -// DecodeValue implements the Codec interface. -func (cwsc *CodeWithScopeCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - if vr.Type() != TypeCodeWithScope { - return fmt.Errorf("cannot decode %v into a CodeWithScope", vr.Type()) - } - - target, ok := i.(*CodeWithScope) - if !ok || target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *CodeWithScope values, got %T", cwsc, i) - } - - code, dr, err := vr.ReadCodeWithScope() - if err != nil { - return err - } - - var scope *Document - err = defaultDocumentCodec.decodeDocument(dc, dr, &scope) - if err != nil { - return err - } - - *target = CodeWithScope{Code: code, Scope: scope} - return nil -} - -// TimestampCodec is the Codec for Timestamp values. -type TimestampCodec struct{} - -var _ Codec = &TimestampCodec{} - -// EncodeValue implements the Codec interface. -func (tc *TimestampCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - var ts Timestamp - switch t := i.(type) { - case Timestamp: - ts = t - case *Timestamp: - ts = *t - default: - return CodecEncodeError{Codec: tc, Types: []interface{}{Timestamp{}, (*Timestamp)(nil)}, Received: i} - } - - return vw.WriteTimestamp(ts.T, ts.I) -} - -// DecodeValue implements the Codec interface. -func (tc *TimestampCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - if vr.Type() != TypeTimestamp { - return fmt.Errorf("cannot decode %v into a Timestamp", vr.Type()) - } - - target, ok := i.(*Timestamp) - if !ok || target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *Timestamp values, got %T", tc, i) - } - - t, incr, err := vr.ReadTimestamp() - if err != nil { - return err - } - - *target = Timestamp{T: t, I: incr} - return nil -} - -// Decimal128Codec is the Codec for decimal.Decimal128 values. -type Decimal128Codec struct{} - -var _ Codec = &Decimal128Codec{} - -// EncodeValue implements the Codec interface. -func (dc *Decimal128Codec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - var d128 decimal.Decimal128 - switch t := i.(type) { - case decimal.Decimal128: - d128 = t - case *decimal.Decimal128: - d128 = *t - default: - return CodecEncodeError{Codec: dc, Types: []interface{}{decimal.Decimal128{}, (*decimal.Decimal128)(nil)}, Received: i} - } - - return vw.WriteDecimal128(d128) -} - -// DecodeValue implements the Codec interface. -func (dc *Decimal128Codec) DecodeValue(dctx DecodeContext, vr ValueReader, i interface{}) error { - if vr.Type() != TypeDecimal128 { - return fmt.Errorf("cannot decode %v into a decimal.Decimal128", vr.Type()) - } - - target, ok := i.(*decimal.Decimal128) - if !ok || target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *decimal.Decimal128 values, got %T", dc, i) - } - - d128, err := vr.ReadDecimal128() - if err != nil { - return err - } - - *target = d128 - return nil -} - -// MinKeyCodec is the Codec for MinKey values. -type MinKeyCodec struct{} - -var _ Codec = &MinKeyCodec{} - -// EncodeValue implements the Codec interface. -func (mkc *MinKeyCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - switch i.(type) { - case MinKeyv2, *MinKeyv2: - default: - return CodecEncodeError{Codec: mkc, Types: []interface{}{MinKeyv2{}, (*MinKeyv2)(nil)}, Received: i} - } - - return vw.WriteMinKey() -} - -// DecodeValue implements the Codec interface. -func (mkc *MinKeyCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - if vr.Type() != TypeMinKey { - return fmt.Errorf("cannot decode %v into a MinKey", vr.Type()) - } - - target, ok := i.(*MinKeyv2) - if !ok || target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *MinKey values, got %T", mkc, i) - } - - *target = MinKeyv2{} - return vr.ReadMinKey() -} - -// MaxKeyCodec is the Codec for MaxKey values. -type MaxKeyCodec struct{} - -var _ Codec = &MaxKeyCodec{} - -// EncodeValue implements the Codec interface. -func (mkc *MaxKeyCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - switch i.(type) { - case MaxKeyv2, *MaxKeyv2: - default: - return CodecEncodeError{Codec: mkc, Types: []interface{}{MaxKeyv2{}, (*MaxKeyv2)(nil)}, Received: i} - } - - return vw.WriteMaxKey() -} - -// DecodeValue implements the Codec interface. -func (mkc *MaxKeyCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - if vr.Type() != TypeMaxKey { - return fmt.Errorf("cannot decode %v into a MaxKey", vr.Type()) - } - - target, ok := i.(*MaxKeyv2) - if !ok || target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *MaxKey values, got %T", mkc, i) - } - - *target = MaxKeyv2{} - return vr.ReadMaxKey() -} - -// elementCodec is the Codec for *Element values. -// -// This is a codec used internally. -type elementCodec struct{} - -func (ec *elementCodec) EncodeValue(ectx EncodeContext, vw ValueWriter, i interface{}) error { - elem, ok := i.(*Element) - if !ok { - return CodecEncodeError{Codec: ec, Types: []interface{}{(*Element)(nil)}, Received: i} - } - - if _, err := elem.Validate(); err != nil { - return err - } - - return ec.encodeValue(ectx, vw, elem) -} - -func (*elementCodec) DecodeValue(DecodeContext, ValueReader, interface{}) error { - return errors.New("elementCodec's DecodeValue method should not be used directly") -} - -func (ec *elementCodec) encodeValue(ectx EncodeContext, vw ValueWriter, elem *Element) error { - return defaultValueCodec.encodeValue(ectx, vw, elem.value) -} - -func (ec *elementCodec) decodeValue(dc DecodeContext, vr ValueReader, key string, elem **Element) error { - switch vr.Type() { - case TypeDouble: - f64, err := vr.ReadDouble() - if err != nil { - return err - } - *elem = EC.Double(key, f64) - case TypeString: - str, err := vr.ReadString() - if err != nil { - return err - } - *elem = EC.String(key, str) - case TypeEmbeddedDocument: - codec, err := dc.Lookup(tDocument) - if err != nil { - return err - } - var embeddedDoc *Document - err = codec.DecodeValue(dc, vr, &embeddedDoc) - if err != nil { - return err - } - *elem = EC.SubDocument(key, embeddedDoc) - case TypeArray: - codec, err := dc.Lookup(tArray) - if err != nil { - return err - } - var arr *Array - err = codec.DecodeValue(dc, vr, &arr) - if err != nil { - return err - } - *elem = EC.Array(key, arr) - case TypeBinary: - data, subtype, err := vr.ReadBinary() - if err != nil { - return err - } - *elem = EC.BinaryWithSubtype(key, data, subtype) - case TypeUndefined: - err := vr.ReadUndefined() - if err != nil { - return err - } - *elem = EC.Undefined(key) - case TypeObjectID: - oid, err := vr.ReadObjectID() - if err != nil { - return err - } - *elem = EC.ObjectID(key, oid) - case TypeBoolean: - b, err := vr.ReadBoolean() - if err != nil { - return err - } - *elem = EC.Boolean(key, b) - case TypeDateTime: - dt, err := vr.ReadDateTime() - if err != nil { - return err - } - *elem = EC.DateTime(key, dt) - case TypeNull: - err := vr.ReadNull() - if err != nil { - return err - } - *elem = EC.Null(key) - case TypeRegex: - pattern, options, err := vr.ReadRegex() - if err != nil { - return err - } - *elem = EC.Regex(key, pattern, options) - case TypeDBPointer: - ns, pointer, err := vr.ReadDBPointer() - if err != nil { - return err - } - *elem = EC.DBPointer(key, ns, pointer) - case TypeJavaScript: - js, err := vr.ReadJavascript() - if err != nil { - return err - } - *elem = EC.JavaScript(key, js) - case TypeSymbol: - symbol, err := vr.ReadSymbol() - if err != nil { - return err - } - *elem = EC.Symbol(key, symbol) - case TypeCodeWithScope: - code, scope, err := vr.ReadCodeWithScope() - if err != nil { - return err - } - scopeDoc := new(*Document) - err = defaultDocumentCodec.decodeDocument(dc, scope, scopeDoc) - if err != nil { - return err - } - *elem = EC.CodeWithScope(key, code, *scopeDoc) - case TypeInt32: - i32, err := vr.ReadInt32() - if err != nil { - return err - } - *elem = EC.Int32(key, i32) - case TypeTimestamp: - t, i, err := vr.ReadTimestamp() - if err != nil { - return err - } - *elem = EC.Timestamp(key, t, i) - case TypeInt64: - i64, err := vr.ReadInt64() - if err != nil { - return err - } - *elem = EC.Int64(key, i64) - case TypeDecimal128: - d128, err := vr.ReadDecimal128() - if err != nil { - return err - } - *elem = EC.Decimal128(key, d128) - case TypeMinKey: - err := vr.ReadMinKey() - if err != nil { - return err - } - *elem = EC.MinKey(key) - case TypeMaxKey: - err := vr.ReadMaxKey() - if err != nil { - return err - } - *elem = EC.MaxKey(key) - default: - return fmt.Errorf("Cannot read unknown BSON type %s", vr.Type()) - } - - return nil -} - -// ValueCodec is the Codec for *Value values. -type ValueCodec struct{} - -var _ Codec = &ValueCodec{} - -// EncodeValue implements the Codec interface. -func (vc *ValueCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - val, ok := i.(*Value) - if !ok { - return CodecEncodeError{Codec: vc, Types: []interface{}{(*Value)(nil)}, Received: i} - } - - if _, err := val.validate(false); err != nil { - return err - } - - return vc.encodeValue(ec, vw, val) -} - -// encodeValue does not validation, and the callers must perform validation on val before calling -// this method. -func (vc *ValueCodec) encodeValue(ec EncodeContext, vw ValueWriter, val *Value) error { - var err error - switch val.Type() { - case TypeDouble: - err = vw.WriteDouble(val.Double()) - case TypeString: - err = vw.WriteString(val.StringValue()) - case TypeEmbeddedDocument: - var codec Codec - codec, err = ec.Lookup(tDocument) - if err != nil { - break - } - err = codec.EncodeValue(ec, vw, val.MutableDocument()) - case TypeArray: - var codec Codec - codec, err = ec.Lookup(tArray) - if err != nil { - break - } - err = codec.EncodeValue(ec, vw, val.MutableArray()) - case TypeBinary: - // TODO: FIX THIS (╯°□°)╯︵ ┻━┻ - subtype, data := val.Binary() - err = vw.WriteBinaryWithSubtype(data, subtype) - case TypeUndefined: - err = vw.WriteUndefined() - case TypeObjectID: - err = vw.WriteObjectID(val.ObjectID()) - case TypeBoolean: - err = vw.WriteBoolean(val.Boolean()) - case TypeDateTime: - err = vw.WriteDateTime(val.DateTime()) - case TypeNull: - err = vw.WriteNull() - case TypeRegex: - err = vw.WriteRegex(val.Regex()) - case TypeDBPointer: - err = vw.WriteDBPointer(val.DBPointer()) - case TypeJavaScript: - err = vw.WriteJavascript(val.JavaScript()) - case TypeSymbol: - err = vw.WriteSymbol(val.Symbol()) - case TypeCodeWithScope: - code, scope := val.MutableJavaScriptWithScope() - - var cwsw DocumentWriter - cwsw, err = vw.WriteCodeWithScope(code) - if err != nil { - break - } - - err = defaultDocumentCodec.encodeDocument(ec, cwsw, scope) - case TypeInt32: - err = vw.WriteInt32(val.Int32()) - case TypeTimestamp: - err = vw.WriteTimestamp(val.Timestamp()) - case TypeInt64: - err = vw.WriteInt64(val.Int64()) - case TypeDecimal128: - err = vw.WriteDecimal128(val.Decimal128()) - case TypeMinKey: - err = vw.WriteMinKey() - case TypeMaxKey: - err = vw.WriteMaxKey() - default: - err = fmt.Errorf("%T is not a valid BSON type to encode", val.Type()) - } - - return err -} - -// DecodeValue implements the Codec interface. -func (vc *ValueCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - pval, ok := i.(**Value) - if !ok { - return CodecDecodeError{Codec: vc, Types: []interface{}{(**Value)(nil)}, Received: i} - } - - if pval == nil { - return fmt.Errorf("%T can only be used to decode non-nil **Value", vc) - } - - return vc.decodeValue(dc, vr, pval) -} - -func (vc *ValueCodec) decodeValue(dc DecodeContext, vr ValueReader, val **Value) error { - switch vr.Type() { - case TypeDouble: - f64, err := vr.ReadDouble() - if err != nil { - return err - } - *val = VC.Double(f64) - case TypeString: - str, err := vr.ReadString() - if err != nil { - return err - } - *val = VC.String(str) - case TypeEmbeddedDocument: - codec, err := dc.Lookup(tDocument) - if err != nil { - return err - } - var embeddedDoc *Document - err = codec.DecodeValue(dc, vr, &embeddedDoc) - if err != nil { - return err - } - *val = VC.Document(embeddedDoc) - case TypeArray: - codec, err := dc.Lookup(tArray) - if err != nil { - return err - } - var arr *Array - err = codec.DecodeValue(dc, vr, &arr) - if err != nil { - return err - } - *val = VC.Array(arr) - case TypeBinary: - data, subtype, err := vr.ReadBinary() - if err != nil { - return err - } - *val = VC.BinaryWithSubtype(data, subtype) - case TypeUndefined: - err := vr.ReadUndefined() - if err != nil { - return err - } - *val = VC.Undefined() - case TypeObjectID: - oid, err := vr.ReadObjectID() - if err != nil { - return err - } - *val = VC.ObjectID(oid) - case TypeBoolean: - b, err := vr.ReadBoolean() - if err != nil { - return err - } - *val = VC.Boolean(b) - case TypeDateTime: - dt, err := vr.ReadDateTime() - if err != nil { - return err - } - *val = VC.DateTime(dt) - case TypeNull: - err := vr.ReadNull() - if err != nil { - return err - } - *val = VC.Null() - case TypeRegex: - pattern, options, err := vr.ReadRegex() - if err != nil { - return err - } - *val = VC.Regex(pattern, options) - case TypeDBPointer: - ns, pointer, err := vr.ReadDBPointer() - if err != nil { - return err - } - *val = VC.DBPointer(ns, pointer) - case TypeJavaScript: - js, err := vr.ReadJavascript() - if err != nil { - return err - } - *val = VC.JavaScript(js) - case TypeSymbol: - symbol, err := vr.ReadSymbol() - if err != nil { - return err - } - *val = VC.Symbol(symbol) - case TypeCodeWithScope: - code, scope, err := vr.ReadCodeWithScope() - if err != nil { - return err - } - scopeDoc := new(*Document) - err = defaultDocumentCodec.decodeDocument(dc, scope, scopeDoc) - if err != nil { - return err - } - *val = VC.CodeWithScope(code, *scopeDoc) - case TypeInt32: - i32, err := vr.ReadInt32() - if err != nil { - return err - } - *val = VC.Int32(i32) - case TypeTimestamp: - t, i, err := vr.ReadTimestamp() - if err != nil { - return err - } - *val = VC.Timestamp(t, i) - case TypeInt64: - i64, err := vr.ReadInt64() - if err != nil { - return err - } - *val = VC.Int64(i64) - case TypeDecimal128: - d128, err := vr.ReadDecimal128() - if err != nil { - return err - } - *val = VC.Decimal128(d128) - case TypeMinKey: - err := vr.ReadMinKey() - if err != nil { - return err - } - *val = VC.MinKey() - case TypeMaxKey: - err := vr.ReadMaxKey() - if err != nil { - return err - } - *val = VC.MaxKey() - default: - return fmt.Errorf("Cannot read unknown BSON type %s", vr.Type()) - } - - return nil -} - -// JSONNumberCodec is the Codec for json.Number values. -type JSONNumberCodec struct{} - -var _ Codec = &JSONNumberCodec{} - -// EncodeValue implements the Codec interface. -func (jnc *JSONNumberCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - var jsnum json.Number - switch t := i.(type) { - case json.Number: - jsnum = t - case *json.Number: - jsnum = *t - default: - return CodecEncodeError{Codec: jnc, Types: []interface{}{json.Number(""), (*json.Number)(nil)}, Received: i} - } - - // Attempt int first, then float64 - if i64, err := jsnum.Int64(); err == nil { - return defaultIntCodec.EncodeValue(ec, vw, i64) - } - - f64, err := jsnum.Float64() - if err != nil { - return err - } - - return defaultFloatCodec.EncodeValue(ec, vw, f64) -} - -// DecodeValue implements the Codec interface. -func (jnc *JSONNumberCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - target, ok := i.(*json.Number) - if !ok || target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *json.Number values, got %T", jnc, i) - } - - switch vr.Type() { - case TypeDouble: - f64, err := vr.ReadDouble() - if err != nil { - return err - } - *target = json.Number(strconv.FormatFloat(f64, 'g', -1, 64)) - case TypeInt32: - i32, err := vr.ReadInt32() - if err != nil { - return err - } - *target = json.Number(strconv.FormatInt(int64(i32), 10)) - case TypeInt64: - i64, err := vr.ReadInt64() - if err != nil { - return err - } - *target = json.Number(strconv.FormatInt(i64, 10)) - default: - return fmt.Errorf("cannot decode %v into a json.Number", vr.Type()) - } - - return nil -} - -// URLCodec is the Codec for url.URL values. -type URLCodec struct{} - -var _ Codec = &URLCodec{} - -// EncodeValue implements the Codec interface. -func (uc *URLCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - var u *url.URL - switch t := i.(type) { - case url.URL: - u = &t - case *url.URL: - u = t - default: - return CodecEncodeError{Codec: uc, Types: []interface{}{url.URL{}, (*url.URL)(nil)}, Received: i} - } - - return vw.WriteString(u.String()) -} - -// DecodeValue implements the Codec interface. -func (uc *URLCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - if vr.Type() != TypeString { - return fmt.Errorf("cannot decode %v into a *url.URL", vr.Type()) - } - - str, err := vr.ReadString() - if err != nil { - return err - } - - u, err := url.Parse(str) - if err != nil { - return err - } - - // It's valid to use either a *url.URL or a url.URL - switch target := i.(type) { - case *url.URL: - if target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *url.URL values, got %T", uc, i) - } - *target = *u - case **url.URL: - if target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *url.URL values, got %T", uc, i) - } - *target = u - default: - return fmt.Errorf("%T can only be used to decode non-nil *url.URL values, got %T", uc, i) - } - return nil -} - -// TimeCodec is the Codec for time.Time values. -type TimeCodec struct{} - -var _ Codec = &TimeCodec{} - -// EncodeValue implements the Codec interface. -func (tc *TimeCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - var tt time.Time - switch t := i.(type) { - case time.Time: - tt = t - case *time.Time: - tt = *t - default: - return CodecEncodeError{Codec: tc, Types: []interface{}{time.Time{}, (*time.Time)(nil)}, Received: i} - } - - return vw.WriteDateTime(tt.UnixNano() / int64(time.Millisecond)) -} - -// DecodeValue implements the Codec interface. -func (tc *TimeCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - if vr.Type() != TypeDateTime { - return fmt.Errorf("cannot decode %v into a time.Time", vr.Type()) - } - - dt, err := vr.ReadDateTime() - if err != nil { - return err - } - - if target, ok := i.(*time.Time); ok && target != nil { - *target = time.Unix(dt/1000, dt%1000*1000000) - return nil - } - - if target, ok := i.(**time.Time); ok && target != nil { - tt := *target - if tt == nil { - tt = new(time.Time) - } - *tt = time.Unix(dt/1000, dt%1000*1000000) - *target = tt - return nil - } - - return fmt.Errorf("%T can only be used to decode non-nil *time.Time values, got %T", tc, i) -} - -// ReaderCodec is the Codec for Reader values. -type ReaderCodec struct{} - -var _ Codec = &ReaderCodec{} - -// EncodeValue implements the Codec interface. -func (rc *ReaderCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - rdr, ok := i.(Reader) - if !ok { - return CodecEncodeError{Codec: rc, Types: []interface{}{Reader{}}, Received: i} - } - - // TODO: Handle fast path, we should just copy the bytes into the - // *valueWriter and then do the cleanup of the state. - dw, err := vw.WriteDocument() - if err != nil { - return err - } - - itr, err := rdr.Iterator() - if err != nil { - return err - } - - for itr.Next() { - elem := itr.Element() - dvw, err := dw.WriteDocumentElement(elem.Key()) - if err != nil { - return err - } - - val := elem.Value() - err = defaultValueCodec.encodeValue(ec, dvw, val) - - if err != nil { - return err - } - } - - if err := itr.Err(); err != nil { - return err - } - - return dw.WriteDocumentEnd() -} - -// DecodeValue implements the Codec interface. -func (rc *ReaderCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - rdr, ok := i.(*Reader) - if !ok { - return CodecDecodeError{Codec: rc, Types: []interface{}{(*Reader)(nil)}, Received: i} - } - - if rdr == nil { - return fmt.Errorf("%T can only be used to decode non-nil *Reader", rc) - } - - if *rdr == nil { - *rdr = make(Reader, 256) - } - - // TODO: handle the fast path, if we have a *valueReader, just copy the - // bytes. - vw := vwPool.Get().(*valueWriter) - defer vwPool.Put(vw) - - vw.reset((*rdr)[:0]) - - err := CopyDocument(vw, vr) - if err != nil { - return err - } - - *rdr = vw.buf - return nil -} - -// ByteSliceCodec is the Codec for []byte values. -type ByteSliceCodec struct{} - -var _ Codec = &ByteSliceCodec{} - -// EncodeValue implements the Codec interface. -func (bsc *ByteSliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - var slcb []byte - switch t := i.(type) { - case []byte: - slcb = t - case *[]byte: - slcb = *t - default: - return CodecEncodeError{Codec: bsc, Types: []interface{}{[]byte{}, (*[]byte)(nil)}, Received: i} - } - - return vw.WriteBinary(slcb) -} - -// DecodeValue implements the Codec interface. -func (bsc *ByteSliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - if vr.Type() != TypeBinary { - return fmt.Errorf("cannot decode %v into a *[]byte", vr.Type()) - } - - target, ok := i.(*[]byte) - if !ok || target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *[]byte values, got %T", bsc, i) - } - - data, subtype, err := vr.ReadBinary() - if err != nil { - return err - } - if subtype != 0x00 { - return fmt.Errorf("%T can only be used to decode subtype 0x00 for %s, got %v", bsc, TypeBinary, subtype) - } - - *target = data - return nil -} - -// ElementSliceCodec is the Codec for []*Element values. -type ElementSliceCodec struct{} - -var _ Codec = &ElementSliceCodec{} - -// EncodeValue implements the Codec interface. -func (esc *ElementSliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - var slce []*Element - switch t := i.(type) { - case []*Element: - slce = t - case *[]*Element: - slce = *t - default: - return CodecEncodeError{Codec: esc, Types: []interface{}{[]*Element{}, (*[]*Element)(nil)}, Received: i} - } - - return defaultDocumentCodec.EncodeValue(ec, vw, &Document{elems: slce}) -} - -// DecodeValue implements the Codec interface. -func (esc *ElementSliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - var doc *Document - err := defaultDocumentCodec.DecodeValue(dc, vr, &doc) - if err != nil { - return err - } - - target, ok := i.(*[]*Element) - if !ok || target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *[]*Element values, got %T", esc, i) - } - - *target = doc.elems - return nil -} diff --git a/bson/codec_test.go b/bson/codec_test.go deleted file mode 100644 index dca38ecb90..0000000000 --- a/bson/codec_test.go +++ /dev/null @@ -1,3402 +0,0 @@ -package bson - -import ( - "encoding/json" - "errors" - "fmt" - "math" - "net/url" - "reflect" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/mongodb/mongo-go-driver/bson/decimal" - "github.com/mongodb/mongo-go-driver/bson/objectid" -) - -func TestProvidedCodecs(t *testing.T) { - var wrong = func(string, string) string { return "wrong" } - type mybool bool - type myint8 int8 - type myint16 int16 - type myint32 int32 - type myint64 int64 - type myint int - type myuint8 uint8 - type myuint16 uint16 - type myuint32 uint32 - type myuint64 uint64 - type myuint uint - type myfloat32 float32 - type myfloat64 float64 - type mystring string - - const cansetreflectiontest = "cansetreflectiontest" - - intAllowedTypes := []interface{}{int8(0), int16(0), int32(0), int64(0), int(0)} - intAllowedDecodeTypes := []interface{}{(*int8)(nil), (*int16)(nil), (*int32)(nil), (*int64)(nil), (*int)(nil)} - uintAllowedEncodeTypes := []interface{}{uint8(0), uint16(0), uint32(0), uint64(0), uint(0)} - uintAllowedDecodeTypes := []interface{}{(*uint8)(nil), (*uint16)(nil), (*uint32)(nil), (*uint64)(nil), (*uint)(nil)} - now := time.Now().Truncate(time.Millisecond) - pdatetime := new(DateTime) - *pdatetime = DateTime(1234567890) - pjsnum := new(json.Number) - *pjsnum = json.Number("3.14159") - d128 := decimal.NewDecimal128(12345, 67890) - - type enccase struct { - name string - val interface{} - ectx *EncodeContext - llvrw *llValueReaderWriter - invoke llvrwInvoked - err error - } - type deccase struct { - name string - val interface{} - dctx *DecodeContext - llvrw *llValueReaderWriter - invoke llvrwInvoked - err error - } - testCases := []struct { - name string - codec Codec - encodeCases []enccase - decodeCases []deccase - }{ - { - "BooleanCodec", - &BooleanCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &BooleanCodec{}, Types: []interface{}{bool(true)}, Received: wrong}, - }, - {"fast path", bool(true), nil, nil, llvrwWriteBoolean, nil}, - {"reflection path", mybool(true), nil, nil, llvrwWriteBoolean, nil}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeBoolean}, - llvrwNothing, - CodecDecodeError{Codec: &BooleanCodec{}, Types: []interface{}{bool(true)}, Received: &wrong}, - }, - { - "type not boolean", - bool(false), - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a boolean", TypeString), - }, - { - "fast path", - bool(true), - nil, - &llValueReaderWriter{bsontype: TypeBoolean, readval: bool(true)}, - llvrwReadBoolean, - nil, - }, - { - "reflection path", - mybool(true), - nil, - &llValueReaderWriter{bsontype: TypeBoolean, readval: bool(true)}, - llvrwReadBoolean, - nil, - }, - { - "reflection path error", - mybool(true), - nil, - &llValueReaderWriter{bsontype: TypeBoolean, readval: bool(true), err: errors.New("ReadBoolean Error")}, - llvrwReadBoolean, errors.New("ReadBoolean Error"), - }, - { - "can set false", - cansetreflectiontest, - nil, - &llValueReaderWriter{bsontype: TypeBoolean}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode settable (non-nil) values", &BooleanCodec{}), - }, - }, - }, - { - "IntCodec", - &IntCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &IntCodec{}, Types: intAllowedTypes, Received: wrong}, - }, - {"int8/fast path", int8(127), nil, nil, llvrwWriteInt32, nil}, - {"int16/fast path", int16(32767), nil, nil, llvrwWriteInt32, nil}, - {"int32/fast path", int32(2147483647), nil, nil, llvrwWriteInt32, nil}, - {"int64/fast path", int64(1234567890987), nil, nil, llvrwWriteInt64, nil}, - {"int/fast path", int(1234567), nil, nil, llvrwWriteInt64, nil}, - {"int64/fast path - minsize", int64(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, - {"int/fast path - minsize", int(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, - {"int64/fast path - minsize too large", int64(2147483648), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, - {"int/fast path - minsize too large", int(2147483648), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, - {"int8/reflection path", myint8(127), nil, nil, llvrwWriteInt32, nil}, - {"int16/reflection path", myint16(32767), nil, nil, llvrwWriteInt32, nil}, - {"int32/reflection path", myint32(2147483647), nil, nil, llvrwWriteInt32, nil}, - {"int64/reflection path", myint64(1234567890987), nil, nil, llvrwWriteInt64, nil}, - {"int/reflection path", myint(1234567890987), nil, nil, llvrwWriteInt64, nil}, - {"int64/reflection path - minsize", myint64(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, - {"int/reflection path - minsize", myint(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, - {"int64/reflection path - minsize too large", myint64(2147483648), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, - {"int/reflection path - minsize too large", myint(2147483648), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0)}, - llvrwReadInt32, - CodecDecodeError{Codec: &IntCodec{}, Types: intAllowedDecodeTypes, Received: &wrong}, - }, - { - "type not int32/int64", - 0, - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into an integer type", TypeString), - }, - { - "ReadInt32 error", - 0, - nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0), err: errors.New("ReadInt32 error"), errAfter: llvrwReadInt32}, - llvrwReadInt32, - errors.New("ReadInt32 error"), - }, - { - "ReadInt64 error", - 0, - nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(0), err: errors.New("ReadInt64 error"), errAfter: llvrwReadInt64}, - llvrwReadInt64, - errors.New("ReadInt64 error"), - }, - { - "ReadDouble error", - 0, - nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(0), err: errors.New("ReadDouble error"), errAfter: llvrwReadDouble}, - llvrwReadDouble, - errors.New("ReadDouble error"), - }, - { - "ReadDouble", int64(3), &DecodeContext{}, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(3.00)}, llvrwReadDouble, - nil, - }, - { - "ReadDouble (truncate)", int64(3), &DecodeContext{Truncate: true}, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(3.14)}, llvrwReadDouble, - nil, - }, - { - "ReadDouble (no truncate)", int64(0), nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(3.14)}, llvrwReadDouble, - fmt.Errorf("%T can only truncate float64 to an integer type when truncation is enabled", &IntCodec{}), - }, - { - "ReadDouble overflows int64", int64(0), nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: math.MaxFloat64}, llvrwReadDouble, - fmt.Errorf("%g overflows int64", math.MaxFloat64), - }, - {"int8/fast path", int8(127), nil, &llValueReaderWriter{bsontype: TypeInt32, readval: int32(127)}, llvrwReadInt32, nil}, - {"int16/fast path", int16(32676), nil, &llValueReaderWriter{bsontype: TypeInt32, readval: int32(32676)}, llvrwReadInt32, nil}, - {"int32/fast path", int32(1234), nil, &llValueReaderWriter{bsontype: TypeInt32, readval: int32(1234)}, llvrwReadInt32, nil}, - {"int64/fast path", int64(1234), nil, &llValueReaderWriter{bsontype: TypeInt64, readval: int64(1234)}, llvrwReadInt64, nil}, - {"int/fast path", int(1234), nil, &llValueReaderWriter{bsontype: TypeInt64, readval: int64(1234)}, llvrwReadInt64, nil}, - { - "int8/fast path - nil", (*int8)(nil), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0)}, llvrwReadInt32, - fmt.Errorf("%T can only be used to decode non-nil *int8", &IntCodec{}), - }, - { - "int16/fast path - nil", (*int16)(nil), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0)}, llvrwReadInt32, - fmt.Errorf("%T can only be used to decode non-nil *int16", &IntCodec{}), - }, - { - "int32/fast path - nil", (*int32)(nil), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0)}, llvrwReadInt32, - fmt.Errorf("%T can only be used to decode non-nil *int32", &IntCodec{}), - }, - { - "int64/fast path - nil", (*int64)(nil), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0)}, llvrwReadInt32, - fmt.Errorf("%T can only be used to decode non-nil *int64", &IntCodec{}), - }, - { - "int/fast path - nil", (*int)(nil), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0)}, llvrwReadInt32, - fmt.Errorf("%T can only be used to decode non-nil *int", &IntCodec{}), - }, - { - "int8/fast path - overflow", int8(0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(129)}, llvrwReadInt32, - fmt.Errorf("%d overflows int8", 129), - }, - { - "int16/fast path - overflow", int16(0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(32768)}, llvrwReadInt32, - fmt.Errorf("%d overflows int16", 32768), - }, - { - "int32/fast path - overflow", int32(0), nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(2147483648)}, llvrwReadInt64, - fmt.Errorf("%d overflows int32", 2147483648), - }, - { - "int8/fast path - overflow (negative)", int8(0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(-129)}, llvrwReadInt32, - fmt.Errorf("%d overflows int8", -129), - }, - { - "int16/fast path - overflow (negative)", int16(0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(-32769)}, llvrwReadInt32, - fmt.Errorf("%d overflows int16", -32769), - }, - { - "int32/fast path - overflow (negative)", int32(0), nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(-2147483649)}, llvrwReadInt64, - fmt.Errorf("%d overflows int32", -2147483649), - }, - { - "int8/reflection path", myint8(127), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(127)}, llvrwReadInt32, - nil, - }, - { - "int16/reflection path", myint16(255), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(255)}, llvrwReadInt32, - nil, - }, - { - "int32/reflection path", myint32(511), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(511)}, llvrwReadInt32, - nil, - }, - { - "int64/reflection path", myint64(1023), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(1023)}, llvrwReadInt32, - nil, - }, - { - "int/reflection path", myint(2047), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(2047)}, llvrwReadInt32, - nil, - }, - { - "int8/reflection path - overflow", myint8(0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(129)}, llvrwReadInt32, - fmt.Errorf("%d overflows int8", 129), - }, - { - "int16/reflection path - overflow", myint16(0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(32768)}, llvrwReadInt32, - fmt.Errorf("%d overflows int16", 32768), - }, - { - "int32/reflection path - overflow", myint32(0), nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(2147483648)}, llvrwReadInt64, - fmt.Errorf("%d overflows int32", 2147483648), - }, - { - "int8/reflection path - overflow (negative)", myint8(0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(-129)}, llvrwReadInt32, - fmt.Errorf("%d overflows int8", -129), - }, - { - "int16/reflection path - overflow (negative)", myint16(0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(-32769)}, llvrwReadInt32, - fmt.Errorf("%d overflows int16", -32769), - }, - { - "int32/reflection path - overflow (negative)", myint32(0), nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(-2147483649)}, llvrwReadInt64, - fmt.Errorf("%d overflows int32", -2147483649), - }, - { - "can set false", - cansetreflectiontest, - nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0)}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode settable (non-nil) values", &IntCodec{}), - }, - }, - }, - { - "UintCodec", - &UintCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &UintCodec{}, Types: uintAllowedEncodeTypes, Received: wrong}, - }, - {"uint8/fast path", uint8(127), nil, nil, llvrwWriteInt32, nil}, - {"uint16/fast path", uint16(32767), nil, nil, llvrwWriteInt32, nil}, - {"uint32/fast path", uint32(2147483647), nil, nil, llvrwWriteInt64, nil}, - {"uint64/fast path", uint64(1234567890987), nil, nil, llvrwWriteInt64, nil}, - {"uint/fast path", uint(1234567), nil, nil, llvrwWriteInt64, nil}, - {"uint32/fast path - minsize", uint32(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, - {"uint64/fast path - minsize", uint64(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, - {"uint/fast path - minsize", uint(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, - {"uint32/fast path - minsize too large", uint32(2147483648), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, - {"uint64/fast path - minsize too large", uint64(2147483648), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, - {"uint/fast path - minsize too large", uint(2147483648), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, - {"uint64/fast path - overflow", uint64(1 << 63), nil, nil, llvrwNothing, fmt.Errorf("%d overflows int64", uint(1<<63))}, - {"uint/fast path - overflow", uint(1 << 63), nil, nil, llvrwNothing, fmt.Errorf("%d overflows int64", uint(1<<63))}, - {"uint8/reflection path", myuint8(127), nil, nil, llvrwWriteInt32, nil}, - {"uint16/reflection path", myuint16(32767), nil, nil, llvrwWriteInt32, nil}, - {"uint32/reflection path", myuint32(2147483647), nil, nil, llvrwWriteInt64, nil}, - {"uint64/reflection path", myuint64(1234567890987), nil, nil, llvrwWriteInt64, nil}, - {"uint/reflection path", myuint(1234567890987), nil, nil, llvrwWriteInt64, nil}, - {"uint32/reflection path - minsize", myuint32(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, - {"uint64/reflection path - minsize", myuint64(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, - {"uint/reflection path - minsize", myuint(2147483647), &EncodeContext{MinSize: true}, nil, llvrwWriteInt32, nil}, - {"uint32/reflection path - minsize too large", myuint(1 << 31), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, - {"uint64/reflection path - minsize too large", myuint64(1 << 31), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, - {"uint/reflection path - minsize too large", myuint(2147483648), &EncodeContext{MinSize: true}, nil, llvrwWriteInt64, nil}, - {"uint64/reflection path - overflow", myuint64(1 << 63), nil, nil, llvrwNothing, fmt.Errorf("%d overflows int64", uint(1<<63))}, - {"uint/reflection path - overflow", myuint(1 << 63), nil, nil, llvrwNothing, fmt.Errorf("%d overflows int64", uint(1<<63))}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0)}, - llvrwReadInt32, - CodecDecodeError{Codec: &UintCodec{}, Types: uintAllowedDecodeTypes, Received: &wrong}, - }, - { - "type not int32/int64", - 0, - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into an integer type", TypeString), - }, - { - "ReadInt32 error", - uint(0), - nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0), err: errors.New("ReadInt32 error"), errAfter: llvrwReadInt32}, - llvrwReadInt32, - errors.New("ReadInt32 error"), - }, - { - "ReadInt64 error", - uint(0), - nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(0), err: errors.New("ReadInt64 error"), errAfter: llvrwReadInt64}, - llvrwReadInt64, - errors.New("ReadInt64 error"), - }, - { - "ReadDouble error", - 0, - nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(0), err: errors.New("ReadDouble error"), errAfter: llvrwReadDouble}, - llvrwReadDouble, - errors.New("ReadDouble error"), - }, - { - "ReadDouble", uint64(3), &DecodeContext{}, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(3.00)}, llvrwReadDouble, - nil, - }, - { - "ReadDouble (truncate)", uint64(3), &DecodeContext{Truncate: true}, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(3.14)}, llvrwReadDouble, - nil, - }, - { - "ReadDouble (no truncate)", uint64(0), nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(3.14)}, llvrwReadDouble, - fmt.Errorf("%T can only truncate float64 to an integer type when truncation is enabled", &UintCodec{}), - }, - { - "ReadDouble overflows int64", uint64(0), nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: math.MaxFloat64}, llvrwReadDouble, - fmt.Errorf("%g overflows int64", math.MaxFloat64), - }, - {"uint8/fast path", uint8(127), nil, &llValueReaderWriter{bsontype: TypeInt32, readval: int32(127)}, llvrwReadInt32, nil}, - {"uint16/fast path", uint16(255), nil, &llValueReaderWriter{bsontype: TypeInt32, readval: int32(255)}, llvrwReadInt32, nil}, - {"uint32/fast path", uint32(1234), nil, &llValueReaderWriter{bsontype: TypeInt32, readval: int32(1234)}, llvrwReadInt32, nil}, - {"uint64/fast path", uint64(1234), nil, &llValueReaderWriter{bsontype: TypeInt64, readval: int64(1234)}, llvrwReadInt64, nil}, - {"uint/fast path", uint(1234), nil, &llValueReaderWriter{bsontype: TypeInt64, readval: int64(1234)}, llvrwReadInt64, nil}, - { - "uint8/fast path - nil", (*uint8)(nil), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0)}, llvrwReadInt32, - fmt.Errorf("%T can only be used to decode non-nil *uint8", &UintCodec{}), - }, - { - "uint16/fast path - nil", (*uint16)(nil), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0)}, llvrwReadInt32, - fmt.Errorf("%T can only be used to decode non-nil *uint16", &UintCodec{}), - }, - { - "uint32/fast path - nil", (*uint32)(nil), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0)}, llvrwReadInt32, - fmt.Errorf("%T can only be used to decode non-nil *uint32", &UintCodec{}), - }, - { - "uint64/fast path - nil", (*uint64)(nil), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0)}, llvrwReadInt32, - fmt.Errorf("%T can only be used to decode non-nil *uint64", &UintCodec{}), - }, - { - "uint/fast path - nil", (*uint)(nil), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0)}, llvrwReadInt32, - fmt.Errorf("%T can only be used to decode non-nil *uint", &UintCodec{}), - }, - { - "uint8/fast path - overflow", uint8(0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(1 << 8)}, llvrwReadInt32, - fmt.Errorf("%d overflows uint8", 1<<8), - }, - { - "uint16/fast path - overflow", uint16(0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(1 << 16)}, llvrwReadInt32, - fmt.Errorf("%d overflows uint16", 1<<16), - }, - { - "uint32/fast path - overflow", uint32(0), nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(1 << 32)}, llvrwReadInt64, - fmt.Errorf("%d overflows uint32", 1<<32), - }, - { - "uint8/fast path - overflow (negative)", uint8(0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(-1)}, llvrwReadInt32, - fmt.Errorf("%d overflows uint8", -1), - }, - { - "uint16/fast path - overflow (negative)", uint16(0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(-1)}, llvrwReadInt32, - fmt.Errorf("%d overflows uint16", -1), - }, - { - "uint32/fast path - overflow (negative)", uint32(0), nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(-1)}, llvrwReadInt64, - fmt.Errorf("%d overflows uint32", -1), - }, - { - "uint64/fast path - overflow (negative)", uint64(0), nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(-1)}, llvrwReadInt64, - fmt.Errorf("%d overflows uint64", -1), - }, - { - "uint/fast path - overflow (negative)", uint(0), nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(-1)}, llvrwReadInt64, - fmt.Errorf("%d overflows uint", -1), - }, - { - "uint8/reflection path", myuint8(127), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(127)}, llvrwReadInt32, - nil, - }, - { - "uint16/reflection path", myuint16(255), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(255)}, llvrwReadInt32, - nil, - }, - { - "uint32/reflection path", myuint32(511), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(511)}, llvrwReadInt32, - nil, - }, - { - "uint64/reflection path", myuint64(1023), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(1023)}, llvrwReadInt32, - nil, - }, - { - "uint/reflection path", myuint(2047), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(2047)}, llvrwReadInt32, - nil, - }, - { - "uint8/reflection path - overflow", myuint8(0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(1 << 8)}, llvrwReadInt32, - fmt.Errorf("%d overflows uint8", 1<<8), - }, - { - "uint16/reflection path - overflow", myuint16(0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(1 << 16)}, llvrwReadInt32, - fmt.Errorf("%d overflows uint16", 1<<16), - }, - { - "uint32/reflection path - overflow", myuint32(0), nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(1 << 32)}, llvrwReadInt64, - fmt.Errorf("%d overflows uint32", 1<<32), - }, - { - "uint8/reflection path - overflow (negative)", myuint8(0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(-1)}, llvrwReadInt32, - fmt.Errorf("%d overflows uint8", -1), - }, - { - "uint16/reflection path - overflow (negative)", myuint16(0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(-1)}, llvrwReadInt32, - fmt.Errorf("%d overflows uint16", -1), - }, - { - "uint32/reflection path - overflow (negative)", myuint32(0), nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(-1)}, llvrwReadInt64, - fmt.Errorf("%d overflows uint32", -1), - }, - { - "uint64/reflection path - overflow (negative)", myuint64(0), nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(-1)}, llvrwReadInt64, - fmt.Errorf("%d overflows uint64", -1), - }, - { - "uint/reflection path - overflow (negative)", myuint(0), nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(-1)}, llvrwReadInt64, - fmt.Errorf("%d overflows uint", -1), - }, - { - "can set false", - cansetreflectiontest, - nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0)}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode settable (non-nil) values", &UintCodec{}), - }, - }, - }, - { - "FloatCodec", - &FloatCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &FloatCodec{}, Types: []interface{}{float32(0), float64(0)}, Received: wrong}, - }, - {"float32/fast path", float32(3.14159), nil, nil, llvrwWriteDouble, nil}, - {"float64/fast path", float64(3.14159), nil, nil, llvrwWriteDouble, nil}, - {"float32/reflection path", myfloat32(3.14159), nil, nil, llvrwWriteDouble, nil}, - {"float64/reflection path", myfloat64(3.14159), nil, nil, llvrwWriteDouble, nil}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(0)}, - llvrwReadDouble, - CodecDecodeError{Codec: &FloatCodec{}, Types: []interface{}{(*float32)(nil), (*float64)(nil)}, Received: &wrong}, - }, - { - "type not double", - 0, - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a float32 or float64 type", TypeString), - }, - { - "ReadDouble error", - float64(0), - nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(0), err: errors.New("ReadDouble error"), errAfter: llvrwReadDouble}, - llvrwReadDouble, - errors.New("ReadDouble error"), - }, - { - "ReadInt32 error", - float64(0), - nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0), err: errors.New("ReadInt32 error"), errAfter: llvrwReadInt32}, - llvrwReadInt32, - errors.New("ReadInt32 error"), - }, - { - "ReadInt64 error", - float64(0), - nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(0), err: errors.New("ReadInt64 error"), errAfter: llvrwReadInt64}, - llvrwReadInt64, - errors.New("ReadInt64 error"), - }, - { - "float64/int32", float32(32.0), nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(32)}, llvrwReadInt32, - nil, - }, - { - "float64/int64", float32(64.0), nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(64)}, llvrwReadInt64, - nil, - }, - { - "float32/fast path (equal)", float32(3.0), nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(3.0)}, llvrwReadDouble, - nil, - }, - { - "float64/fast path", float64(3.14159), nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(3.14159)}, llvrwReadDouble, - nil, - }, - { - "float32/fast path (truncate)", float32(3.14), &DecodeContext{Truncate: true}, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(3.14)}, llvrwReadDouble, - nil, - }, - { - "float32/fast path (no truncate)", float32(0), nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(3.14)}, llvrwReadDouble, - fmt.Errorf("%T can only convert float64 to float32 when truncation is allowed", &FloatCodec{}), - }, - { - "float32/fast path - nil", (*float32)(nil), nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(0)}, llvrwReadDouble, - fmt.Errorf("%T can only be used to decode non-nil *float32", &FloatCodec{}), - }, - { - "float64/fast path - nil", (*float64)(nil), nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(0)}, llvrwReadDouble, - fmt.Errorf("%T can only be used to decode non-nil *float64", &FloatCodec{}), - }, - { - "float32/reflection path (equal)", myfloat32(3.0), nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(3.0)}, llvrwReadDouble, - nil, - }, - { - "float64/reflection path", myfloat64(3.14159), nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(3.14159)}, llvrwReadDouble, - nil, - }, - { - "float32/reflection path (truncate)", myfloat32(3.14), &DecodeContext{Truncate: true}, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(3.14)}, llvrwReadDouble, - nil, - }, - { - "float32/reflection path (no truncate)", myfloat32(0), nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(3.14)}, llvrwReadDouble, - fmt.Errorf("%T can only convert float64 to float32 when truncation is allowed", &FloatCodec{}), - }, - { - "can set false", - cansetreflectiontest, - nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(0)}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode settable (non-nil) values", &FloatCodec{}), - }, - }, - }, - { - "StringCodec", - &StringCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &StringCodec{}, Types: []interface{}{string(""), JavaScriptCode(""), Symbol("")}, Received: wrong}, - }, - {"string/fast path", string("foobar"), nil, nil, llvrwWriteString, nil}, - {"JavaScript/fast path", JavaScriptCode("foobar"), nil, nil, llvrwWriteJavascript, nil}, - {"Symbol/fast path", Symbol("foobar"), nil, nil, llvrwWriteSymbol, nil}, - {"reflection path", mystring("foobarbaz"), nil, nil, llvrwWriteString, nil}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeString, readval: string("")}, - llvrwReadString, - CodecDecodeError{Codec: &StringCodec{}, Types: []interface{}{(*string)(nil), (*JavaScriptCode)(nil), (*Symbol)(nil)}, Received: &wrong}, - }, - { - "type not string", - string(""), - nil, - &llValueReaderWriter{bsontype: TypeBoolean}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a string type", TypeBoolean), - }, - { - "ReadString error", - string(""), - nil, - &llValueReaderWriter{bsontype: TypeString, readval: string(""), err: errors.New("ReadString error"), errAfter: llvrwReadString}, - llvrwReadString, - errors.New("ReadString error"), - }, - { - "ReadJavaScript error", - string(""), - nil, - &llValueReaderWriter{bsontype: TypeJavaScript, readval: string(""), err: errors.New("ReadJS error"), errAfter: llvrwReadJavascript}, - llvrwReadJavascript, - errors.New("ReadJS error"), - }, - { - "ReadSymbol error", - string(""), - nil, - &llValueReaderWriter{bsontype: TypeSymbol, readval: string(""), err: errors.New("ReadSymbol error"), errAfter: llvrwReadSymbol}, - llvrwReadSymbol, - errors.New("ReadSymbol error"), - }, - { - "string/fast path", - string("foobar"), - nil, - &llValueReaderWriter{bsontype: TypeString, readval: string("foobar")}, - llvrwReadString, - nil, - }, - { - "JavaScript/fast path", - JavaScriptCode("var hello = 'world';"), - nil, - &llValueReaderWriter{bsontype: TypeJavaScript, readval: string("var hello = 'world';")}, - llvrwReadJavascript, - nil, - }, - { - "Symbol/fast path", - Symbol("foobarbaz"), - nil, - &llValueReaderWriter{bsontype: TypeSymbol, readval: Symbol("foobarbaz")}, - llvrwReadSymbol, - nil, - }, - { - "string/fast path - nil", (*string)(nil), nil, - &llValueReaderWriter{bsontype: TypeString, readval: string("")}, llvrwReadString, - fmt.Errorf("%T can only be used to decode non-nil *string", &StringCodec{}), - }, - { - "JavaScript/fast path - nil", (*JavaScriptCode)(nil), nil, - &llValueReaderWriter{bsontype: TypeJavaScript, readval: string("")}, llvrwReadJavascript, - fmt.Errorf("%T can only be used to decode non-nil *JavaScriptCode", &StringCodec{}), - }, - { - "Symbol/fast path - nil", (*Symbol)(nil), nil, - &llValueReaderWriter{bsontype: TypeSymbol, readval: Symbol("")}, llvrwReadSymbol, - fmt.Errorf("%T can only be used to decode non-nil *Symbol", &StringCodec{}), - }, - { - "reflection path", - mystring("foobarbaz"), - nil, - &llValueReaderWriter{bsontype: TypeString, readval: string("foobarbaz")}, - llvrwReadString, - nil, - }, - { - "reflection path error", - mystring("foobarbazqux"), - nil, - &llValueReaderWriter{bsontype: TypeString, readval: string("foobarbazqux"), err: errors.New("ReadString Error"), errAfter: llvrwReadString}, - llvrwReadString, errors.New("ReadString Error"), - }, - { - "can set false", - cansetreflectiontest, - nil, - &llValueReaderWriter{bsontype: TypeString, readval: string("")}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode settable (non-nil) values", &StringCodec{}), - }, - }, - }, - { - "TimeCodec", - &TimeCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &TimeCodec{}, Types: []interface{}{time.Time{}, (*time.Time)(nil)}, Received: wrong}, - }, - {"time.Time", now, nil, nil, llvrwWriteDateTime, nil}, - {"*time.Time", &now, nil, nil, llvrwWriteDateTime, nil}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(0)}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a time.Time", TypeInt32), - }, - { - "type not *time.Time", - int64(0), - nil, - &llValueReaderWriter{bsontype: TypeDateTime, readval: int64(1234567890)}, - llvrwReadDateTime, - fmt.Errorf("%T can only be used to decode non-nil *time.Time values, got %T", &TimeCodec{}, (*int64)(nil)), - }, - { - "ReadDateTime error", - time.Time{}, - nil, - &llValueReaderWriter{bsontype: TypeDateTime, readval: int64(0), err: errors.New("ReadDateTime error"), errAfter: llvrwReadDateTime}, - llvrwReadDateTime, - errors.New("ReadDateTime error"), - }, - { - "time.Time", - now, - nil, - &llValueReaderWriter{bsontype: TypeDateTime, readval: int64(now.UnixNano() / int64(time.Millisecond))}, - llvrwReadDateTime, - nil, - }, - { - "*time.Time", - &now, - nil, - &llValueReaderWriter{bsontype: TypeDateTime, readval: int64(now.UnixNano() / int64(time.Millisecond))}, - llvrwReadDateTime, - nil, - }, - }, - }, - { - "MapCodec", - &MapCodec{}, - []enccase{ - { - "wrong kind", - wrong, - nil, - nil, - llvrwNothing, - fmt.Errorf("%T can only encode maps with string keys", &MapCodec{}), - }, - { - "wrong kind (non-string key)", - map[int]interface{}{}, - nil, - nil, - llvrwNothing, - fmt.Errorf("%T can only encode maps with string keys", &MapCodec{}), - }, - { - "WriteDocument Error", - map[string]interface{}{}, - nil, - &llValueReaderWriter{err: errors.New("wd error"), errAfter: llvrwWriteDocument}, - llvrwWriteDocument, - errors.New("wd error"), - }, - { - "Lookup Error", - map[string]interface{}{}, - &EncodeContext{Registry: NewEmptyRegistryBuilder().Build()}, - &llValueReaderWriter{}, - llvrwWriteDocument, - ErrNoCodec{Type: reflect.TypeOf((*interface{})(nil)).Elem()}, - }, - { - "WriteDocumentElement Error", - map[string]interface{}{"foo": "bar"}, - &EncodeContext{Registry: NewRegistryBuilder().Build()}, - &llValueReaderWriter{err: errors.New("wde error"), errAfter: llvrwWriteDocumentElement}, - llvrwWriteDocumentElement, - errors.New("wde error"), - }, - { - "EncodeValue Error", - map[string]interface{}{"foo": "bar"}, - &EncodeContext{Registry: NewRegistryBuilder().Build()}, - &llValueReaderWriter{err: errors.New("ev error"), errAfter: llvrwWriteString}, - llvrwWriteString, - errors.New("ev error"), - }, - }, - []deccase{ - { - "wrong kind", - wrong, - nil, - &llValueReaderWriter{}, - llvrwNothing, - fmt.Errorf("%T can only decode settable maps with string keys", &MapCodec{}), - }, - { - "wrong kind (non-string key)", - map[int]interface{}{}, - nil, - &llValueReaderWriter{}, - llvrwNothing, - fmt.Errorf("%T can only decode settable maps with string keys", &MapCodec{}), - }, - { - "ReadDocument Error", - make(map[string]interface{}), - nil, - &llValueReaderWriter{err: errors.New("rd error"), errAfter: llvrwReadDocument}, - llvrwReadDocument, - errors.New("rd error"), - }, - { - "Lookup Error", - map[string]string{}, - &DecodeContext{Registry: NewEmptyRegistryBuilder().Build()}, - &llValueReaderWriter{}, - llvrwReadDocument, - ErrNoCodec{Type: reflect.TypeOf(string(""))}, - }, - { - "ReadElement Error", - make(map[string]interface{}), - &DecodeContext{Registry: NewRegistryBuilder().Build()}, - &llValueReaderWriter{err: errors.New("re error"), errAfter: llvrwReadElement}, - llvrwReadElement, - errors.New("re error"), - }, - { - "DecodeValue Error", - map[string]string{"foo": "bar"}, - &DecodeContext{Registry: NewRegistryBuilder().Build()}, - &llValueReaderWriter{bsontype: TypeString, err: errors.New("dv error"), errAfter: llvrwReadString}, - llvrwReadString, - errors.New("dv error"), - }, - }, - }, - { - "SliceCodec", - &SliceCodec{}, - []enccase{ - { - "wrong kind", - wrong, - nil, - nil, - llvrwNothing, - fmt.Errorf("%T can only encode arrays and slices", &SliceCodec{}), - }, - { - "WriteArray Error", - []string{}, - nil, - &llValueReaderWriter{err: errors.New("wa error"), errAfter: llvrwWriteArray}, - llvrwWriteArray, - errors.New("wa error"), - }, - { - "Lookup Error", - []interface{}{}, - &EncodeContext{Registry: NewEmptyRegistryBuilder().Build()}, - &llValueReaderWriter{}, - llvrwWriteArray, - ErrNoCodec{Type: reflect.TypeOf((*interface{})(nil)).Elem()}, - }, - { - "WriteArrayElement Error", - []string{"foo"}, - &EncodeContext{Registry: NewRegistryBuilder().Build()}, - &llValueReaderWriter{err: errors.New("wae error"), errAfter: llvrwWriteArrayElement}, - llvrwWriteArrayElement, - errors.New("wae error"), - }, - { - "EncodeValue Error", - []string{"foo"}, - &EncodeContext{Registry: NewRegistryBuilder().Build()}, - &llValueReaderWriter{err: errors.New("ev error"), errAfter: llvrwWriteString}, - llvrwWriteString, - errors.New("ev error"), - }, - }, - []deccase{ - { - "wrong kind", - wrong, - nil, - &llValueReaderWriter{}, - llvrwNothing, - fmt.Errorf("%T can only decode settable slice and array values, got %T", &SliceCodec{}, &wrong), - }, - { - "can set false", - (*[]string)(nil), - nil, - &llValueReaderWriter{}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode non-nil pointers to slice or array values, got %T", &SliceCodec{}, (*[]string)(nil)), - }, - { - "Not Type Array", - []interface{}{}, - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - errors.New("cannot decode string into a slice"), - }, - { - "ReadArray Error", - []interface{}{}, - nil, - &llValueReaderWriter{err: errors.New("ra error"), errAfter: llvrwReadArray, bsontype: TypeArray}, - llvrwReadArray, - errors.New("ra error"), - }, - { - "Lookup Error", - []string{}, - &DecodeContext{Registry: NewEmptyRegistryBuilder().Build()}, - &llValueReaderWriter{bsontype: TypeArray}, - llvrwReadArray, - ErrNoCodec{Type: reflect.TypeOf(string(""))}, - }, - { - "ReadValue Error", - []string{}, - &DecodeContext{Registry: NewRegistryBuilder().Build()}, - &llValueReaderWriter{err: errors.New("rv error"), errAfter: llvrwReadValue, bsontype: TypeArray}, - llvrwReadValue, - errors.New("rv error"), - }, - { - "DecodeValue Error", - []string{}, - &DecodeContext{Registry: NewRegistryBuilder().Build()}, - &llValueReaderWriter{bsontype: TypeArray}, - llvrwReadValue, - errors.New("cannot decode array into a string type"), - }, - }, - }, - { - "BinaryCodec", - &BinaryCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &BinaryCodec{}, Types: []interface{}{Binary{}, (*Binary)(nil)}, Received: wrong}, - }, - {"Binary/success", Binary{Data: []byte{0x01, 0x02}, Subtype: 0xFF}, nil, nil, llvrwWriteBinaryWithSubtype, nil}, - {"*Binary/success", &Binary{Data: []byte{0x01, 0x02}, Subtype: 0xFF}, nil, nil, llvrwWriteBinaryWithSubtype, nil}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeBinary, readval: Binary{}}, - llvrwReadBinary, - fmt.Errorf("%T can only be used to decode non-nil *Binary values, got %T", &BinaryCodec{}, &wrong), - }, - { - "type not binary", - Binary{}, - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a Binary", TypeString), - }, - { - "ReadBinary Error", - Binary{}, - nil, - &llValueReaderWriter{bsontype: TypeBinary, err: errors.New("rb error"), errAfter: llvrwReadBinary}, - llvrwReadBinary, - errors.New("rb error"), - }, - { - "Binary/success", - Binary{Data: []byte{0x01, 0x02, 0x03}, Subtype: 0xFF}, - nil, - &llValueReaderWriter{bsontype: TypeBinary, readval: Binary{Data: []byte{0x01, 0x02, 0x03}, Subtype: 0xFF}}, - llvrwReadBinary, - nil, - }, - { - "*Binary/success", - &Binary{Data: []byte{0x01, 0x02, 0x03}, Subtype: 0xFF}, - nil, - &llValueReaderWriter{bsontype: TypeBinary, readval: Binary{Data: []byte{0x01, 0x02, 0x03}, Subtype: 0xFF}}, - llvrwReadBinary, - nil, - }, - }, - }, - { - "UndefinedCodec", - &UndefinedCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &UndefinedCodec{}, Types: []interface{}{Undefinedv2{}, (*Undefinedv2)(nil)}, Received: wrong}, - }, - {"Undefined/success", Undefinedv2{}, nil, nil, llvrwWriteUndefined, nil}, - {"*Undefined/success", &Undefinedv2{}, nil, nil, llvrwWriteUndefined, nil}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeUndefined}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode non-nil *Undefined values, got %T", &UndefinedCodec{}, &wrong), - }, - { - "type not undefined", - Undefinedv2{}, - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into an Undefined", TypeString), - }, - { - "ReadUndefined Error", - Undefinedv2{}, - nil, - &llValueReaderWriter{bsontype: TypeUndefined, err: errors.New("ru error"), errAfter: llvrwReadUndefined}, - llvrwReadUndefined, - errors.New("ru error"), - }, - { - "ReadUndefined/success", - Undefinedv2{}, - nil, - &llValueReaderWriter{bsontype: TypeUndefined}, - llvrwReadUndefined, - nil, - }, - }, - }, - { - "ObjectIDCodec", - &ObjectIDCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &ObjectIDCodec{}, Types: []interface{}{objectid.ObjectID{}, (*objectid.ObjectID)(nil)}, Received: wrong}, - }, - { - "objectid.ObjectID/success", - objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, - nil, nil, llvrwWriteObjectID, nil, - }, - { - "*objectid.ObjectID/success", - &objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, - nil, nil, llvrwWriteObjectID, nil, - }, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeObjectID}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode non-nil *objectid.ObjectID values, got %T", &ObjectIDCodec{}, &wrong), - }, - { - "type not objectID", - objectid.ObjectID{}, - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into an ObjectID", TypeString), - }, - { - "ReadObjectID Error", - objectid.ObjectID{}, - nil, - &llValueReaderWriter{bsontype: TypeObjectID, err: errors.New("roid error"), errAfter: llvrwReadObjectID}, - llvrwReadObjectID, - errors.New("roid error"), - }, - { - "success", - objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, - nil, - &llValueReaderWriter{ - bsontype: TypeObjectID, - readval: objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, - }, - llvrwReadObjectID, - nil, - }, - }, - }, - { - "DateTimeCodec", - &DateTimeCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &DateTimeCodec{}, Types: []interface{}{DateTime(0), (*DateTime)(nil)}, Received: wrong}, - }, - {"DateTime/success", DateTime(1234567890), nil, nil, llvrwWriteDateTime, nil}, - {"*DateTime/success", pdatetime, nil, nil, llvrwWriteDateTime, nil}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeDateTime}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode non-nil *DateTime values, got %T", &DateTimeCodec{}, &wrong), - }, - { - "type not datetime", - DateTime(0), - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a DateTime", TypeString), - }, - { - "ReadDateTime Error", - DateTime(0), - nil, - &llValueReaderWriter{bsontype: TypeDateTime, err: errors.New("rdt error"), errAfter: llvrwReadDateTime}, - llvrwReadDateTime, - errors.New("rdt error"), - }, - { - "success", - DateTime(1234567890), - nil, - &llValueReaderWriter{bsontype: TypeDateTime, readval: int64(1234567890)}, - llvrwReadDateTime, - nil, - }, - }, - }, - { - "NullCodec", - &NullCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &NullCodec{}, Types: []interface{}{Nullv2{}, (*Nullv2)(nil)}, Received: wrong}, - }, - {"Null/success", Nullv2{}, nil, nil, llvrwWriteNull, nil}, - {"*Null/success", &Nullv2{}, nil, nil, llvrwWriteNull, nil}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeNull}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode non-nil *Null values, got %T", &NullCodec{}, &wrong), - }, - { - "type not null", - Nullv2{}, - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a Null", TypeString), - }, - { - "ReadNull Error", - Nullv2{}, - nil, - &llValueReaderWriter{bsontype: TypeNull, err: errors.New("rn error"), errAfter: llvrwReadNull}, - llvrwReadNull, - errors.New("rn error"), - }, - { - "success", - Nullv2{}, - nil, - &llValueReaderWriter{bsontype: TypeNull}, - llvrwReadNull, - nil, - }, - }, - }, - { - "RegexCodec", - &RegexCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &RegexCodec{}, Types: []interface{}{Regex{}, (*Regex)(nil)}, Received: wrong}, - }, - {"Regex/success", Regex{Pattern: "foo", Options: "bar"}, nil, nil, llvrwWriteRegex, nil}, - {"*Regex/success", &Regex{Pattern: "foo", Options: "bar"}, nil, nil, llvrwWriteRegex, nil}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeRegex}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode non-nil *Regex values, got %T", &RegexCodec{}, &wrong), - }, - { - "type not regex", - Regex{}, - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a Regex", TypeString), - }, - { - "ReadRegex Error", - Regex{}, - nil, - &llValueReaderWriter{bsontype: TypeRegex, err: errors.New("rr error"), errAfter: llvrwReadRegex}, - llvrwReadRegex, - errors.New("rr error"), - }, - { - "success", - Regex{Pattern: "foo", Options: "bar"}, - nil, - &llValueReaderWriter{bsontype: TypeRegex, readval: Regex{Pattern: "foo", Options: "bar"}}, - llvrwReadRegex, - nil, - }, - }, - }, - { - "DBPointerCodec", - &DBPointerCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &DBPointerCodec{}, Types: []interface{}{DBPointer{}, (*DBPointer)(nil)}, Received: wrong}, - }, - { - "DBPointer/success", - DBPointer{ - DB: "foobar", - Pointer: objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, - }, - nil, nil, llvrwWriteDBPointer, nil, - }, - { - "*DBPointer/success", - &DBPointer{ - DB: "foobar", - Pointer: objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, - }, - nil, nil, llvrwWriteDBPointer, nil, - }, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeDBPointer}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode non-nil *DBPointer values, got %T", &DBPointerCodec{}, &wrong), - }, - { - "type not dbpointer", - DBPointer{}, - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a DBPointer", TypeString), - }, - { - "ReadDBPointer Error", - DBPointer{}, - nil, - &llValueReaderWriter{bsontype: TypeDBPointer, err: errors.New("rdbp error"), errAfter: llvrwReadDBPointer}, - llvrwReadDBPointer, - errors.New("rdbp error"), - }, - { - "success", - DBPointer{ - DB: "foobar", - Pointer: objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, - }, - nil, - &llValueReaderWriter{ - bsontype: TypeDBPointer, - readval: DBPointer{ - DB: "foobar", - Pointer: objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, - }, - }, - llvrwReadDBPointer, - nil, - }, - }, - }, - { - "CodeWithScopeCodec", - &CodeWithScopeCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &CodeWithScopeCodec{}, Types: []interface{}{CodeWithScope{}, (*CodeWithScope)(nil)}, Received: wrong}, - }, - { - "WriteCodeWithScope error", - CodeWithScope{}, - nil, - &llValueReaderWriter{err: errors.New("wcws error"), errAfter: llvrwWriteCodeWithScope}, - llvrwWriteCodeWithScope, - errors.New("wcws error"), - }, - { - "CodeWithScope/success", - CodeWithScope{ - Code: "var hello = 'world';", - Scope: NewDocument(), - }, - nil, nil, llvrwWriteDocumentEnd, nil, - }, - { - "*CodeWithScope/success", - &CodeWithScope{ - Code: "var hello = 'world';", - Scope: NewDocument(), - }, - nil, nil, llvrwWriteDocumentEnd, nil, - }, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeCodeWithScope}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode non-nil *CodeWithScope values, got %T", &CodeWithScopeCodec{}, &wrong), - }, - { - "type not codewithscope", - CodeWithScope{}, - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a CodeWithScope", TypeString), - }, - { - "ReadCodeWithScope Error", - CodeWithScope{}, - nil, - &llValueReaderWriter{bsontype: TypeCodeWithScope, err: errors.New("rcws error"), errAfter: llvrwReadCodeWithScope}, - llvrwReadCodeWithScope, - errors.New("rcws error"), - }, - { - "decodeDocument Error", - CodeWithScope{ - Code: "var hello = 'world';", - Scope: NewDocument(EC.Null("foo")), - }, - nil, - &llValueReaderWriter{bsontype: TypeCodeWithScope, err: errors.New("dd error"), errAfter: llvrwReadElement}, - llvrwReadElement, - errors.New("dd error"), - }, - }, - }, - { - "TimestampCodec", - &TimestampCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &TimestampCodec{}, Types: []interface{}{Timestamp{}, (*Timestamp)(nil)}, Received: wrong}, - }, - {"Timestamp/success", Timestamp{T: 12345, I: 67890}, nil, nil, llvrwWriteTimestamp, nil}, - {"*Timestamp/success", &Timestamp{T: 12345, I: 67890}, nil, nil, llvrwWriteTimestamp, nil}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeTimestamp}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode non-nil *Timestamp values, got %T", &TimestampCodec{}, &wrong), - }, - { - "type not timestamp", - Timestamp{}, - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a Timestamp", TypeString), - }, - { - "ReadTimestamp Error", - Timestamp{}, - nil, - &llValueReaderWriter{bsontype: TypeTimestamp, err: errors.New("rt error"), errAfter: llvrwReadTimestamp}, - llvrwReadTimestamp, - errors.New("rt error"), - }, - { - "success", - Timestamp{T: 12345, I: 67890}, - nil, - &llValueReaderWriter{bsontype: TypeTimestamp, readval: Timestamp{T: 12345, I: 67890}}, - llvrwReadTimestamp, - nil, - }, - }, - }, - { - "Decimal128Codec", - &Decimal128Codec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &Decimal128Codec{}, Types: []interface{}{decimal.Decimal128{}, (*decimal.Decimal128)(nil)}, Received: wrong}, - }, - {"Decimal128/success", d128, nil, nil, llvrwWriteDecimal128, nil}, - {"*Decimal128/success", &d128, nil, nil, llvrwWriteDecimal128, nil}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeDecimal128}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode non-nil *decimal.Decimal128 values, got %T", &Decimal128Codec{}, &wrong), - }, - { - "type not decimal128", - decimal.Decimal128{}, - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a decimal.Decimal128", TypeString), - }, - { - "ReadDecimal128 Error", - decimal.Decimal128{}, - nil, - &llValueReaderWriter{bsontype: TypeDecimal128, err: errors.New("rd128 error"), errAfter: llvrwReadDecimal128}, - llvrwReadDecimal128, - errors.New("rd128 error"), - }, - { - "success", - d128, - nil, - &llValueReaderWriter{bsontype: TypeDecimal128, readval: d128}, - llvrwReadDecimal128, - nil, - }, - }, - }, - { - "MinKeyCodec", - &MinKeyCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &MinKeyCodec{}, Types: []interface{}{MinKeyv2{}, (*MinKeyv2)(nil)}, Received: wrong}, - }, - {"MinKey/success", MinKeyv2{}, nil, nil, llvrwWriteMinKey, nil}, - {"*MinKey/success", &MinKeyv2{}, nil, nil, llvrwWriteMinKey, nil}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeMinKey}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode non-nil *MinKey values, got %T", &MinKeyCodec{}, &wrong), - }, - { - "type not null", - MinKeyv2{}, - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a MinKey", TypeString), - }, - { - "ReadMinKey Error", - MinKeyv2{}, - nil, - &llValueReaderWriter{bsontype: TypeMinKey, err: errors.New("rn error"), errAfter: llvrwReadMinKey}, - llvrwReadMinKey, - errors.New("rn error"), - }, - { - "success", - MinKeyv2{}, - nil, - &llValueReaderWriter{bsontype: TypeMinKey}, - llvrwReadMinKey, - nil, - }, - }, - }, - { - "MaxKeyCodec", - &MaxKeyCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &MaxKeyCodec{}, Types: []interface{}{MaxKeyv2{}, (*MaxKeyv2)(nil)}, Received: wrong}, - }, - {"MaxKey/success", MaxKeyv2{}, nil, nil, llvrwWriteMaxKey, nil}, - {"*MaxKey/success", &MaxKeyv2{}, nil, nil, llvrwWriteMaxKey, nil}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeMaxKey}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode non-nil *MaxKey values, got %T", &MaxKeyCodec{}, &wrong), - }, - { - "type not null", - MaxKeyv2{}, - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a MaxKey", TypeString), - }, - { - "ReadMaxKey Error", - MaxKeyv2{}, - nil, - &llValueReaderWriter{bsontype: TypeMaxKey, err: errors.New("rn error"), errAfter: llvrwReadMaxKey}, - llvrwReadMaxKey, - errors.New("rn error"), - }, - { - "success", - MaxKeyv2{}, - nil, - &llValueReaderWriter{bsontype: TypeMaxKey}, - llvrwReadMaxKey, - nil, - }, - }, - }, - { - "elementCodec", - &elementCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &elementCodec{}, Types: []interface{}{(*Element)(nil)}, Received: wrong}, - }, - {"invalid element", (*Element)(nil), nil, nil, llvrwNothing, ErrNilElement}, - { - "success", - EC.Null("foo"), - &EncodeContext{Registry: NewRegistryBuilder().Build()}, - &llValueReaderWriter{}, - llvrwWriteNull, - nil, - }, - }, - []deccase{ - { - "cannot use directly", - (*Element)(nil), nil, nil, llvrwNothing, - errors.New("elementCodec's DecodeValue method should not be used directly"), - }, - }, - }, - { - "ValueCodec", - &ValueCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &ValueCodec{}, Types: []interface{}{(*Value)(nil)}, Received: wrong}, - }, - {"invalid value", &Value{}, nil, nil, llvrwNothing, ErrUninitializedElement}, - { - "success", - VC.Null(), - &EncodeContext{Registry: NewRegistryBuilder().Build()}, - &llValueReaderWriter{}, - llvrwWriteNull, - nil, - }, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecDecodeError{Codec: &ValueCodec{}, Types: []interface{}{(**Value)(nil)}, Received: &wrong}, - }, - {"invalid value", (**Value)(nil), nil, nil, llvrwNothing, fmt.Errorf("%T can only be used to decode non-nil **Value", &ValueCodec{})}, - { - "success", - VC.Double(3.14159), - &DecodeContext{Registry: NewRegistryBuilder().Build()}, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(3.14159)}, - llvrwReadDouble, - nil, - }, - }, - }, - { - "JSONNumberCodec", - &JSONNumberCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &JSONNumberCodec{}, Types: []interface{}{json.Number(""), (*json.Number)(nil)}, Received: wrong}, - }, - { - "json.Number/invalid", - json.Number("hello world"), - nil, nil, llvrwNothing, errors.New(`strconv.ParseFloat: parsing "hello world": invalid syntax`), - }, - { - "json.Number/int64/success", - json.Number("1234567890"), - nil, nil, llvrwWriteInt64, nil, - }, - { - "json.Number/float64/success", - json.Number("3.14159"), - nil, nil, llvrwWriteDouble, nil, - }, - { - "*json.Number/int64/success", - pjsnum, - nil, nil, llvrwWriteDouble, nil, - }, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeObjectID}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode non-nil *json.Number values, got %T", &JSONNumberCodec{}, &wrong), - }, - { - "type not double/int32/int64", - json.Number(""), - nil, - &llValueReaderWriter{bsontype: TypeString}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a json.Number", TypeString), - }, - { - "ReadDouble Error", - json.Number(""), - nil, - &llValueReaderWriter{bsontype: TypeDouble, err: errors.New("rd error"), errAfter: llvrwReadDouble}, - llvrwReadDouble, - errors.New("rd error"), - }, - { - "ReadInt32 Error", - json.Number(""), - nil, - &llValueReaderWriter{bsontype: TypeInt32, err: errors.New("ri32 error"), errAfter: llvrwReadInt32}, - llvrwReadInt32, - errors.New("ri32 error"), - }, - { - "ReadInt64 Error", - json.Number(""), - nil, - &llValueReaderWriter{bsontype: TypeInt64, err: errors.New("ri64 error"), errAfter: llvrwReadInt64}, - llvrwReadInt64, - errors.New("ri64 error"), - }, - { - "success/double", - json.Number("3.14159"), - nil, - &llValueReaderWriter{bsontype: TypeDouble, readval: float64(3.14159)}, - llvrwReadDouble, - nil, - }, - { - "success/int32", - json.Number("12345"), - nil, - &llValueReaderWriter{bsontype: TypeInt32, readval: int32(12345)}, - llvrwReadInt32, - nil, - }, - { - "success/int64", - json.Number("1234567890"), - nil, - &llValueReaderWriter{bsontype: TypeInt64, readval: int64(1234567890)}, - llvrwReadInt64, - nil, - }, - }, - }, - { - "URLCodec", - &URLCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &URLCodec{}, Types: []interface{}{url.URL{}, (*url.URL)(nil)}, Received: wrong}, - }, - {"url.URL", url.URL{Scheme: "http", Host: "example.com"}, nil, nil, llvrwWriteString, nil}, - {"*url.URL", &url.URL{Scheme: "http", Host: "example.com"}, nil, nil, llvrwWriteString, nil}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeInt32}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a *url.URL", TypeInt32), - }, - { - "type not *url.URL", - int64(0), - nil, - &llValueReaderWriter{bsontype: TypeString, readval: string("http://example.com")}, - llvrwReadString, - fmt.Errorf("%T can only be used to decode non-nil *url.URL values, got %T", &URLCodec{}, (*int64)(nil)), - }, - { - "ReadString error", - url.URL{}, - nil, - &llValueReaderWriter{bsontype: TypeString, err: errors.New("rs error"), errAfter: llvrwReadString}, - llvrwReadString, - errors.New("rs error"), - }, - { - "url.Parse error", - url.URL{}, - nil, - &llValueReaderWriter{bsontype: TypeString, readval: string("not-valid-%%%%://")}, - llvrwReadString, - errors.New("parse not-valid-%%%%://: first path segment in URL cannot contain colon"), - }, - { - "nil *url.URL", - (*url.URL)(nil), - nil, - &llValueReaderWriter{bsontype: TypeString, readval: string("http://example.com")}, - llvrwReadString, - fmt.Errorf("%T can only be used to decode non-nil *url.URL values, got %T", &URLCodec{}, (*url.URL)(nil)), - }, - { - "nil **url.URL", - (**url.URL)(nil), - nil, - &llValueReaderWriter{bsontype: TypeString, readval: string("http://example.com")}, - llvrwReadString, - fmt.Errorf("%T can only be used to decode non-nil *url.URL values, got %T", &URLCodec{}, (**url.URL)(nil)), - }, - { - "url.URL", - url.URL{Scheme: "http", Host: "example.com"}, - nil, - &llValueReaderWriter{bsontype: TypeString, readval: string("http://example.com")}, - llvrwReadString, - nil, - }, - { - "*url.URL", - &url.URL{Scheme: "http", Host: "example.com"}, - nil, - &llValueReaderWriter{bsontype: TypeString, readval: string("http://example.com")}, - llvrwReadString, - nil, - }, - }, - }, - { - "ReaderCodec", - &ReaderCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &ReaderCodec{}, Types: []interface{}{Reader{}}, Received: wrong}, - }, - { - "WriteDocument Error", - Reader{}, - nil, - &llValueReaderWriter{err: errors.New("wd error"), errAfter: llvrwWriteDocument}, - llvrwWriteDocument, - errors.New("wd error"), - }, - { - "Reader.Iterator Error", - Reader{0xFF, 0x00, 0x00, 0x00, 0x00}, - nil, - &llValueReaderWriter{}, - llvrwWriteDocument, - ErrInvalidLength, - }, - { - "WriteDocumentElement Error", - Reader(bytesFromDoc(NewDocument(EC.Null("foo")))), - nil, - &llValueReaderWriter{err: errors.New("wde error"), errAfter: llvrwWriteDocumentElement}, - llvrwWriteDocumentElement, - errors.New("wde error"), - }, - { - "encodeValue error", - Reader(bytesFromDoc(NewDocument(EC.Null("foo")))), - nil, - &llValueReaderWriter{err: errors.New("ev error"), errAfter: llvrwWriteNull}, - llvrwWriteNull, - errors.New("ev error"), - }, - { - "iterator error", - Reader{0x0C, 0x00, 0x00, 0x00, 0x01, 'f', 'o', 'o', 0x00, 0x01, 0x02, 0x03}, - nil, - &llValueReaderWriter{}, - llvrwWriteDocument, - NewErrTooSmall(), - }, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{}, - llvrwNothing, - CodecDecodeError{Codec: &ReaderCodec{}, Types: []interface{}{(*Reader)(nil)}, Received: &wrong}, - }, - { - "*Reader is nil", - (*Reader)(nil), - nil, - nil, - llvrwNothing, - fmt.Errorf("%T can only be used to decode non-nil *Reader", &ReaderCodec{}), - }, - { - "Copy error", - Reader{}, - nil, - &llValueReaderWriter{err: errors.New("copy error"), errAfter: llvrwReadDocument}, - llvrwReadDocument, - errors.New("copy error"), - }, - }, - }, - { - "ByteSliceCodec", - &ByteSliceCodec{}, - []enccase{ - { - "wrong type", - wrong, - nil, - nil, - llvrwNothing, - CodecEncodeError{Codec: &ByteSliceCodec{}, Types: []interface{}{[]byte{}, (*[]byte)(nil)}, Received: wrong}, - }, - {"[]byte", []byte{0x01, 0x02, 0x03}, nil, nil, llvrwWriteBinary, nil}, - {"*[]byte", &([]byte{0x01, 0x02, 0x03}), nil, nil, llvrwWriteBinary, nil}, - }, - []deccase{ - { - "wrong type", - wrong, - nil, - &llValueReaderWriter{bsontype: TypeInt32}, - llvrwNothing, - fmt.Errorf("cannot decode %v into a *[]byte", TypeInt32), - }, - { - "type not *[]byte", - int64(0), - nil, - &llValueReaderWriter{bsontype: TypeBinary, readval: Binary{}}, - llvrwNothing, - fmt.Errorf("%T can only be used to decode non-nil *[]byte values, got %T", &ByteSliceCodec{}, (*int64)(nil)), - }, - { - "ReadBinary error", - []byte{}, - nil, - &llValueReaderWriter{bsontype: TypeBinary, err: errors.New("rb error"), errAfter: llvrwReadBinary}, - llvrwReadBinary, - errors.New("rb error"), - }, - { - "incorrect subtype", - []byte{}, - nil, - &llValueReaderWriter{bsontype: TypeBinary, readval: Binary{Data: []byte{0x01, 0x02, 0x03}, Subtype: 0xFF}}, - llvrwReadBinary, - fmt.Errorf("%T can only be used to decode subtype 0x00 for %s, got %v", &ByteSliceCodec{}, TypeBinary, byte(0xFF)), - }, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Run("EncodeValue", func(t *testing.T) { - for _, rc := range tc.encodeCases { - t.Run(rc.name, func(t *testing.T) { - var ec EncodeContext - if rc.ectx != nil { - ec = *rc.ectx - } - llvrw := new(llValueReaderWriter) - if rc.llvrw != nil { - llvrw = rc.llvrw - } - llvrw.t = t - err := tc.codec.EncodeValue(ec, llvrw, rc.val) - if !compareErrors(err, rc.err) { - t.Errorf("Errors do not match. got %v; want %v", err, rc.err) - } - invoked := llvrw.invoked - if !cmp.Equal(invoked, rc.invoke) { - t.Errorf("Incorrect method invoked. got %v; want %v", invoked, rc.invoke) - } - }) - } - }) - t.Run("DecodeValue", func(t *testing.T) { - for _, rc := range tc.decodeCases { - t.Run(rc.name, func(t *testing.T) { - var dc DecodeContext - if rc.dctx != nil { - dc = *rc.dctx - } - llvrw := new(llValueReaderWriter) - if rc.llvrw != nil { - llvrw = rc.llvrw - } - llvrw.t = t - var got interface{} - if rc.val == cansetreflectiontest { // We're doing a CanSet reflection test - err := tc.codec.DecodeValue(dc, llvrw, nil) - if !compareErrors(err, rc.err) { - t.Errorf("Errors do not match. got %v; want %v", err, rc.err) - } - - val := reflect.New(reflect.TypeOf(rc.val)).Elem().Interface() - err = tc.codec.DecodeValue(dc, llvrw, val) - if !compareErrors(err, rc.err) { - t.Errorf("Errors do not match. got %v; want %v", err, rc.err) - } - return - } - var unwrap bool - rtype := reflect.TypeOf(rc.val) - if rtype.Kind() == reflect.Ptr { - if reflect.ValueOf(rc.val).IsNil() { - got = rc.val - } else { - val := reflect.New(rtype).Elem() - elem := reflect.New(rtype.Elem()) - val.Set(elem) - got = val.Addr().Interface() - unwrap = true - } - } else { - unwrap = true - got = reflect.New(reflect.TypeOf(rc.val)).Interface() - } - want := rc.val - err := tc.codec.DecodeValue(dc, llvrw, got) - if !compareErrors(err, rc.err) { - t.Errorf("Errors do not match. got %v; want %v", err, rc.err) - } - invoked := llvrw.invoked - if !cmp.Equal(invoked, rc.invoke) { - t.Errorf("Incorrect method invoked. got %v; want %v", invoked, rc.invoke) - } - if unwrap { - got = reflect.ValueOf(got).Elem().Interface() - } - if rc.err == nil && !cmp.Equal(got, want, cmp.Comparer(compareDecimal128), cmp.Comparer(compareValues)) { - t.Errorf("Values do not match. got (%T)%v; want (%T)%v", got, got, want, want) - } - }) - } - }) - }) - } - - t.Run("MapCodec/DecodeValue/non-settable", func(t *testing.T) { - var dc DecodeContext - llvrw := new(llValueReaderWriter) - llvrw.t = t - - want := fmt.Errorf("%T can only be used to decode non-nil pointers to map values, got %T", &MapCodec{}, nil) - got := (&MapCodec{}).DecodeValue(dc, llvrw, nil) - if !compareErrors(got, want) { - t.Errorf("Errors do not match. got %v; want %v", got, want) - } - - want = fmt.Errorf("%T can only be used to decode non-nil pointers to map values, got %T", &MapCodec{}, string("")) - - val := reflect.New(reflect.TypeOf(string(""))).Elem().Interface() - got = (&MapCodec{}).DecodeValue(dc, llvrw, val) - if !compareErrors(got, want) { - t.Errorf("Errors do not match. got %v; want %v", got, want) - } - return - }) - - t.Run("CodeWithScopeCodec/DecodeValue/success", func(t *testing.T) { - dc := DecodeContext{Registry: NewRegistryBuilder().Build()} - dvr := newDocumentValueReader(NewDocument(EC.CodeWithScope("foo", "var hello = 'world';", NewDocument(EC.Null("bar"))))) - dr, err := dvr.ReadDocument() - noerr(t, err) - _, vr, err := dr.ReadElement() - noerr(t, err) - - want := CodeWithScope{ - Code: "var hello = 'world';", - Scope: NewDocument(EC.Null("bar")), - } - var got CodeWithScope - err = (&CodeWithScopeCodec{}).DecodeValue(dc, vr, &got) - noerr(t, err) - - if !cmp.Equal(got, want) { - t.Errorf("CodeWithScopes do not match. got %v; want %v", got, want) - } - }) - - t.Run("DocumentCodec", func(t *testing.T) { - t.Run("EncodeValue", func(t *testing.T) { - t.Run("CodecEncodeError", func(t *testing.T) { - val := bool(true) - want := CodecEncodeError{Codec: &DocumentCodec{}, Types: []interface{}{(*Document)(nil)}, Received: val} - got := (&DocumentCodec{}).EncodeValue(EncodeContext{}, nil, val) - if !compareErrors(got, want) { - t.Errorf("Errors do not match. got %v; want %v", got, want) - } - }) - t.Run("WriteDocument Error", func(t *testing.T) { - want := errors.New("WriteDocument Error") - llvrw := &llValueReaderWriter{ - t: t, - err: want, - errAfter: llvrwWriteDocument, - } - got := (&DocumentCodec{}).EncodeValue(EncodeContext{}, llvrw, NewDocument()) - if !compareErrors(got, want) { - t.Errorf("Errors do not match. got %v; want %v", got, want) - } - }) - t.Run("encodeDocument errors", func(t *testing.T) { - ec := EncodeContext{} - err := errors.New("encodeDocument error") - oid := objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C} - badelem := EC.Null("foo") - badelem.value.data[0] = 0x00 - testCases := []struct { - name string - ec EncodeContext - llvrw *llValueReaderWriter - doc *Document - err error - }{ - { - "WriteDocumentElement", - ec, - &llValueReaderWriter{t: t, err: errors.New("wde error"), errAfter: llvrwWriteDocumentElement}, - NewDocument(EC.Null("foo")), - errors.New("wde error"), - }, - { - "WriteDouble", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteDouble}, - NewDocument(EC.Double("foo", 3.14159)), err, - }, - { - "WriteString", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteString}, - NewDocument(EC.String("foo", "bar")), err, - }, - { - "WriteDocument (Lookup)", EncodeContext{Registry: NewEmptyRegistryBuilder().Build()}, - &llValueReaderWriter{t: t}, - NewDocument(EC.SubDocument("foo", NewDocument(EC.Null("bar")))), - ErrNoCodec{Type: tDocument}, - }, - { - "WriteArray (Lookup)", EncodeContext{Registry: NewEmptyRegistryBuilder().Build()}, - &llValueReaderWriter{t: t}, - NewDocument(EC.Array("foo", NewArray(VC.Null()))), - ErrNoCodec{Type: tArray}, - }, - { - "WriteBinary", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteBinaryWithSubtype}, - NewDocument(EC.BinaryWithSubtype("foo", []byte{0x01, 0x02, 0x03}, 0xFF)), err, - }, - { - "WriteUndefined", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteUndefined}, - NewDocument(EC.Undefined("foo")), err, - }, - { - "WriteObjectID", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteObjectID}, - NewDocument(EC.ObjectID("foo", oid)), err, - }, - { - "WriteBoolean", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteBoolean}, - NewDocument(EC.Boolean("foo", true)), err, - }, - { - "WriteDateTime", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteDateTime}, - NewDocument(EC.DateTime("foo", 1234567890)), err, - }, - { - "WriteNull", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteNull}, - NewDocument(EC.Null("foo")), err, - }, - { - "WriteRegex", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteRegex}, - NewDocument(EC.Regex("foo", "bar", "baz")), err, - }, - { - "WriteDBPointer", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteDBPointer}, - NewDocument(EC.DBPointer("foo", "bar", oid)), err, - }, - { - "WriteJavascript", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteJavascript}, - NewDocument(EC.JavaScript("foo", "var hello = 'world';")), err, - }, - { - "WriteSymbol", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteSymbol}, - NewDocument(EC.Symbol("foo", "symbolbaz")), err, - }, - { - "WriteCodeWithScope (Lookup)", EncodeContext{Registry: NewEmptyRegistryBuilder().Build()}, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteCodeWithScope}, - NewDocument(EC.CodeWithScope("foo", "var hello = 'world';", NewDocument(EC.Null("bar")))), - err, - }, - { - "WriteInt32", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteInt32}, - NewDocument(EC.Int32("foo", 12345)), err, - }, - { - "WriteInt64", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteInt64}, - NewDocument(EC.Int64("foo", 1234567890)), err, - }, - { - "WriteTimestamp", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteTimestamp}, - NewDocument(EC.Timestamp("foo", 10, 20)), err, - }, - { - "WriteDecimal128", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteDecimal128}, - NewDocument(EC.Decimal128("foo", decimal.NewDecimal128(10, 20))), err, - }, - { - "WriteMinKey", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteMinKey}, - NewDocument(EC.MinKey("foo")), err, - }, - { - "WriteMaxKey", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteMaxKey}, - NewDocument(EC.MaxKey("foo")), err, - }, - { - "Invalid Type", ec, - &llValueReaderWriter{t: t, bsontype: Type(0)}, - NewDocument(badelem), - ErrInvalidElement, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := (&DocumentCodec{}).EncodeValue(tc.ec, tc.llvrw, tc.doc) - if !compareErrors(err, tc.err) { - t.Errorf("Errors do not match. got %v; want %v", err, tc.err) - } - }) - } - }) - - t.Run("success", func(t *testing.T) { - oid := objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C} - d128 := decimal.NewDecimal128(10, 20) - want := NewDocument( - EC.Double("a", 3.14159), EC.String("b", "foo"), EC.SubDocumentFromElements("c", EC.Null("aa")), - EC.ArrayFromElements("d", VC.Null()), - EC.BinaryWithSubtype("e", []byte{0x01, 0x02, 0x03}, 0xFF), EC.Undefined("f"), - EC.ObjectID("g", oid), EC.Boolean("h", true), EC.DateTime("i", 1234567890), EC.Null("j"), EC.Regex("k", "foo", "bar"), - EC.DBPointer("l", "foobar", oid), EC.JavaScript("m", "var hello = 'world';"), EC.Symbol("n", "bazqux"), - EC.CodeWithScope("o", "var hello = 'world';", NewDocument(EC.Null("ab"))), EC.Int32("p", 12345), - EC.Timestamp("q", 10, 20), EC.Int64("r", 1234567890), EC.Decimal128("s", d128), EC.MinKey("t"), EC.MaxKey("u"), - ) - got := NewDocument() - ec := EncodeContext{Registry: NewRegistryBuilder().Build()} - err := (&DocumentCodec{}).EncodeValue(ec, newDocumentValueWriter(got), want) - noerr(t, err) - if !got.Equal(want) { - t.Error("Documents do not match") - t.Errorf("\ngot :%v\nwant:%v", got, want) - } - }) - }) - - t.Run("DecodeValue", func(t *testing.T) { - t.Run("CodecDecodeError", func(t *testing.T) { - val := bool(true) - want := CodecDecodeError{Codec: &DocumentCodec{}, Types: []interface{}{(**Document)(nil)}, Received: val} - got := (&DocumentCodec{}).DecodeValue(DecodeContext{}, &llValueReaderWriter{bsontype: TypeEmbeddedDocument}, val) - if !compareErrors(got, want) { - t.Errorf("Errors do not match. got %v; want %v", got, want) - } - }) - t.Run("ReadDocument Error", func(t *testing.T) { - want := errors.New("ReadDocument Error") - llvrw := &llValueReaderWriter{ - t: t, - err: want, - errAfter: llvrwReadDocument, - bsontype: TypeEmbeddedDocument, - } - got := (&DocumentCodec{}).DecodeValue(DecodeContext{}, llvrw, new(*Document)) - if !compareErrors(got, want) { - t.Errorf("Errors do not match. got %v; want %v", got, want) - } - }) - t.Run("decodeDocument errors", func(t *testing.T) { - dc := DecodeContext{} - err := errors.New("decodeDocument error") - badelem := EC.Null("foo") - badelem.value.data[0] = 0x00 - testCases := []struct { - name string - dc DecodeContext - llvrw *llValueReaderWriter - err error - }{ - { - "ReadElement", - dc, - &llValueReaderWriter{t: t, err: errors.New("re error"), errAfter: llvrwReadElement}, - errors.New("re error"), - }, - {"ReadDouble", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadDouble, bsontype: TypeDouble}, err}, - {"ReadString", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadString, bsontype: TypeString}, err}, - { - "ReadDocument (Lookup)", DecodeContext{Registry: NewEmptyRegistryBuilder().Build()}, - &llValueReaderWriter{t: t, bsontype: TypeEmbeddedDocument}, - ErrNoCodec{Type: tDocument}, - }, - { - "ReadArray (Lookup)", DecodeContext{Registry: NewEmptyRegistryBuilder().Build()}, - &llValueReaderWriter{t: t, bsontype: TypeArray}, - ErrNoCodec{Type: tArray}, - }, - {"ReadBinary", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadBinary, bsontype: TypeBinary}, err}, - {"ReadUndefined", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadUndefined, bsontype: TypeUndefined}, err}, - {"ReadObjectID", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadObjectID, bsontype: TypeObjectID}, err}, - {"ReadBoolean", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadBoolean, bsontype: TypeBoolean}, err}, - {"ReadDateTime", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadDateTime, bsontype: TypeDateTime}, err}, - {"ReadNull", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadNull, bsontype: TypeNull}, err}, - {"ReadRegex", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadRegex, bsontype: TypeRegex}, err}, - {"ReadDBPointer", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadDBPointer, bsontype: TypeDBPointer}, err}, - {"ReadJavascript", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadJavascript, bsontype: TypeJavaScript}, err}, - {"ReadSymbol", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadSymbol, bsontype: TypeSymbol}, err}, - { - "ReadCodeWithScope (Lookup)", DecodeContext{Registry: NewEmptyRegistryBuilder().Build()}, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadCodeWithScope, bsontype: TypeCodeWithScope}, - err, - }, - {"ReadInt32", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadInt32, bsontype: TypeInt32}, err}, - {"ReadInt64", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadInt64, bsontype: TypeInt64}, err}, - {"ReadTimestamp", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadTimestamp, bsontype: TypeTimestamp}, err}, - {"ReadDecimal128", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadDecimal128, bsontype: TypeDecimal128}, err}, - {"ReadMinKey", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadMinKey, bsontype: TypeMinKey}, err}, - {"ReadMaxKey", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadMaxKey, bsontype: TypeMaxKey}, err}, - {"Invalid Type", dc, &llValueReaderWriter{t: t, bsontype: Type(0)}, fmt.Errorf("Cannot read unknown BSON type %s", Type(0))}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := (&DocumentCodec{}).decodeDocument(tc.dc, tc.llvrw, new(*Document)) - if !compareErrors(err, tc.err) { - t.Errorf("Errors do not match. got %v; want %v", err, tc.err) - } - }) - } - }) - - t.Run("success", func(t *testing.T) { - oid := objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C} - d128 := decimal.NewDecimal128(10, 20) - want := NewDocument( - EC.Double("a", 3.14159), EC.String("b", "foo"), EC.SubDocumentFromElements("c", EC.Null("aa")), - EC.ArrayFromElements("d", VC.Null()), - EC.BinaryWithSubtype("e", []byte{0x01, 0x02, 0x03}, 0xFF), EC.Undefined("f"), - EC.ObjectID("g", oid), EC.Boolean("h", true), EC.DateTime("i", 1234567890), EC.Null("j"), EC.Regex("k", "foo", "bar"), - EC.DBPointer("l", "foobar", oid), EC.JavaScript("m", "var hello = 'world';"), EC.Symbol("n", "bazqux"), - EC.CodeWithScope("o", "var hello = 'world';", NewDocument(EC.Null("ab"))), EC.Int32("p", 12345), - EC.Timestamp("q", 10, 20), EC.Int64("r", 1234567890), EC.Decimal128("s", d128), EC.MinKey("t"), EC.MaxKey("u"), - ) - var got *Document - dc := DecodeContext{Registry: NewRegistryBuilder().Build()} - err := (&DocumentCodec{}).DecodeValue(dc, newDocumentValueReader(want), &got) - noerr(t, err) - if !got.Equal(want) { - t.Error("Documents do not match") - t.Errorf("\ngot :%v\nwant:%v", got, want) - } - }) - }) - }) - - t.Run("ArrayCodec", func(t *testing.T) { - t.Run("EncodeValue", func(t *testing.T) { - t.Run("CodecEncodeError", func(t *testing.T) { - val := bool(true) - want := CodecEncodeError{Codec: &ArrayCodec{}, Types: []interface{}{(*Array)(nil)}, Received: val} - got := (&ArrayCodec{}).EncodeValue(EncodeContext{}, nil, val) - if !compareErrors(got, want) { - t.Errorf("Errors do not match. got %v; want %v", got, want) - } - }) - t.Run("WriteArray Error", func(t *testing.T) { - want := errors.New("WriteArray Error") - llvrw := &llValueReaderWriter{ - t: t, - err: want, - errAfter: llvrwWriteArray, - } - got := (&ArrayCodec{}).EncodeValue(EncodeContext{}, llvrw, NewArray()) - if !compareErrors(got, want) { - t.Errorf("Errors do not match. got %v; want %v", got, want) - } - }) - t.Run("encode array errors", func(t *testing.T) { - ec := EncodeContext{} - err := errors.New("encode array error") - oid := objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C} - badval := VC.Null() - badval.data[0] = 0x00 - testCases := []struct { - name string - ec EncodeContext - llvrw *llValueReaderWriter - arr *Array - err error - }{ - { - "WriteDocumentElement", - ec, - &llValueReaderWriter{t: t, err: errors.New("wde error"), errAfter: llvrwWriteArrayElement}, - NewArray(VC.Null()), - errors.New("wde error"), - }, - { - "WriteDouble", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteDouble}, - NewArray(VC.Double(3.14159)), err, - }, - { - "WriteString", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteString}, - NewArray(VC.String("bar")), err, - }, - { - "WriteDocument (Lookup)", EncodeContext{Registry: NewEmptyRegistryBuilder().Build()}, - &llValueReaderWriter{t: t}, - NewArray(VC.Document(NewDocument(EC.Null("bar")))), - ErrNoCodec{Type: tDocument}, - }, - { - "WriteArray (Lookup)", EncodeContext{Registry: NewEmptyRegistryBuilder().Build()}, - &llValueReaderWriter{t: t}, - NewArray(VC.Array(NewArray(VC.Null()))), - ErrNoCodec{Type: tArray}, - }, - { - "WriteBinary", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteBinaryWithSubtype}, - NewArray(VC.BinaryWithSubtype([]byte{0x01, 0x02, 0x03}, 0xFF)), err, - }, - { - "WriteUndefined", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteUndefined}, - NewArray(VC.Undefined()), err, - }, - { - "WriteObjectID", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteObjectID}, - NewArray(VC.ObjectID(oid)), err, - }, - { - "WriteBoolean", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteBoolean}, - NewArray(VC.Boolean(true)), err, - }, - { - "WriteDateTime", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteDateTime}, - NewArray(VC.DateTime(1234567890)), err, - }, - { - "WriteNull", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteNull}, - NewArray(VC.Null()), err, - }, - { - "WriteRegex", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteRegex}, - NewArray(VC.Regex("bar", "baz")), err, - }, - { - "WriteDBPointer", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteDBPointer}, - NewArray(VC.DBPointer("bar", oid)), err, - }, - { - "WriteJavascript", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteJavascript}, - NewArray(VC.JavaScript("var hello = 'world';")), err, - }, - { - "WriteSymbol", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteSymbol}, - NewArray(VC.Symbol("symbolbaz")), err, - }, - { - "WriteCodeWithScope (Lookup)", EncodeContext{Registry: NewEmptyRegistryBuilder().Build()}, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteCodeWithScope}, - NewArray(VC.CodeWithScope("var hello = 'world';", NewDocument(EC.Null("bar")))), - err, - }, - { - "WriteInt32", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteInt32}, - NewArray(VC.Int32(12345)), err, - }, - { - "WriteInt64", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteInt64}, - NewArray(VC.Int64(1234567890)), err, - }, - { - "WriteTimestamp", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteTimestamp}, - NewArray(VC.Timestamp(10, 20)), err, - }, - { - "WriteDecimal128", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteDecimal128}, - NewArray(VC.Decimal128(decimal.NewDecimal128(10, 20))), err, - }, - { - "WriteMinKey", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteMinKey}, - NewArray(VC.MinKey()), err, - }, - { - "WriteMaxKey", ec, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwWriteMaxKey}, - NewArray(VC.MaxKey()), err, - }, - { - "Invalid Type", ec, - &llValueReaderWriter{t: t, bsontype: Type(0)}, - NewArray(badval), - ErrInvalidElement, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := (&ArrayCodec{}).EncodeValue(tc.ec, tc.llvrw, tc.arr) - if !compareErrors(err, tc.err) { - t.Errorf("Errors do not match. got %v; want %v", err, tc.err) - } - }) - } - }) - - t.Run("success", func(t *testing.T) { - oid := objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C} - d128 := decimal.NewDecimal128(10, 20) - want := NewArray( - VC.Double(3.14159), VC.String("foo"), VC.DocumentFromElements(EC.Null("aa")), - VC.ArrayFromValues(VC.Null()), - VC.BinaryWithSubtype([]byte{0x01, 0x02, 0x03}, 0xFF), VC.Undefined(), - VC.ObjectID(oid), VC.Boolean(true), VC.DateTime(1234567890), VC.Null(), VC.Regex("foo", "bar"), - VC.DBPointer("foobar", oid), VC.JavaScript("var hello = 'world';"), VC.Symbol("bazqux"), - VC.CodeWithScope("var hello = 'world';", NewDocument(EC.Null("ab"))), VC.Int32(12345), - VC.Timestamp(10, 20), VC.Int64(1234567890), VC.Decimal128(d128), VC.MinKey(), VC.MaxKey(), - ) - - ec := EncodeContext{Registry: NewRegistryBuilder().Build()} - - doc := NewDocument() - dvw := newDocumentValueWriter(doc) - dr, err := dvw.WriteDocument() - noerr(t, err) - vr, err := dr.WriteDocumentElement("foo") - noerr(t, err) - - err = (&ArrayCodec{}).EncodeValue(ec, vr, want) - noerr(t, err) - - got := doc.Lookup("foo").MutableArray() - if !got.Equal(want) { - t.Error("Documents do not match") - t.Errorf("\ngot :%v\nwant:%v", got, want) - } - }) - }) - - t.Run("DecodeValue", func(t *testing.T) { - t.Run("CodecDecodeError", func(t *testing.T) { - val := bool(true) - want := CodecDecodeError{Codec: &ArrayCodec{}, Types: []interface{}{(**Array)(nil)}, Received: val} - got := (&ArrayCodec{}).DecodeValue(DecodeContext{}, &llValueReaderWriter{bsontype: TypeArray}, val) - if !compareErrors(got, want) { - t.Errorf("Errors do not match. got %v; want %v", got, want) - } - }) - t.Run("ReadArray Error", func(t *testing.T) { - want := errors.New("ReadArray Error") - llvrw := &llValueReaderWriter{ - t: t, - err: want, - errAfter: llvrwReadArray, - bsontype: TypeArray, - } - got := (&ArrayCodec{}).DecodeValue(DecodeContext{}, llvrw, new(*Array)) - if !compareErrors(got, want) { - t.Errorf("Errors do not match. got %v; want %v", got, want) - } - }) - t.Run("decode array errors", func(t *testing.T) { - dc := DecodeContext{} - err := errors.New("decode array error") - badval := VC.Null() - badval.data[0] = 0x00 - testCases := []struct { - name string - dc DecodeContext - llvrw *llValueReaderWriter - err error - }{ - { - "ReadValue", - dc, - &llValueReaderWriter{t: t, err: errors.New("re error"), errAfter: llvrwReadValue}, - errors.New("re error"), - }, - {"ReadDouble", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadDouble, bsontype: TypeDouble}, err}, - {"ReadString", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadString, bsontype: TypeString}, err}, - { - "ReadDocument (Lookup)", DecodeContext{Registry: NewEmptyRegistryBuilder().Build()}, - &llValueReaderWriter{t: t, bsontype: TypeEmbeddedDocument}, - ErrNoCodec{Type: tDocument}, - }, - { - "ReadArray (Lookup)", DecodeContext{Registry: NewEmptyRegistryBuilder().Build()}, - &llValueReaderWriter{t: t, bsontype: TypeArray}, - ErrNoCodec{Type: tArray}, - }, - {"ReadBinary", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadBinary, bsontype: TypeBinary}, err}, - {"ReadUndefined", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadUndefined, bsontype: TypeUndefined}, err}, - {"ReadObjectID", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadObjectID, bsontype: TypeObjectID}, err}, - {"ReadBoolean", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadBoolean, bsontype: TypeBoolean}, err}, - {"ReadDateTime", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadDateTime, bsontype: TypeDateTime}, err}, - {"ReadNull", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadNull, bsontype: TypeNull}, err}, - {"ReadRegex", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadRegex, bsontype: TypeRegex}, err}, - {"ReadDBPointer", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadDBPointer, bsontype: TypeDBPointer}, err}, - {"ReadJavascript", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadJavascript, bsontype: TypeJavaScript}, err}, - {"ReadSymbol", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadSymbol, bsontype: TypeSymbol}, err}, - { - "ReadCodeWithScope (Lookup)", DecodeContext{Registry: NewEmptyRegistryBuilder().Build()}, - &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadCodeWithScope, bsontype: TypeCodeWithScope}, - err, - }, - {"ReadInt32", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadInt32, bsontype: TypeInt32}, err}, - {"ReadInt64", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadInt64, bsontype: TypeInt64}, err}, - {"ReadTimestamp", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadTimestamp, bsontype: TypeTimestamp}, err}, - {"ReadDecimal128", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadDecimal128, bsontype: TypeDecimal128}, err}, - {"ReadMinKey", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadMinKey, bsontype: TypeMinKey}, err}, - {"ReadMaxKey", dc, &llValueReaderWriter{t: t, err: err, errAfter: llvrwReadMaxKey, bsontype: TypeMaxKey}, err}, - {"Invalid Type", dc, &llValueReaderWriter{t: t, bsontype: Type(0)}, fmt.Errorf("Cannot read unknown BSON type %s", Type(0))}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := (&ArrayCodec{}).DecodeValue(tc.dc, tc.llvrw, new(*Array)) - if !compareErrors(err, tc.err) { - t.Errorf("Errors do not match. got %v; want %v", err, tc.err) - } - }) - } - }) - - t.Run("success", func(t *testing.T) { - oid := objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C} - d128 := decimal.NewDecimal128(10, 20) - want := NewArray( - VC.Double(3.14159), VC.String("foo"), VC.DocumentFromElements(EC.Null("aa")), - VC.ArrayFromValues(VC.Null()), - VC.BinaryWithSubtype([]byte{0x01, 0x02, 0x03}, 0xFF), VC.Undefined(), - VC.ObjectID(oid), VC.Boolean(true), VC.DateTime(1234567890), VC.Null(), VC.Regex("foo", "bar"), - VC.DBPointer("foobar", oid), VC.JavaScript("var hello = 'world';"), VC.Symbol("bazqux"), - VC.CodeWithScope("var hello = 'world';", NewDocument(EC.Null("ab"))), VC.Int32(12345), - VC.Timestamp(10, 20), VC.Int64(1234567890), VC.Decimal128(d128), VC.MinKey(), VC.MaxKey(), - ) - dc := DecodeContext{Registry: NewRegistryBuilder().Build()} - - dvr := newDocumentValueReader(NewDocument(EC.Array("", want))) - dr, err := dvr.ReadDocument() - noerr(t, err) - _, vr, err := dr.ReadElement() - noerr(t, err) - - var got *Array - err = (&ArrayCodec{}).DecodeValue(dc, vr, &got) - noerr(t, err) - if !got.Equal(want) { - t.Error("Documents do not match") - t.Errorf("\ngot :%v\nwant:%v", got, want) - } - }) - }) - }) - t.Run("SliceCodec/DecodeValue/can't set slice", func(t *testing.T) { - var val []string - want := fmt.Errorf("%T can only be used to decode non-nil pointers to slice or array values, got %T", &SliceCodec{}, val) - got := (&SliceCodec{}).DecodeValue(DecodeContext{}, nil, val) - if !compareErrors(got, want) { - t.Errorf("Errors do not match. got %v; want %v", got, want) - } - }) - t.Run("SliceCodec/DecodeValue/too many elements", func(t *testing.T) { - dvr := newDocumentValueReader(NewDocument(EC.ArrayFromElements("foo", VC.String("foo"), VC.String("bar")))) - dr, err := dvr.ReadDocument() - noerr(t, err) - _, vr, err := dr.ReadElement() - noerr(t, err) - var val [1]string - want := fmt.Errorf("more elements returned in array than can fit inside %T", val) - - dc := DecodeContext{Registry: NewRegistryBuilder().Build()} - got := (&SliceCodec{}).DecodeValue(dc, vr, &val) - if !compareErrors(got, want) { - t.Errorf("Errors do not match. got %v; want %v", got, want) - } - }) - - t.Run("success path", func(t *testing.T) { - oid := objectid.New() - oids := []objectid.ObjectID{objectid.New(), objectid.New(), objectid.New()} - var str = new(string) - *str = "bar" - now := time.Now().Truncate(time.Millisecond) - murl, err := url.Parse("https://mongodb.com/random-url?hello=world") - if err != nil { - t.Errorf("Error parsing URL: %v", err) - t.FailNow() - } - decimal128, err := decimal.ParseDecimal128("1.5e10") - if err != nil { - t.Errorf("Error parsing decimal128: %v", err) - t.FailNow() - } - - testCases := []struct { - name string - value interface{} - b []byte - err error - }{ - { - "map[string]int", - map[string]int32{"foo": 1}, - []byte{ - 0x0E, 0x00, 0x00, 0x00, - 0x10, 'f', 'o', 'o', 0x00, - 0x01, 0x00, 0x00, 0x00, - 0x00, - }, - nil, - }, - { - "map[string]objectid.ObjectID", - map[string]objectid.ObjectID{"foo": oid}, - docToBytes(NewDocument(EC.ObjectID("foo", oid))), - nil, - }, - { - "map[string][]*Element", - map[string][]*Element{"Z": {EC.Int32("A", 1), EC.Int32("B", 2), EC.Int32("EC", 3)}}, - docToBytes(NewDocument( - EC.SubDocumentFromElements("Z", EC.Int32("A", 1), EC.Int32("B", 2), EC.Int32("EC", 3)), - )), - nil, - }, - { - "map[string][]*Value", - map[string][]*Value{"Z": {VC.Int32(1), VC.Int32(2), VC.Int32(3)}}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.Int32(1), VC.Int32(2), VC.Int32(3)), - )), - nil, - }, - { - "map[string]*Element", - map[string]*Element{"Z": EC.Int32("Z", 12345)}, - docToBytes(NewDocument( - EC.Int32("Z", 12345), - )), - nil, - }, - { - "map[string]*Document", - map[string]*Document{"Z": NewDocument(EC.Null("foo"))}, - docToBytes(NewDocument( - EC.SubDocumentFromElements("Z", EC.Null("foo")), - )), - nil, - }, - { - "map[string]Reader", - map[string]Reader{"Z": {0x05, 0x00, 0x00, 0x00, 0x00}}, - docToBytes(NewDocument( - EC.SubDocumentFromReader("Z", Reader{0x05, 0x00, 0x00, 0x00, 0x00}), - )), - nil, - }, - { - "map[string][]int32", - map[string][]int32{"Z": {1, 2, 3}}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.Int32(1), VC.Int32(2), VC.Int32(3)), - )), - nil, - }, - { - "map[string][]objectid.ObjectID", - map[string][]objectid.ObjectID{"Z": oids}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.ObjectID(oids[0]), VC.ObjectID(oids[1]), VC.ObjectID(oids[2])), - )), - nil, - }, - { - "map[string][]json.Number(int64)", - map[string][]json.Number{"Z": {json.Number("5"), json.Number("10")}}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.Int64(5), VC.Int64(10)), - )), - nil, - }, - { - "map[string][]json.Number(float64)", - map[string][]json.Number{"Z": {json.Number("5"), json.Number("10.1")}}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.Int64(5), VC.Double(10.1)), - )), - nil, - }, - { - "map[string][]*url.URL", - map[string][]*url.URL{"Z": {murl}}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.String(murl.String())), - )), - nil, - }, - { - "map[string][]decimal.Decimal128", - map[string][]decimal.Decimal128{"Z": {decimal128}}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.Decimal128(decimal128)), - )), - nil, - }, - { - "-", - struct { - A string `bson:"-"` - }{ - A: "", - }, - docToBytes(NewDocument()), - nil, - }, - { - "omitempty", - struct { - A string `bson:",omitempty"` - }{ - A: "", - }, - docToBytes(NewDocument()), - nil, - }, - { - "omitempty, empty time", - struct { - A time.Time `bson:",omitempty"` - }{ - A: time.Time{}, - }, - docToBytes(NewDocument()), - nil, - }, - { - "no private fields", - noPrivateFields{a: "should be empty"}, - docToBytes(NewDocument()), - nil, - }, - { - "minsize", - struct { - A int64 `bson:",minsize"` - }{ - A: 12345, - }, - docToBytes(NewDocument(EC.Int32("a", 12345))), - nil, - }, - { - "inline", - struct { - Foo struct { - A int64 `bson:",minsize"` - } `bson:",inline"` - }{ - Foo: struct { - A int64 `bson:",minsize"` - }{ - A: 12345, - }, - }, - docToBytes(NewDocument(EC.Int32("a", 12345))), - nil, - }, - { - "inline map", - struct { - Foo map[string]string `bson:",inline"` - }{ - Foo: map[string]string{"foo": "bar"}, - }, - docToBytes(NewDocument(EC.String("foo", "bar"))), - nil, - }, - { - "alternate name bson:name", - struct { - A string `bson:"foo"` - }{ - A: "bar", - }, - docToBytes(NewDocument(EC.String("foo", "bar"))), - nil, - }, - { - "alternate name", - struct { - A string `bson:"foo"` - }{ - A: "bar", - }, - docToBytes(NewDocument(EC.String("foo", "bar"))), - nil, - }, - { - "inline, omitempty", - struct { - A string - Foo zeroTest `bson:"omitempty,inline"` - }{ - A: "bar", - Foo: zeroTest{true}, - }, - docToBytes(NewDocument(EC.String("a", "bar"))), - nil, - }, - { - "struct{}", - struct { - A bool - B int32 - C int64 - D uint16 - E uint64 - F float64 - G string - H map[string]string - I []byte - K [2]string - L struct { - M string - } - N *Element - O *Document - P Reader - Q objectid.ObjectID - T []struct{} - Y json.Number - Z time.Time - AA json.Number - AB *url.URL - AC decimal.Decimal128 - AD *time.Time - }{ - A: true, - B: 123, - C: 456, - D: 789, - E: 101112, - F: 3.14159, - G: "Hello, world", - H: map[string]string{"foo": "bar"}, - I: []byte{0x01, 0x02, 0x03}, - K: [2]string{"baz", "qux"}, - L: struct { - M string - }{ - M: "foobar", - }, - N: EC.Null("n"), - O: NewDocument(EC.Int64("countdown", 9876543210)), - P: Reader{0x05, 0x00, 0x00, 0x00, 0x00}, - Q: oid, - T: nil, - Y: json.Number("5"), - Z: now, - AA: json.Number("10.1"), - AB: murl, - AC: decimal128, - AD: &now, - }, - docToBytes(NewDocument( - EC.Boolean("a", true), - EC.Int32("b", 123), - EC.Int64("c", 456), - EC.Int32("d", 789), - EC.Int64("e", 101112), - EC.Double("f", 3.14159), - EC.String("g", "Hello, world"), - EC.SubDocumentFromElements("h", EC.String("foo", "bar")), - EC.Binary("i", []byte{0x01, 0x02, 0x03}), - EC.ArrayFromElements("k", VC.String("baz"), VC.String("qux")), - EC.SubDocumentFromElements("l", EC.String("m", "foobar")), - EC.Null("n"), - EC.SubDocumentFromElements("o", EC.Int64("countdown", 9876543210)), - EC.SubDocumentFromElements("p"), - EC.ObjectID("q", oid), - EC.Null("t"), - EC.Int64("y", 5), - EC.DateTime("z", now.UnixNano()/int64(time.Millisecond)), - EC.Double("aa", 10.1), - EC.String("ab", murl.String()), - EC.Decimal128("ac", decimal128), - EC.DateTime("ad", now.UnixNano()/int64(time.Millisecond)), - )), - nil, - }, - { - "struct{[]interface{}}", - struct { - A []bool - B []int32 - C []int64 - D []uint16 - E []uint64 - F []float64 - G []string - H []map[string]string - I [][]byte - K [1][2]string - L []struct { - M string - } - N [][]string - O []*Element - P []*Document - Q []Reader - R []objectid.ObjectID - T []struct{} - W []map[string]struct{} - X []map[string]struct{} - Y []map[string]struct{} - Z []time.Time - AA []json.Number - AB []*url.URL - AC []decimal.Decimal128 - AD []*time.Time - }{ - A: []bool{true}, - B: []int32{123}, - C: []int64{456}, - D: []uint16{789}, - E: []uint64{101112}, - F: []float64{3.14159}, - G: []string{"Hello, world"}, - H: []map[string]string{{"foo": "bar"}}, - I: [][]byte{{0x01, 0x02, 0x03}}, - K: [1][2]string{{"baz", "qux"}}, - L: []struct { - M string - }{ - { - M: "foobar", - }, - }, - N: [][]string{{"foo", "bar"}}, - O: []*Element{EC.Null("N")}, - P: []*Document{NewDocument(EC.Int64("countdown", 9876543210))}, - Q: []Reader{{0x05, 0x00, 0x00, 0x00, 0x00}}, - R: oids, - T: nil, - W: nil, - X: []map[string]struct{}{}, // Should be empty BSON Array - Y: []map[string]struct{}{{}}, // Should be BSON array with one element, an empty BSON SubDocument - Z: []time.Time{now, now}, - AA: []json.Number{json.Number("5"), json.Number("10.1")}, - AB: []*url.URL{murl}, - AC: []decimal.Decimal128{decimal128}, - AD: []*time.Time{&now, &now}, - }, - docToBytes(NewDocument( - EC.ArrayFromElements("a", VC.Boolean(true)), - EC.ArrayFromElements("b", VC.Int32(123)), - EC.ArrayFromElements("c", VC.Int64(456)), - EC.ArrayFromElements("d", VC.Int32(789)), - EC.ArrayFromElements("e", VC.Int64(101112)), - EC.ArrayFromElements("f", VC.Double(3.14159)), - EC.ArrayFromElements("g", VC.String("Hello, world")), - EC.ArrayFromElements("h", VC.DocumentFromElements(EC.String("foo", "bar"))), - EC.ArrayFromElements("i", VC.Binary([]byte{0x01, 0x02, 0x03})), - EC.ArrayFromElements("k", VC.ArrayFromValues(VC.String("baz"), VC.String("qux"))), - EC.ArrayFromElements("l", VC.DocumentFromElements(EC.String("m", "foobar"))), - EC.ArrayFromElements("n", VC.ArrayFromValues(VC.String("foo"), VC.String("bar"))), - EC.SubDocumentFromElements("o", EC.Null("N")), - EC.ArrayFromElements("p", VC.DocumentFromElements(EC.Int64("countdown", 9876543210))), - EC.ArrayFromElements("q", VC.DocumentFromElements()), - EC.ArrayFromElements("r", VC.ObjectID(oids[0]), VC.ObjectID(oids[1]), VC.ObjectID(oids[2])), - EC.Null("t"), - EC.Null("w"), - EC.Array("x", NewArray()), - EC.ArrayFromElements("y", VC.Document(NewDocument())), - EC.ArrayFromElements("z", VC.DateTime(now.UnixNano()/int64(time.Millisecond)), VC.DateTime(now.UnixNano()/int64(time.Millisecond))), - EC.ArrayFromElements("aa", VC.Int64(5), VC.Double(10.10)), - EC.ArrayFromElements("ab", VC.String(murl.String())), - EC.ArrayFromElements("ac", VC.Decimal128(decimal128)), - EC.ArrayFromElements("ad", VC.DateTime(now.UnixNano()/int64(time.Millisecond)), VC.DateTime(now.UnixNano()/int64(time.Millisecond))), - )), - nil, - }, - } - - t.Run("Encode", func(t *testing.T) { - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - b := make([]byte, 0, 512) - vw := newValueWriterFromSlice(b) - enc, err := NewEncoderv2(NewRegistryBuilder().Build(), vw) - noerr(t, err) - err = enc.Encode(tc.value) - if err != tc.err { - t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) - } - b = vw.buf - if diff := cmp.Diff(b, tc.b); diff != "" { - t.Errorf("Bytes written differ: (-got +want)\n%s", diff) - t.Errorf("Bytes\ngot: %v\nwant:%v\n", b, tc.b) - t.Errorf("Readers\ngot: %v\nwant:%v\n", Reader(b), Reader(tc.b)) - } - }) - } - }) - - t.Run("Decode", func(t *testing.T) { - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - vr := newValueReader(tc.b) - dec, err := NewDecoderv2(NewRegistryBuilder().Build(), vr) - noerr(t, err) - gotVal := reflect.New(reflect.TypeOf(tc.value)) - err = dec.Decode(gotVal.Interface()) - noerr(t, err) - got := gotVal.Elem().Interface() - want := tc.value - if diff := cmp.Diff( - got, want, - cmp.Comparer(compareElements), - cmp.Comparer(compareValues), - cmp.Comparer(compareDecimal128), - cmp.Comparer(compareNoPrivateFields), - cmp.Comparer(compareZeroTest), - ); diff != "" { - t.Errorf("difference:\n%s", diff) - t.Errorf("Values are not equal.\ngot: %#v\nwant:%#v", got, want) - } - }) - } - }) - }) -} - -func compareValues(v1, v2 *Value) bool { return v1.equal(v2) } -func compareElements(e1, e2 *Element) bool { return e1.equal(e2) } -func compareStrings(s1, s2 string) bool { return s1 == s2 } - -type noPrivateFields struct { - a string -} - -func compareNoPrivateFields(npf1, npf2 noPrivateFields) bool { - return npf1.a != npf2.a // We don't want these to be equal -} diff --git a/bson/constructor.go b/bson/constructor.go index 1142b9568e..bbc4fee751 100644 --- a/bson/constructor.go +++ b/bson/constructor.go @@ -7,10 +7,10 @@ package bson import ( + "bytes" "errors" "fmt" "math" - "reflect" "time" "github.com/mongodb/mongo-go-driver/bson/decimal" @@ -102,15 +102,7 @@ func (ElementConstructor) Interface(key string, value interface{}) *Element { elem = EC.Null(key) } default: - var err error - enc := new(encoder) - val := reflect.ValueOf(value) - val = enc.underlyingVal(val) - - elem, err = enc.elemFromValue(key, val, true) - if err != nil { - elem = EC.Null(key) - } + elem = EC.Null(key) } return elem @@ -149,11 +141,7 @@ func (c ElementConstructor) InterfaceErr(key string, value interface{}) (*Elemen err = errors.New("invalid *Value provided, cannot convert to *Element") } default: - enc := new(encoder) - val := reflect.ValueOf(value) - val = enc.underlyingVal(val) - - elem, err = enc.elemFromValue(key, val, true) + err = fmt.Errorf("Cannot create element for type %T, try using bsoncodec.ConstructElementErr", value) } if err != nil { @@ -543,6 +531,44 @@ func (ElementConstructor) MaxKey(key string) *Element { return elem } +// FromBytes constructs an element from the bytes provided. If the bytes are not +// a valid element, this method will panic. +func (ElementConstructor) FromBytes(src []byte) *Element { + elem, err := EC.FromBytesErr(src) + if err != nil { + panic(err) + } + return elem +} + +// FromValue constructs an element using the underlying value. +func (ElementConstructor) FromValue(key string, value *Value) *Element { + return convertValueToElem(key, value) +} + +// FromBytesErr constructs an element from the bytes provided, but unlike +// FromBytes this method will return an error and not panic if the bytes are not +// a valid element. +func (ElementConstructor) FromBytesErr(src []byte) (*Element, error) { + // TODO: once we have llbson developed, use that to validate the bytes + idx := bytes.IndexByte(src, 0x00) + if idx < 0 { + return nil, errors.New("not a valid element: does not contain a valid key") + } + + if len(src) < 2 { + return nil, errors.New("not a valid element: not enough bytes") + } + + data := make([]byte, len(src)) + copy(data, src) + elem := &Element{value: &Value{start: 0, offset: uint32(idx) + 1, data: data}} + if _, err := elem.Validate(); err != nil { + return nil, err + } + return elem, nil +} + // Double creates a double element with the given value. func (ValueConstructor) Double(f float64) *Value { return EC.Double("", f).value diff --git a/bson/decode.go b/bson/decode.go index c93334985e..d2d1b6cdef 100644 --- a/bson/decode.go +++ b/bson/decode.go @@ -7,17 +7,9 @@ package bson import ( - "fmt" - "io" - "math" - "net/url" "reflect" - "strconv" - "strings" "time" - "bytes" - "github.com/mongodb/mongo-go-driver/bson/decimal" "github.com/mongodb/mongo-go-driver/bson/objectid" ) @@ -66,870 +58,3 @@ var zeroVal reflect.Value // objects correctly to match the legacy bson library's handling of // time.Time values. const zeroEpochMs = int64(62135596800000) - -// Unmarshaler describes a type that can unmarshal itself from BSON bytes. -type Unmarshaler interface { - UnmarshalBSON([]byte) error -} - -// DocumentUnmarshaler describes a type that can unmarshal itself from a bson.Document. -type DocumentUnmarshaler interface { - UnmarshalBSONDocument(*Document) error -} - -// Decoder describes a BSON representation that can decodes itself into a value. -type Decoder interface { - Decode(interface{}) error -} - -// decoder facilitates decoding a value from an io.Reader yielding a BSON document as bytes. -type decoder struct { - pReader *peekLengthReader - bsonReader Reader -} - -type peekLengthReader struct { - io.Reader - length [4]byte - pos int32 -} - -func newPeekLengthReader(r io.Reader) *peekLengthReader { - return &peekLengthReader{Reader: r, pos: -1} -} - -func (r *peekLengthReader) peekLength() (int32, error) { - _, err := io.ReadFull(r, r.length[:]) - if err != nil { - return 0, err - } - - // Mark that the length has been read. - r.pos = 0 - - return readi32(r.length[:]), nil -} - -func (r *peekLengthReader) Read(b []byte) (int, error) { - // If either peekLength hasn't been called or the length has been read past, read from the - // io.Reader. - if r.pos < 0 || r.pos > 3 { - return r.Reader.Read(b) - } - - // Read as much of the length as possible into the buffer - bytesToRead := 4 - r.pos - if len(b) < int(bytesToRead) { - bytesToRead = int32(len(b)) - } - - r.pos += int32(copy(b, r.length[r.pos:r.pos+bytesToRead])) - - // Because we use io.ReadFull everywhere, we don't need to read any further since it will be - // read in a subsequent call to Read. - return int(bytesToRead), nil -} - -func convertToPtr(val reflect.Value) reflect.Value { - valPtr := reflect.New(val.Type()) - valPtr.Elem().Set(val) - return valPtr -} - -// NewDecoder constructs a new default Decoder implementation from the given io.Reader. -// -// In this implementation, the value can be any one of the following types: -// -// - bson.Unmarshaler -// - io.Writer -// - []byte -// - bson.Reader -// - any map with string keys -// - a struct (possibly with tags) -// -// In the case of struct values, only exported fields will be deserialized. The lowercased field -// name is used as the key for each exported field, but this behavior may be changed using a struct -// tag. The tag may also contain flags to adjust the unmarshaling behavior for the field. The tag -// formats accepted are: -// -// "[][,[,]]" -// -// `(...) bson:"[][,[,]]" (...)` -// -// The target field or element types of out may not necessarily match the BSON values of the -// provided data. The following conversions are made automatically: -// -// - Numeric types are converted if at least the integer part of the value would be preserved -// correctly -// -// If the value would not fit the type and cannot be converted, it is silently skipped. -// -// Pointer values are initialized when necessary. -func NewDecoder(r io.Reader) Decoder { - return newDecoder(r) -} - -func newDecoder(r io.Reader) *decoder { - return &decoder{pReader: newPeekLengthReader(r)} -} - -// Decode decodes the BSON document from the underlying io.Reader into the given value. -func (d *decoder) Decode(v interface{}) error { - switch t := v.(type) { - case Unmarshaler: - err := d.decodeToReader() - if err != nil { - return err - } - - return t.UnmarshalBSON(d.bsonReader) - case io.Writer: - err := d.decodeToReader() - if err != nil { - return err - } - - _, err = t.Write(d.bsonReader) - return err - case []byte: - length, err := d.pReader.peekLength() - if err != nil { - return err - } - - if len(t) < int(length) { - return NewErrTooSmall() - } - - _, err = io.ReadFull(d.pReader, t) - if err != nil { - return err - } - - _, err = Reader(t).Validate() - return err - case Reader: - length, err := d.pReader.peekLength() - if err != nil { - return err - } - - if len(t) < int(length) { - return NewErrTooSmall() - } - - _, err = io.ReadAtLeast(d.pReader, t, int(length)) - if err != nil { - return err - } - - _, err = t.Validate() - return err - - default: - rval := reflect.ValueOf(v) - return d.reflectDecode(rval) - } -} - -func (d *decoder) decodeToReader() error { - var err error - d.bsonReader, err = NewFromIOReader(d.pReader) - if err != nil { - return err - } - - _, err = d.bsonReader.Validate() - return err - -} - -func (d *decoder) reflectDecode(val reflect.Value) (err error) { - defer func() { - if e := recover(); e != nil { - err = fmt.Errorf("%s", e) - } - }() - - switch val.Kind() { - case reflect.Map: - return d.decodeIntoMap(val) - case reflect.Slice, reflect.Array: - return d.decodeIntoElementSlice(val) - case reflect.Struct: - return d.decodeIntoStruct(val) - case reflect.Ptr: - v := val.Elem() - - if v.Kind() == reflect.Struct { - return d.decodeIntoStruct(v) - } - - fallthrough - default: - return fmt.Errorf("cannot decode BSON document to type %s", val.Type()) - } -} - -func (d *decoder) createEmptyValue(r Reader, t reflect.Type) (reflect.Value, error) { - var val reflect.Value - - if t == tReader { - return reflect.ValueOf(r), nil - } - - switch t.Kind() { - case reflect.Map: - val = reflect.MakeMap(t) - case reflect.Ptr: - if t == tDocument { - val = reflect.ValueOf(NewDocument()) - break - } - - empty, err := d.createEmptyValue(r, t.Elem()) - if err != nil { - return val, err - } - - val = reflect.New(empty.Type()) - val.Elem().Set(empty) - case reflect.Slice: - length := 0 - _, err := r.readElements(func(_ *Element) error { - length++ - return nil - }) - - if err != nil { - return val, err - } - - val = reflect.MakeSlice(t.Elem(), length, length) - case reflect.Struct: - val = reflect.New(t) - default: - val = reflect.Zero(t) - } - - return val, nil -} - -func (d *decoder) getReflectValue(v *Value, containerType reflect.Type, outer reflect.Type) (reflect.Value, error) { - var val reflect.Value - - isPtr := (containerType.Kind() == reflect.Ptr) - - for containerType.Kind() == reflect.Ptr { - containerType = containerType.Elem() - } - - switch v.Type() { - case 0x1: - f := v.Double() - - switch containerType { - case tUint8: - if f > 0 && math.Floor(f) == f && f <= float64(math.MaxUint8) { - val = reflect.ValueOf(uint8(f)) - } - case tUint16: - if f > 0 && math.Floor(f) == f && f <= float64(math.MaxUint16) { - val = reflect.ValueOf(uint16(f)) - } - case tUint32: - if f > 0 && math.Floor(f) == f && f <= float64(math.MaxUint32) { - val = reflect.ValueOf(uint32(f)) - } - case tUint64: - if f > 0 && math.Floor(f) == f && f <= float64(math.MaxUint64) { - val = reflect.ValueOf(uint64(f)) - } - case tUint: - if f < 0 || math.Floor(f) != f || f > float64(math.MaxUint64) { - break - } - - u := uint64(f) - if uint64(uint(u)) == u { - val = reflect.ValueOf(uint(f)) - } - case tInt8: - if math.Floor(f) == f && f >= float64(math.MinInt8) && f <= float64(math.MaxInt8) { - val = reflect.ValueOf(int8(f)) - } - case tInt16: - if math.Floor(f) == f && f >= float64(math.MinInt16) && f <= float64(math.MaxInt16) { - val = reflect.ValueOf(int16(f)) - } - case tInt32: - if math.Floor(f) == f && f >= float64(math.MinInt32) && f <= float64(math.MaxInt32) { - val = reflect.ValueOf(int32(f)) - } - case tInt64: - if math.Floor(f) == f && f >= float64(math.MinInt64) && f <= float64(math.MaxInt64) { - val = reflect.ValueOf(int64(f)) - } - case tInt: - if math.Floor(f) != f || f < float64(math.MinInt64) || f > float64(math.MaxInt64) { - break - } - - i := int64(f) - if int64(int(i)) == i { - val = reflect.ValueOf(int(i)) - } - - case tFloat32: - val = reflect.ValueOf(float32(f)) - - case tFloat64, tEmpty: - val = reflect.ValueOf(f) - case tJSONNumber: - val = reflect.ValueOf(strconv.FormatFloat(f, 'f', -1, 64)).Convert(tJSONNumber) - default: - return val, nil - } - - case 0x2: - str := v.StringValue() - switch containerType { - case tString, tEmpty: - val = reflect.ValueOf(str) - case tJSONNumber: - _, err := strconv.ParseFloat(str, 64) - if err != nil { - return val, err - } - val = reflect.ValueOf(str).Convert(tJSONNumber) - - case tURL: - u, err := url.Parse(str) - if err != nil { - return val, err - } - val = reflect.ValueOf(u).Elem() - default: - return val, nil - } - case 0x4: - if containerType == tEmpty { - d := newDecoder(bytes.NewBuffer(v.ReaderArray())) - newVal, err := d.decodeBSONArrayToSlice(tEmptySlice) - if err != nil { - return val, err - } - - if isPtr { - val = convertToPtr(newVal) - isPtr = false - } else { - val = newVal - } - - break - } - - if containerType.Kind() == reflect.Slice { - d := newDecoder(bytes.NewBuffer(v.ReaderArray())) - newVal, err := d.decodeBSONArrayToSlice(containerType) - if err != nil { - return val, err - } - - if isPtr { - val = convertToPtr(newVal) - isPtr = false - } else { - val = newVal - } - - break - } - - if containerType.Kind() == reflect.Array { - d := newDecoder(bytes.NewBuffer(v.ReaderArray())) - newVal, err := d.decodeBSONArrayIntoArray(containerType) - if err != nil { - return val, err - } - - if isPtr { - val = convertToPtr(newVal) - isPtr = false - } else { - val = newVal - } - - break - } - - fallthrough - - case 0x3: - r := v.ReaderDocument() - - typeToCreate := containerType - if typeToCreate == tEmpty { - typeToCreate = outer - } - - empty, err := d.createEmptyValue(r, typeToCreate) - if err != nil { - return val, err - } - - d := NewDecoder(bytes.NewBuffer(r)) - err = d.Decode(empty.Interface()) - if err != nil { - return val, err - } - - if reflect.PtrTo(typeToCreate) == empty.Type() { - empty = empty.Elem() - if isPtr { - empty = convertToPtr(empty) - isPtr = false - } - } - - val = empty - - case 0x5: - switch containerType { - case tByteSlice: - _, data := v.Binary() - val = reflect.ValueOf(data) - case tEmpty, tBinary: - st, data := v.Binary() - val = reflect.ValueOf(Binary{Subtype: st, Data: data}) - } - - case 0x6: - if containerType != tEmpty { - return val, nil - } - - val = reflect.ValueOf(Undefined) - case 0x7: - if containerType != tOID && containerType != tEmpty { - return val, nil - } - - val = reflect.ValueOf(v.ObjectID()) - case 0x8: - if containerType != tBool && containerType != tEmpty { - return val, nil - } - - val = reflect.ValueOf(v.Boolean()) - case 0x9: - if containerType != tTime && containerType != tEmpty { - return val, nil - } - - if int64(v.getUint64()) == -zeroEpochMs { - val = reflect.ValueOf(time.Time{}) - } else { - val = reflect.ValueOf(v.DateTime()) - } - - case 0xA: - if containerType != tEmpty { - return val, nil - } - - val = reflect.ValueOf(Null) - case 0xB: - if containerType != tRegex && containerType != tEmpty { - return val, nil - } - - p, o := v.Regex() - val = reflect.ValueOf(Regex{Pattern: p, Options: o}) - case 0xC: - if containerType != tDBPointer && containerType != tEmpty { - return val, nil - } - - db, p := v.DBPointer() - val = reflect.ValueOf(DBPointer{DB: db, Pointer: p}) - case 0xD: - if containerType != tJavaScriptCode && containerType != tString && containerType != tEmpty { - return val, nil - } - - val = reflect.ValueOf(v.JavaScript()) - case 0xE: - if containerType != tSymbol && containerType != tString && containerType != tEmpty { - return val, nil - } - - val = reflect.ValueOf(v.Symbol()) - case 0xF: - if containerType != tCodeWithScope && containerType != tEmpty { - return val, nil - } - - code, scope := v.MutableJavaScriptWithScope() - val = reflect.ValueOf(CodeWithScope{Code: code, Scope: scope}) - case 0x10: - i := v.Int32() - - switch containerType { - case tInt8: - if i >= int32(math.MinInt8) && i <= int32(math.MaxInt8) { - val = reflect.ValueOf(int8(i)) - } - - case tInt16: - if i >= int32(math.MinInt16) && i <= int32(math.MaxInt16) { - val = reflect.ValueOf(int16(i)) - } - - case tUint8: - if i >= 0 && i <= int32(math.MaxUint8) { - val = reflect.ValueOf(uint8(i)) - } - - case tUint16: - if i >= 0 && i <= int32(math.MaxUint16) { - val = reflect.ValueOf(uint16(i)) - } - - case tUint32: - if i < 0 { - return val, nil - } - - val = reflect.ValueOf(uint32(i)) - case tUint64: - if i < 0 { - return val, nil - } - - val = reflect.ValueOf(uint64(i)) - case tUint: - if i < 0 { - return val, nil - } - - val = reflect.ValueOf(uint(i)) - case tEmpty, tInt32, tInt64, tInt, tFloat32, tFloat64: - val = reflect.ValueOf(i).Convert(containerType) - case tJSONNumber: - val = reflect.ValueOf(strconv.FormatInt(int64(i), 10)).Convert(tJSONNumber) - default: - return val, nil - } - - case 0x11: - if containerType != tTimestamp && containerType != tEmpty { - return val, nil - } - - t, i := v.Timestamp() - val = reflect.ValueOf(Timestamp{T: t, I: i}) - case 0x12: - i := v.Int64() - - switch containerType { - case tInt8: - if i >= int64(math.MinInt8) && i <= int64(math.MaxInt8) { - val = reflect.ValueOf(int8(i)) - } - - case tInt16: - if i >= int64(math.MinInt16) && i <= int64(math.MaxInt16) { - val = reflect.ValueOf(int16(i)) - } - - case tUint8: - if i >= 0 && i <= int64(math.MaxUint8) { - val = reflect.ValueOf(uint8(i)) - } - - case tUint16: - if i >= 0 && i <= int64(math.MaxUint16) { - val = reflect.ValueOf(uint16(i)) - } - - case tUint32: - if i >= 0 && i <= math.MaxUint32 { - val = reflect.ValueOf(uint32(i)) - } - case tUint64: - if i >= 0 { - val = reflect.ValueOf(uint64(i)) - } - case tUint: - if i >= 0 && int64(uint(i)) == i { - val = reflect.ValueOf(uint(i)) - } - case tInt32: - if i >= int64(math.MinInt32) && i <= int64(math.MaxInt32) { - val = reflect.ValueOf(int32(i)) - } - case tInt: - // Check the value can fit in an int - if int64(int(i)) == i { - val = reflect.ValueOf(int(i)) - } - case tInt64, tEmpty: - val = reflect.ValueOf(i) - case tFloat32: - val = reflect.ValueOf(float32(i)) - case tFloat64: - val = reflect.ValueOf(float64(i)) - case tJSONNumber: - val = reflect.ValueOf(strconv.FormatInt(i, 10)).Convert(tJSONNumber) - } - - case 0x13: - if containerType != tDecimal && containerType != tEmpty { - return val, nil - } - - val = reflect.ValueOf(v.Decimal128()) - case 0xFF: - if containerType != tEmpty { - return val, nil - } - - val = reflect.ValueOf(MinKey) - case 0x7f: - if containerType != tEmpty { - return val, nil - } - - val = reflect.ValueOf(MaxKey) - default: - return val, fmt.Errorf("invalid BSON type: %s", v.Type()) - } - - if isPtr && val.IsValid() && !val.CanAddr() { - val = convertToPtr(val) - } - return val, nil -} - -func (d *decoder) decodeIntoMap(mapVal reflect.Value) error { - err := d.decodeToReader() - if err != nil { - return err - } - - itr, err := d.bsonReader.Iterator() - if err != nil { - return err - } - - valType := mapVal.Type().Elem() - - for itr.Next() { - elem := itr.Element() - - v, err := d.getReflectValue(elem.value, valType, mapVal.Type()) - if err != nil { - return err - } - - k := reflect.ValueOf(elem.Key()) - mapVal.SetMapIndex(k, v) - } - - return itr.Err() -} - -func (d *decoder) decodeBSONArrayToSlice(sliceType reflect.Type) (reflect.Value, error) { - var out reflect.Value - - elems := make([]reflect.Value, 0) - - err := d.decodeToReader() - if err != nil { - return out, err - } - - itr, err := d.bsonReader.Iterator() - if err != nil { - return out, err - } - - for itr.Next() { - v, err := d.getReflectValue( - itr.Element().Clone().Value(), - sliceType.Elem(), - sliceType, - ) - if err != nil { - return out, err - } - if !v.IsValid() { - continue - } - - elems = append(elems, v) - } - - out = reflect.MakeSlice(sliceType, len(elems), len(elems)) - - for i, elem := range elems { - if i >= out.Len() { - break - } - - if sliceType.Elem().Kind() == reflect.Ptr { - if elem.CanAddr() { - elem = elem.Addr() - } else { - elem = elem.Elem().Addr() - } - } - out.Index(i).Set(elem) - } - - return out, nil -} - -func (d *decoder) decodeBSONArrayIntoArray(arrayType reflect.Type) (reflect.Value, error) { - length := arrayType.Len() - arrayVal := reflect.New(arrayType) - - err := d.decodeToReader() - if err != nil { - return arrayVal, err - } - - itr, err := d.bsonReader.Iterator() - if err != nil { - return arrayVal, err - } - - i := 0 - for itr.Next() { - if i >= length { - break - } - - v, err := d.getReflectValue( - itr.Element().Clone().Value(), - arrayType.Elem(), - arrayType, - ) - if err != nil { - return arrayVal, err - } - - arrayVal.Elem().Index(i).Set(v) - i++ - } - - if err = itr.Err(); err != nil { - return arrayVal, err - } - - return arrayVal.Elem(), nil -} - -func (d *decoder) decodeIntoElementSlice(sliceVal reflect.Value) error { - if sliceVal.Type().Elem() != tElement { - return nil - } - - sliceLength := sliceVal.Len() - - err := d.decodeToReader() - if err != nil { - return err - } - - itr, err := d.bsonReader.Iterator() - if err != nil { - return err - } - - i := 0 - for itr.Next() { - if i >= sliceLength { - return NewErrTooSmall() - } - - elem := reflect.ValueOf(itr.Element().Clone()) - sliceVal.Index(i).Set(elem) - i++ - } - - return itr.Err() -} - -func matchesField(key string, field string, sType reflect.Type) bool { - sField, found := sType.FieldByName(field) - if !found { - return false - } - - tag, ok := sField.Tag.Lookup("bson") - if !ok { - // Get the full tag string - tag = string(sField.Tag) - - if len(sField.Tag) == 0 || strings.ContainsRune(tag, ':') { - return strings.ToLower(key) == strings.ToLower(field) - } - } - - var fieldKey string - i := strings.IndexRune(tag, ',') - if i == -1 { - fieldKey = tag - } else { - fieldKey = tag[:i] - } - - return fieldKey == key -} - -func (d *decoder) decodeIntoStruct(structVal reflect.Value) error { - err := d.decodeToReader() - if err != nil { - return err - } - - itr, err := d.bsonReader.Iterator() - if err != nil { - return err - } - - sType := structVal.Type() - - for itr.Next() { - elem := itr.Element() - - field := structVal.FieldByNameFunc(func(field string) bool { - return matchesField(elem.Key(), field, sType) - }) - if field == zeroVal { - continue - } - - v, err := d.getReflectValue(elem.value, field.Type(), structVal.Type()) - if err != nil { - return err - } - - if v != zeroVal { - if field.Type().Kind() == reflect.Ptr { - if v.CanAddr() { - v = v.Addr() - } else { - v = v.Elem().Addr() - } - } - - field.Set(v) - } - } - - return itr.Err() -} diff --git a/bson/decode_test.go b/bson/decode_test.go index 215cb21221..8523982ee6 100644 --- a/bson/decode_test.go +++ b/bson/decode_test.go @@ -7,14 +7,8 @@ package bson import ( - "bytes" - "encoding/json" - "net/url" - "reflect" "testing" - "github.com/google/go-cmp/cmp" - "github.com/mongodb/mongo-go-driver/bson/decimal" "github.com/stretchr/testify/require" ) @@ -36,3862 +30,6 @@ func requireErrEqual(t *testing.T, err1 error, err2 error) { require.Equal(t, err1, err2) } -func TestDecoder(t *testing.T) { - t.Run("byte slice", func(t *testing.T) { - testCases := []struct { - name string - reader *bytes.Buffer - expected []byte - actual []byte - err error - }{ - { - "nil", - bytes.NewBuffer([]byte{0x5, 0x0, 0x0, 0x0, 0x0}), - nil, - nil, - NewErrTooSmall(), - }, - { - "empty slice", - bytes.NewBuffer([]byte{0x5, 0x0, 0x0, 0x0}), - nil, - []byte{}, - NewErrTooSmall(), - }, - { - "too small", - bytes.NewBuffer([]byte{ - 0x5, 0x0, 0x0, 0x0, 0x0, - }), - nil, - make([]byte, 0x4), - NewErrTooSmall(), - }, - { - "empty doc", - bytes.NewBuffer([]byte{ - 0x5, 0x0, 0x0, 0x0, 0x0, - }), - []byte{0x5, 0x0, 0x0, 0x0, 0x0}, - make([]byte, 0x5), - nil, - }, - { - "non-empty doc", - bytes.NewBuffer([]byte{ - // length - 0x17, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - null - 0xa, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - }), - []byte{ - // length - 0x17, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - null - 0xa, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - }, - make([]byte, 0x17), - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - d := NewDecoder(tc.reader) - - err := d.Decode(tc.actual) - requireErrEqual(t, tc.err, err) - if err != nil { - return - } - - require.True(t, bytes.Equal(tc.expected, tc.actual)) - }) - } - }) - - t.Run("Reader", func(t *testing.T) { - testCases := []struct { - name string - reader *bytes.Buffer - expected Reader - actual Reader - err error - }{ - { - "nil", - bytes.NewBuffer([]byte{0x5, 0x0, 0x0, 0x0, 0x0}), - nil, - nil, - NewErrTooSmall(), - }, - { - "empty slice", - bytes.NewBuffer([]byte{0x5, 0x0, 0x0, 0x0}), - nil, - []byte{}, - NewErrTooSmall(), - }, - { - "too small", - bytes.NewBuffer([]byte{ - 0x5, 0x0, 0x0, 0x0, 0x0, - }), - nil, - make([]byte, 0x4), - NewErrTooSmall(), - }, - { - "empty doc", - bytes.NewBuffer([]byte{ - 0x5, 0x0, 0x0, 0x0, 0x0, - }), - []byte{0x5, 0x0, 0x0, 0x0, 0x0}, - make([]byte, 0x5), - nil, - }, - { - "non-empty doc", - bytes.NewBuffer([]byte{ - // length - 0x17, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - null - 0xa, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - }), - []byte{ - // length - 0x17, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - null - 0xa, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - }, - make([]byte, 0x17), - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - d := NewDecoder(tc.reader) - - err := d.Decode(tc.actual) - requireErrEqual(t, tc.err, err) - if err != nil { - return - } - - require.True(t, bytes.Equal(tc.expected, tc.actual)) - }) - } - }) - - t.Run("io.Writer", func(t *testing.T) { - testCases := []struct { - name string - reader *bytes.Buffer - expected *bytes.Buffer - actual *bytes.Buffer - err error - }{ - { - "empty doc", - bytes.NewBuffer([]byte{ - 0x5, 0x0, 0x0, 0x0, 0x0, - }), - bytes.NewBuffer([]byte{ - 0x5, 0x0, 0x0, 0x0, 0x0, - }), - bytes.NewBuffer([]byte{}), - nil, - }, - { - "non-empty doc", - bytes.NewBuffer([]byte{ - // length - 0x17, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - null - 0xa, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - }), - bytes.NewBuffer([]byte{ - // length - 0x17, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - null - 0xa, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - }), - bytes.NewBuffer([]byte{}), - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - d := NewDecoder(tc.reader) - - err := d.Decode(tc.actual) - requireErrEqual(t, tc.err, err) - if err != nil { - return - } - - require.Equal(t, tc.expected, tc.actual) - }) - } - }) - - t.Run("Unmarshaler", func(t *testing.T) { - testCases := []struct { - name string - reader *bytes.Buffer - expected *Document - actual *Document - err error - }{ - { - "empty doc", - bytes.NewBuffer([]byte{ - 0x5, 0x0, 0x0, 0x0, 0x0, - }), - NewDocument(), - NewDocument(), - nil, - }, - { - "non-empty doc", - bytes.NewBuffer([]byte{ - // length - 0x17, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - null - 0xa, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - }), - NewDocument( - EC.String("foo", "bar"), - EC.Null("baz"), - ), - NewDocument(), - nil, - }, - { - "nested doc", - bytes.NewBuffer([]byte{ - // length - 0x26, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - document - 0x3, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // -- begin subdocument -- - - // length - 0xf, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "bang" - 0x62, 0x61, 0x6e, 0x67, 0x0, - // value - int32(12) - 0xc, 0x0, 0x0, 0x0, - - // null terminator - 0x0, - - // -- end subdocument - - // null terminator - 0x0, - }), - NewDocument( - EC.String("foo", "bar"), - EC.SubDocumentFromElements("baz", - EC.Int32("bang", 12), - ), - ), - NewDocument(), - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - d := NewDecoder(tc.reader) - - err := d.Decode(tc.actual) - requireErrEqual(t, tc.err, err) - if err != nil { - return - } - - require.True(t, documentComparer(tc.expected, tc.actual)) - }) - } - }) - - t.Run("map", func(t *testing.T) { - testCases := []struct { - name string - reader *bytes.Buffer - expected map[string]interface{} - actual map[string]interface{} - err error - }{ - { - "empty doc", - bytes.NewBuffer([]byte{ - 0x5, 0x0, 0x0, 0x0, 0x0, - }), - make(map[string]interface{}), - make(map[string]interface{}), - nil, - }, - { - "non-empty doc", - bytes.NewBuffer([]byte{ - // length - 0x1b, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(32) - 0x20, 0x0, 0x0, 0x0, - - // null terminator - 0x0, - }), - map[string]interface{}{ - "foo": "bar", - "baz": int32(32), - }, - make(map[string]interface{}), - nil, - }, - { - "containing array", - bytes.NewBuffer( - []byte{ - // length - 0x1a, 0x0, 0x0, 0x0, - - // type - array - 0x4, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // ----- begin array ----- - - // length - 0x10, 0x0, 0x0, 0x0, - - // type string - 0x2, - // key - "0" - 0x30, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // null terminator - 0x0, - - // ----- end array ----- - - // null terminator - 0x0, - }, - ), - map[string]interface{}{ - "foo": []interface{}{"bar"}, - }, - make(map[string]interface{}), - nil, - }, - { - "nested doc", - bytes.NewBuffer([]byte{ - // length - 0x26, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - document - 0x3, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // -- begin subdocument -- - - // length - 0xf, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "bang" - 0x62, 0x61, 0x6e, 0x67, 0x0, - // value - int32(12) - 0xc, 0x0, 0x0, 0x0, - - // null terminator - 0x0, - - // -- end subdocument - - // null terminator - 0x0, - }), - map[string]interface{}{ - "foo": "bar", - "baz": map[string]interface{}{ - "bang": int32(12), - }, - }, - make(map[string]interface{}), - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - d := NewDecoder(tc.reader) - - err := d.Decode(tc.actual) - requireErrEqual(t, tc.err, err) - if err != nil { - return - } - - require.True(t, cmp.Equal(tc.expected, tc.actual)) - }) - } - }) - - t.Run("element slice", func(t *testing.T) { - testCases := []struct { - name string - reader *bytes.Buffer - expected []*Element - actual []*Element - err error - }{ - { - "empty doc", - bytes.NewBuffer([]byte{ - 0x5, 0x0, 0x0, 0x0, 0x0, - }), - []*Element{}, - []*Element{}, - nil, - }, - { - "non-empty doc", - bytes.NewBuffer([]byte{ - // length - 0x1b, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(32) - 0x20, 0x0, 0x0, 0x0, - - // null terminator - 0x0, - }), - []*Element{ - EC.String("foo", "bar"), - EC.Int32("baz", 32), - }, - make([]*Element, 2), - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - d := NewDecoder(tc.reader) - - err := d.Decode(tc.actual) - requireErrEqual(t, tc.err, err) - if err != nil { - return - } - - elementSliceEqual(t, tc.expected, tc.actual) - }) - } - }) - - t.Run("struct", func(t *testing.T) { - stringValue := "bar" - int32Value := int32(32) - int32Value12 := int32(12) - testCases := []struct { - name string - reader *bytes.Buffer - expected interface{} - actual interface{} - err error - }{ - { - "empty doc", - bytes.NewBuffer([]byte{ - 0x5, 0x0, 0x0, 0x0, 0x0, - }), - &struct{}{}, - &struct{}{}, - nil, - }, - { - "non-empty doc", - bytes.NewBuffer([]byte{ - // length - 0x25, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(32) - 0x20, 0x0, 0x0, 0x0, - - // type - regex - 0xb, - // key - "r" - 0x72, 0x0, - // value - pattern("WoRd") - 0x57, 0x6f, 0x52, 0x64, 0x0, - // value - options("i") - 0x69, 0x0, - - // null terminator - 0x0, - }), - &struct { - Foo string - Baz int32 - R Regex - }{ - "bar", - 32, - Regex{Pattern: "WoRd", Options: "i"}, - }, - &struct { - Foo string - Baz int32 - R Regex - }{}, - nil, - }, - { - "non-empty doc pointers", - bytes.NewBuffer([]byte{ - // length - 0x25, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(32) - 0x20, 0x0, 0x0, 0x0, - - // type - regex - 0xb, - // key - "r" - 0x72, 0x0, - // value - pattern("WoRd") - 0x57, 0x6f, 0x52, 0x64, 0x0, - // value - options("i") - 0x69, 0x0, - - // null terminator - 0x0, - }), - &struct { - Foo *string - Baz *int32 - R *Regex - }{ - &stringValue, - &int32Value, - &Regex{Pattern: "WoRd", Options: "i"}, - }, - &struct { - Foo *string - Baz *int32 - R *Regex - }{}, - nil, - }, - { - "empty interface field", - bytes.NewBuffer([]byte{ - // length - 0xe, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(32) - 0x20, 0x0, 0x0, 0x0, - - // null terminator - 0x0, - }), - &struct { - Baz interface{} - }{ - int32(32), - }, - &struct { - Baz interface{} - }{}, - nil, - }, - { - "containing array", - bytes.NewBuffer( - []byte{ - // length - 0x1a, 0x0, 0x0, 0x0, - - // type - array - 0x4, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // ----- begin array ----- - - // length - 0x10, 0x0, 0x0, 0x0, - - // type string - 0x2, - // key - "0" - 0x30, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // null terminator - 0x0, - - // ----- end array ----- - - // null terminator - 0x0, - }, - ), - &struct{ Foo interface{} }{[]interface{}{"bar"}}, - &struct{ Foo interface{} }{}, - nil, - }, - { - "nested doc", - bytes.NewBuffer([]byte{ - // length - 0x26, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - document - 0x3, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // -- begin subdocument -- - - // length - 0xf, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "bang" - 0x62, 0x61, 0x6e, 0x67, 0x0, - // value - int32(12) - 0xc, 0x0, 0x0, 0x0, - - // null terminator - 0x0, - - // -- end subdocument - - // null terminator - 0x0, - }), - &struct { - Foo string - Baz struct { - Bang int32 - } - }{ - "bar", - struct{ Bang int32 }{12}, - }, - &struct { - Foo string - Baz struct { - Bang int32 - } - }{}, - nil, - }, - { - "nested doc pointer", - bytes.NewBuffer([]byte{ - // length - 0x26, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - document - 0x3, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // -- begin subdocument -- - - // length - 0xf, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "bang" - 0x62, 0x61, 0x6e, 0x67, 0x0, - // value - int32(12) - 0xc, 0x0, 0x0, 0x0, - - // null terminator - 0x0, - - // -- end subdocument - - // null terminator - 0x0, - }), - &struct { - Foo *string - Baz *struct { - Bang *int32 - } - }{ - &stringValue, - &struct{ Bang *int32 }{&int32Value12}, - }, - &struct { - Foo *string - Baz *struct { - Bang *int32 - } - }{}, - nil, - }, - { - "struct tags", - bytes.NewBuffer([]byte{ - // length - 0x1b, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(32) - 0x20, 0x0, 0x0, 0x0, - - // null terminator - 0x0, - }), - &struct { - A string `bson:"foo"` - B int32 `bson:"baz,omitempty"` - }{ - "bar", - 32, - }, - &struct { - A string `bson:"foo"` - B int32 `bson:"baz,omitempty"` - }{}, - nil, - }, - { - "struct with pointers tags", - bytes.NewBuffer([]byte{ - // length - 0x1b, 0x0, 0x0, 0x0, - - // type - string - 0x2, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string "bar" - 0x62, 0x61, 0x72, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(32) - 0x20, 0x0, 0x0, 0x0, - - // null terminator - 0x0, - }), - &struct { - A *string `bson:"foo"` - B *int32 `bson:"baz,omitempty"` - }{ - &stringValue, - &int32Value, - }, - &struct { - A *string `bson:"foo"` - B *int32 `bson:"baz,omitempty"` - }{}, - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - d := NewDecoder(tc.reader) - - err := d.Decode(tc.actual) - requireErrEqual(t, tc.err, err) - if err != nil { - return - } - - require.True(t, cmp.Equal(tc.expected, tc.actual)) - }) - } - }) - - t.Run("numbers", func(t *testing.T) { - t.Run("decode int32", func(t *testing.T) { - uint8Val := uint8(1) - uint16Val := uint16(2) - uint32Val := uint32(3) - uint64Val := uint64(4) - uintVal := uint(5) - int8Val := int8(6) - int16Val := int16(7) - int32Val := int32(8) - int64Val := int64(9) - intVal := int(10) - float32Val := float32(11.0) - float64Val := float64(12.0) - testCases := []struct { - name string - reader *bytes.Buffer - expected interface{} - actual interface{} - err error - }{ - { - "negative into uint8", - bytes.NewBuffer([]byte{ - // length - 0xe, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(-27) - 0xe5, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz uint8 - }{ - 0, - }, - &struct { - Baz uint8 - }{}, - nil, - }, - { - "negative into uint8 pointer", - bytes.NewBuffer([]byte{ - // length - 0xe, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(-27) - 0xe5, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz *uint8 - }{ - nil, - }, - &struct { - Baz *uint8 - }{}, - nil, - }, - { - "negative into uint16", - bytes.NewBuffer([]byte{ - // length - 0xe, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(-27) - 0xe5, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz uint16 - }{ - 0, - }, - &struct { - Baz uint16 - }{}, - nil, - }, - { - "negative into uint16 pointer", - bytes.NewBuffer([]byte{ - // length - 0xe, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(-27) - 0xe5, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz *uint16 - }{ - nil, - }, - &struct { - Baz *uint16 - }{}, - nil, - }, - { - "negative into uint32", - bytes.NewBuffer([]byte{ - // length - 0xe, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(-27) - 0xe5, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz uint32 - }{ - 0, - }, - &struct { - Baz uint32 - }{}, - nil, - }, - { - "negative into uint32 pointer", - bytes.NewBuffer([]byte{ - // length - 0xe, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(-27) - 0xe5, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz *uint32 - }{ - nil, - }, - &struct { - Baz *uint32 - }{}, - nil, - }, - { - "negative into uint64", - bytes.NewBuffer([]byte{ - // length - 0xe, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(-27) - 0xe5, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz uint64 - }{ - 0, - }, - &struct { - Baz uint64 - }{}, - nil, - }, - { - "negative into uint64 pointer", - bytes.NewBuffer([]byte{ - // length - 0xe, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(-27) - 0xe5, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz *uint64 - }{ - nil, - }, - &struct { - Baz *uint64 - }{}, - nil, - }, - { - "negative into uint", - bytes.NewBuffer([]byte{ - // length - 0xe, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(-27) - 0xe5, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz uint - }{ - 0, - }, - &struct { - Baz uint - }{}, - nil, - }, - { - "negative into uint pointer", - bytes.NewBuffer([]byte{ - // length - 0xe, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(-27) - 0xe5, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz *uint - }{ - nil, - }, - &struct { - Baz *uint - }{}, - nil, - }, - { - "too high for int8", - bytes.NewBuffer([]byte{ - // length - 0xe, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(2^24) - 0x0, 0x0, 0x0, 0x1, - - // null terminator - 0x0, - }), - &struct { - Baz int8 - }{ - 0, - }, - &struct { - Baz int8 - }{}, - nil, - }, - { - "too high for int8 pointer", - bytes.NewBuffer([]byte{ - // length - 0xe, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(2^24) - 0x0, 0x0, 0x0, 0x1, - - // null terminator - 0x0, - }), - &struct { - Baz *int8 - }{ - nil, - }, - &struct { - Baz *int8 - }{}, - nil, - }, - { - "too high for int16", - bytes.NewBuffer([]byte{ - // length - 0xe, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(2^24) - 0x0, 0x0, 0x0, 0x1, - - // null terminator - 0x0, - }), - &struct { - Baz int16 - }{ - 0, - }, - &struct { - Baz int16 - }{}, - nil, - }, - { - "too high for int16 pointer", - bytes.NewBuffer([]byte{ - // length - 0xe, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int32(2^24) - 0x0, 0x0, 0x0, 0x1, - - // null terminator - 0x0, - }), - &struct { - Baz *int16 - }{ - nil, - }, - &struct { - Baz *int16 - }{}, - nil, - }, - { - "success", - bytes.NewBuffer([]byte{ - // length - 0x59, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "a" - 0x61, 0x0, - // value - int32(1) - 0x1, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "b" - 0x62, 0x0, - // value - int32(2) - 0x2, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "c" - 0x63, 0x0, - // value - int32(3) - 0x3, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "d" - 0x64, 0x0, - // value - int32(4) - 0x4, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "e" - 0x65, 0x0, - // value - int32(5) - 0x5, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "f" - 0x66, 0x0, - // value - int32(6) - 0x6, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "g" - 0x67, 0x0, - // value - int32(7) - 0x7, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "h" - 0x68, 0x0, - // value - int32(8) - 0x8, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "i" - 0x69, 0x0, - // value - int32(9) - 0x9, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "j" - 0x6a, 0x0, - // value - int32(10) - 0xa, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "k" - 0x6b, 0x0, - // value - int32(11) - 0xb, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "l" - 0x6c, 0x0, - // value - int32(12) - 0xc, 0x0, 0x0, 0x0, - - // null terminator - 0x0, - }), - &struct { - A uint8 - B uint16 - C uint32 - D uint64 - E uint - F int8 - G int16 - H int32 - I int64 - J int - K float32 - L float64 - }{ - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - }, - &struct { - A uint8 - B uint16 - C uint32 - D uint64 - E uint - F int8 - G int16 - H int32 - I int64 - J int - K float32 - L float64 - }{}, - nil, - }, - { - "success pointer", - bytes.NewBuffer([]byte{ - // length - 0x59, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "a" - 0x61, 0x0, - // value - int32(1) - 0x1, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "b" - 0x62, 0x0, - // value - int32(2) - 0x2, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "c" - 0x63, 0x0, - // value - int32(3) - 0x3, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "d" - 0x64, 0x0, - // value - int32(4) - 0x4, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "e" - 0x65, 0x0, - // value - int32(5) - 0x5, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "f" - 0x66, 0x0, - // value - int32(6) - 0x6, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "g" - 0x67, 0x0, - // value - int32(7) - 0x7, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "h" - 0x68, 0x0, - // value - int32(8) - 0x8, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "i" - 0x69, 0x0, - // value - int32(9) - 0x9, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "j" - 0x6a, 0x0, - // value - int32(10) - 0xa, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "k" - 0x6b, 0x0, - // value - int32(11) - 0xb, 0x0, 0x0, 0x0, - - // type - int32 - 0x10, - // key - "l" - 0x6c, 0x0, - // value - int32(12) - 0xc, 0x0, 0x0, 0x0, - - // null terminator - 0x0, - }), - &struct { - A *uint8 - B *uint16 - C *uint32 - D *uint64 - E *uint - F *int8 - G *int16 - H *int32 - I *int64 - J *int - K *float32 - L *float64 - }{ - &uint8Val, - &uint16Val, - &uint32Val, - &uint64Val, - &uintVal, - &int8Val, - &int16Val, - &int32Val, - &int64Val, - &intVal, - &float32Val, - &float64Val, - }, - &struct { - A *uint8 - B *uint16 - C *uint32 - D *uint64 - E *uint - F *int8 - G *int16 - H *int32 - I *int64 - J *int - K *float32 - L *float64 - }{}, - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - d := NewDecoder(tc.reader) - - err := d.Decode(tc.actual) - requireErrEqual(t, tc.err, err) - if err != nil { - return - } - - require.True(t, reflect.DeepEqual(tc.expected, tc.actual)) - }) - } - }) - - t.Run("decode int64", func(t *testing.T) { - uint8Val := uint8(1) - uint16Val := uint16(2) - uint32Val := uint32(3) - uint64Val := uint64(4) - uintVal := uint(5) - int8Val := int8(6) - int16Val := int16(7) - int32Val := int32(8) - int64Val := int64(9) - intVal := int(10) - float32Val := float32(11.0) - float64Val := float64(12.0) - testCases := []struct { - name string - reader *bytes.Buffer - expected interface{} - actual interface{} - err error - }{ - { - "negative into uint8", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int64(-27) - 0xe5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz uint8 - }{ - 0, - }, - &struct { - Baz uint8 - }{}, - nil, - }, - { - "negative into uint8 pointer", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int64(-27) - 0xe5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz *uint8 - }{ - nil, - }, - &struct { - Baz *uint8 - }{}, - nil, - }, - { - "negative into uint16", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int64(-27) - 0xe5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz uint16 - }{ - 0, - }, - &struct { - Baz uint16 - }{}, - nil, - }, - { - "negative into uint16 pointer", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int64(-27) - 0xe5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz *uint16 - }{ - nil, - }, - &struct { - Baz *uint16 - }{}, - nil, - }, - { - "negative into uint32", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int64(-27) - 0xe5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz uint32 - }{ - 0, - }, - &struct { - Baz uint32 - }{}, - nil, - }, - { - "negative into uint32 pointer", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int64(-27) - 0xe5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz *uint32 - }{ - nil, - }, - &struct { - Baz *uint32 - }{}, - nil, - }, - { - "negative into uint64", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int64(-27) - 0xe5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz uint64 - }{ - 0, - }, - &struct { - Baz uint64 - }{}, - nil, - }, - { - "negative into uint64 pointer", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int64(-27) - 0xe5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz *uint64 - }{ - nil, - }, - &struct { - Baz *uint64 - }{}, - nil, - }, - { - "negative into uint", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int64(-27) - 0xe5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz uint - }{ - 0, - }, - &struct { - Baz uint - }{}, - nil, - }, - { - "negative into uint pointer", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int64(-27) - 0xe5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - - // null terminator - 0x0, - }), - &struct { - Baz *uint - }{ - nil, - }, - &struct { - Baz *uint - }{}, - nil, - }, - { - "too high for int8", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int64(2^56) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, - - // null terminator - 0x0, - }), - &struct { - Baz int8 - }{ - 0, - }, - &struct { - Baz int8 - }{}, - nil, - }, - { - "too high for int8 pointer", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int64(2^56) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, - - // null terminator - 0x0, - }), - &struct { - Baz *int8 - }{ - nil, - }, - &struct { - Baz *int8 - }{}, - nil, - }, - { - "too high for int16", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int64(2^56) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, - - // null terminator - 0x0, - }), - &struct { - Baz int16 - }{ - 0, - }, - &struct { - Baz int16 - }{}, - nil, - }, - { - "too high for int16 pointer", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int64(2^56) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, - - // null terminator - 0x0, - }), - &struct { - Baz *int16 - }{ - nil, - }, - &struct { - Baz *int16 - }{}, - nil, - }, - { - "too high for int32", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int64(2^56) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, - - // null terminator - 0x0, - }), - &struct { - Baz int32 - }{ - 0, - }, - &struct { - Baz int32 - }{}, - nil, - }, - { - "too high for int32 pointer", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - int64(2^56) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, - - // null terminator - 0x0, - }), - &struct { - Baz *int32 - }{ - nil, - }, - &struct { - Baz *int32 - }{}, - nil, - }, - { - "success", - bytes.NewBuffer([]byte{ - // length - 0x89, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "a" - 0x61, 0x0, - // value - int64(1) - 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "b" - 0x62, 0x0, - // value - int64(2) - 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "c" - 0x63, 0x0, - // value - int64(3) - 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "d" - 0x64, 0x0, - // value - int64(4) - 0x4, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "e" - 0x65, 0x0, - // value - int64(5) - 0x5, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "f" - 0x66, 0x0, - // value - int64(6) - 0x6, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "g" - 0x67, 0x0, - // value - int64(7) - 0x7, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "h" - 0x68, 0x0, - // value - int64(8) - 0x8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "i" - 0x69, 0x0, - // value - int64(9) - 0x9, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "j" - 0x6a, 0x0, - // value - int64(10) - 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "k" - 0x6b, 0x0, - // value - int64(11) - 0xb, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "l" - 0x6c, 0x0, - // value - int64(12) - 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // null terminator - 0x0, - }), - &struct { - A uint8 - B uint16 - C uint32 - D uint64 - E uint - F int8 - G int16 - H int32 - I int64 - J int - K float32 - L float64 - }{ - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - }, - &struct { - A uint8 - B uint16 - C uint32 - D uint64 - E uint - F int8 - G int16 - H int32 - I int64 - J int - K float32 - L float64 - }{}, - nil, - }, - { - "success pointers", - bytes.NewBuffer([]byte{ - // length - 0x89, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "a" - 0x61, 0x0, - // value - int64(1) - 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "b" - 0x62, 0x0, - // value - int64(2) - 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "c" - 0x63, 0x0, - // value - int64(3) - 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "d" - 0x64, 0x0, - // value - int64(4) - 0x4, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "e" - 0x65, 0x0, - // value - int64(5) - 0x5, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "f" - 0x66, 0x0, - // value - int64(6) - 0x6, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "g" - 0x67, 0x0, - // value - int64(7) - 0x7, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "h" - 0x68, 0x0, - // value - int64(8) - 0x8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "i" - 0x69, 0x0, - // value - int64(9) - 0x9, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "j" - 0x6a, 0x0, - // value - int64(10) - 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "k" - 0x6b, 0x0, - // value - int64(11) - 0xb, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // type - int64 - 0x12, - // key - "l" - 0x6c, 0x0, - // value - int64(12) - 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - - // null terminator - 0x0, - }), - &struct { - A *uint8 - B *uint16 - C *uint32 - D *uint64 - E *uint - F *int8 - G *int16 - H *int32 - I *int64 - J *int - K *float32 - L *float64 - }{ - &uint8Val, - &uint16Val, - &uint32Val, - &uint64Val, - &uintVal, - &int8Val, - &int16Val, - &int32Val, - &int64Val, - &intVal, - &float32Val, - &float64Val, - }, - &struct { - A *uint8 - B *uint16 - C *uint32 - D *uint64 - E *uint - F *int8 - G *int16 - H *int32 - I *int64 - J *int - K *float32 - L *float64 - }{}, - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - d := NewDecoder(tc.reader) - - err := d.Decode(tc.actual) - requireErrEqual(t, tc.err, err) - if err != nil { - return - } - - require.True(t, reflect.DeepEqual(tc.expected, tc.actual)) - }) - } - }) - - t.Run("decode double", func(t *testing.T) { - /*var uint8Value uint8 - var uint16Value uint16*/ - dataToDecode := []byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - double - 0x1, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - double(0.5) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xe0, 0x3f, - - // null terminator - 0x0, - } - - uint8Val := uint8(1) - uint16Val := uint16(2) - uint32Val := uint32(3) - uint64Val := uint64(4) - uintVal := uint(5) - int8Val := int8(6) - int16Val := int16(7) - int32Val := int32(8) - int64Val := int64(9) - intVal := int(10) - float32Val := float32(11.0) - float64Val := float64(12.0) - - testCases := []struct { - name string - reader *bytes.Buffer - expected interface{} - actual interface{} - err error - }{ - { - "fraction into uint8", - bytes.NewBuffer(dataToDecode), - &struct { - Baz uint8 - }{ - 0, - }, - &struct { - Baz uint8 - }{}, - nil, - }, - { - "fraction into uint8 pointer", - bytes.NewBuffer(dataToDecode), - &struct { - Baz *uint8 - }{ - nil, - }, - &struct { - Baz *uint8 - }{}, - nil, - }, - { - "fraction into uint16", - bytes.NewBuffer(dataToDecode), - &struct { - Baz uint16 - }{ - 0, - }, - &struct { - Baz uint16 - }{}, - nil, - }, - { - "fraction into uint16 pointer", - bytes.NewBuffer(dataToDecode), - &struct { - Baz *uint16 - }{ - nil, - }, - &struct { - Baz *uint16 - }{}, - nil, - }, - { - "fraction into uint32", - bytes.NewBuffer(dataToDecode), - &struct { - Baz uint32 - }{ - 0, - }, - &struct { - Baz uint32 - }{}, - nil, - }, - { - "fraction into uint32 pointer", - bytes.NewBuffer(dataToDecode), - &struct { - Baz *uint32 - }{ - nil, - }, - &struct { - Baz *uint32 - }{}, - nil, - }, - { - "fraction into uint64", - bytes.NewBuffer(dataToDecode), - &struct { - Baz uint64 - }{ - 0, - }, - &struct { - Baz uint64 - }{}, - nil, - }, - { - "fraction into uint64 pointer", - bytes.NewBuffer(dataToDecode), - &struct { - Baz *uint64 - }{ - nil, - }, - &struct { - Baz *uint64 - }{}, - nil, - }, - { - "fraction into uint", - bytes.NewBuffer(dataToDecode), - &struct { - Baz uint - }{ - 0, - }, - &struct { - Baz uint - }{}, - nil, - }, - { - "fraction into uint pointer", - bytes.NewBuffer(dataToDecode), - &struct { - Baz *uint - }{ - nil, - }, - &struct { - Baz *uint - }{}, - nil, - }, - { - "fraction into int32", - bytes.NewBuffer(dataToDecode), - &struct { - Baz int32 - }{ - 0, - }, - &struct { - Baz int32 - }{}, - nil, - }, - { - "fraction into int32 pointer", - bytes.NewBuffer(dataToDecode), - &struct { - Baz *int32 - }{ - nil, - }, - &struct { - Baz *int32 - }{}, - nil, - }, - { - "fraction into int64", - bytes.NewBuffer(dataToDecode), - &struct { - Baz int64 - }{ - 0, - }, - &struct { - Baz int64 - }{}, - nil, - }, - { - "fraction into int64 pointer", - bytes.NewBuffer(dataToDecode), - &struct { - Baz *int64 - }{ - nil, - }, - &struct { - Baz *int64 - }{}, - nil, - }, - { - "fraction into int", - bytes.NewBuffer(dataToDecode), - &struct { - Baz int - }{ - 0, - }, - &struct { - Baz int - }{}, - nil, - }, - { - "fraction into int pointer", - bytes.NewBuffer(dataToDecode), - &struct { - Baz *int - }{ - nil, - }, - &struct { - Baz *int - }{}, - nil, - }, - { - "too precise for float32", - bytes.NewBuffer([]byte{ - // length - 0x12, 0x0, 0x0, 0x0, - - // type - double - 0x1, - // key - "baz" - 0x62, 0x61, 0x7a, 0x0, - // value - double(3.00000000001) - 0xf6, 0x57, 0x0, 0x0, 0x0, 0x0, 0x8, 0x40, - - // null terminator - 0x0, - }), - &struct { - Baz float32 - }{ - 3, - }, - &struct { - Baz float32 - }{}, - nil, - }, - { - "success", - bytes.NewBuffer([]byte{ - // length - 0x89, 0x0, 0x0, 0x0, - - // type - double - 0x1, - // key - "a" - 0x61, 0x0, - // value - double(1.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xf0, 0x3f, - - // type - double - 0x1, - // key - "b" - 0x62, 0x0, - // value - double(2.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x40, - - // type - double - 0x1, - // key - "c" - 0x63, 0x0, - // value - double(3.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x8, 0x40, - - // type - double - 0x1, - // key - "d" - 0x64, 0x0, - // value - double(4.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x10, 0x40, - - // type - double - 0x1, - // key - "e" - 0x65, 0x0, - // value - double(5.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x14, 0x40, - - // type - double - 0x1, - // key - "f" - 0x66, 0x0, - // value - double(6.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x18, 0x40, - - // type - double - 0x1, - // key - "g" - 0x67, 0x0, - // value - double(7.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1c, 0x40, - - // type - double - 0x1, - // key - "h" - 0x68, 0x0, - // value - double(8.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x20, 0x40, - - // type - double - 0x1, - // key - "i" - 0x69, 0x0, - // value - double(9.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x22, 0x40, - - // type - double - 0x1, - // key - "j" - 0x6a, 0x0, - // value - double(10.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x24, 0x40, - - // type - double - 0x1, - // key - "k" - 0x6b, 0x0, - // value - double(11.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x26, 0x40, - - // type - double - 0x1, - // key - "j" - 0x6c, 0x0, - // value - double(12.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x28, 0x40, - - // null terminator - 0x0, - }), - &struct { - A uint8 - B uint16 - C uint32 - D uint64 - E uint - F int8 - G int16 - H int32 - I int64 - J int - K float32 - L float64 - }{ - 1.0, - 2.0, - 3.0, - 4.0, - 5.0, - 6.0, - 7.0, - 8.0, - 9.0, - 10.0, - 11.0, - 12.0, - }, - &struct { - A uint8 - B uint16 - C uint32 - D uint64 - E uint - F int8 - G int16 - H int32 - I int64 - J int - K float32 - L float64 - }{}, - nil, - }, - { - "success pointers", - bytes.NewBuffer([]byte{ - // length - 0x89, 0x0, 0x0, 0x0, - - // type - double - 0x1, - // key - "a" - 0x61, 0x0, - // value - double(1.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xf0, 0x3f, - - // type - double - 0x1, - // key - "b" - 0x62, 0x0, - // value - double(2.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x40, - - // type - double - 0x1, - // key - "c" - 0x63, 0x0, - // value - double(3.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x8, 0x40, - - // type - double - 0x1, - // key - "d" - 0x64, 0x0, - // value - double(4.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x10, 0x40, - - // type - double - 0x1, - // key - "e" - 0x65, 0x0, - // value - double(5.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x14, 0x40, - - // type - double - 0x1, - // key - "f" - 0x66, 0x0, - // value - double(6.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x18, 0x40, - - // type - double - 0x1, - // key - "g" - 0x67, 0x0, - // value - double(7.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1c, 0x40, - - // type - double - 0x1, - // key - "h" - 0x68, 0x0, - // value - double(8.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x20, 0x40, - - // type - double - 0x1, - // key - "i" - 0x69, 0x0, - // value - double(9.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x22, 0x40, - - // type - double - 0x1, - // key - "j" - 0x6a, 0x0, - // value - double(10.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x24, 0x40, - - // type - double - 0x1, - // key - "k" - 0x6b, 0x0, - // value - double(11.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x26, 0x40, - - // type - double - 0x1, - // key - "j" - 0x6c, 0x0, - // value - double(12.0) - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x28, 0x40, - - // null terminator - 0x0, - }), - &struct { - A *uint8 - B *uint16 - C *uint32 - D *uint64 - E *uint - F *int8 - G *int16 - H *int32 - I *int64 - J *int - K *float32 - L *float64 - }{ - &uint8Val, - &uint16Val, - &uint32Val, - &uint64Val, - &uintVal, - &int8Val, - &int16Val, - &int32Val, - &int64Val, - &intVal, - &float32Val, - &float64Val, - }, - &struct { - A *uint8 - B *uint16 - C *uint32 - D *uint64 - E *uint - F *int8 - G *int16 - H *int32 - I *int64 - J *int - K *float32 - L *float64 - }{}, - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - d := NewDecoder(tc.reader) - - err := d.Decode(tc.actual) - requireErrEqual(t, tc.err, err) - if err != nil { - return - } - - require.True(t, reflect.DeepEqual(tc.expected, tc.actual)) - }) - } - }) - t.Run("decimal128", func(t *testing.T) { - decimal128, err := decimal.ParseDecimal128("1.5e10") - if err != nil { - t.Errorf("Error parsing decimal128: %v", err) - t.FailNow() - } - testCases := []struct { - name string - reader []byte - expected interface{} - actual interface{} - err error - }{ - { - "decimal128", - docToBytes(NewDocument(EC.Decimal128("a", decimal128))), - &struct { - A decimal.Decimal128 - }{ - A: decimal128, - }, - &struct { - A decimal.Decimal128 - }{}, - nil, - }, - { - "decimal128 pointer", - docToBytes(NewDocument(EC.Decimal128("a", decimal128))), - &struct { - A *decimal.Decimal128 - }{ - A: &decimal128, - }, - &struct { - A *decimal.Decimal128 - }{}, - nil, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - d := NewDecoder(bytes.NewBuffer(tc.reader)) - - err := d.Decode(tc.actual) - requireErrEqual(t, tc.err, err) - if err != nil { - return - } - - require.True(t, reflect.DeepEqual(tc.expected, tc.actual)) - }) - } - }) - }) - - t.Run("mixed types", func(t *testing.T) { - stringValue := "baz" - testCases := []struct { - name string - reader *bytes.Buffer - expected interface{} - actual interface{} - err error - }{ - { - "struct containing slice", - bytes.NewBuffer([]byte{ - // length - 0x1a, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - array - 0x4, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x10, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "0" - 0x30, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - &struct { - Foo []string - }{ - []string{"baz"}, - }, - &struct { - Foo []string - }{}, - nil, - }, - { - "struct containing slice pointer", - bytes.NewBuffer([]byte{ - // length - 0x1a, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - array - 0x4, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x10, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "0" - 0x30, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - string - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - &struct { - Foo *[]string - }{ - &[]string{"baz"}, - }, - &struct { - Foo *[]string - }{}, - nil, - }, - { - "struct containing array", - bytes.NewBuffer([]byte{ - // length - 0x1a, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - array - 0x4, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x10, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "0" - 0x30, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - &struct { - Foo [1]string - }{ - [...]string{"baz"}, - }, - &struct { - Foo [1]string - }{}, - nil, - }, - { - "struct containing array pointer", - bytes.NewBuffer([]byte{ - // length - 0x1a, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - array - 0x4, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x10, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "0" - 0x30, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - &struct { - Foo *[1]string - }{ - &[...]string{"baz"}, - }, - &struct { - Foo *[1]string - }{}, - nil, - }, - { - "struct containing map", - bytes.NewBuffer([]byte{ - // length - 0x1c, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - document - 0x3, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x12, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "bar" - 0x62, 0x61, 0x72, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - &struct { - Foo map[string]string - }{ - map[string]string{ - "bar": "baz", - }, - }, - &struct { - Foo map[string]string - }{}, - nil, - }, - - { - "struct containing map of pointers", - bytes.NewBuffer([]byte{ - // length - 0x1c, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - document - 0x3, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x12, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "bar" - 0x62, 0x61, 0x72, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - &struct { - Foo map[string]*string - }{ - map[string]*string{ - "bar": &stringValue, - }, - }, - &struct { - Foo map[string]*string - }{}, - nil, - }, - - { - "struct containing document", - bytes.NewBuffer([]byte{ - // length - 0x1c, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - document - 0x3, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x12, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "bar" - 0x62, 0x61, 0x72, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - &struct { - Foo *Document - }{ - NewDocument( - EC.String("bar", "baz"), - ), - }, - &struct { - Foo *Document - }{}, - nil, - }, - - { - "struct containing reader", - bytes.NewBuffer([]byte{ - // length - 0x1c, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - document - 0x3, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x12, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "bar" - 0x62, 0x61, 0x72, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - &struct { - Foo Reader - }{ - Reader{ - // length - 0x12, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "bar" - 0x62, 0x61, 0x72, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - }, - }, - &struct { - Foo Reader - }{}, - nil, - }, - { - "struct containing reader pointer", - bytes.NewBuffer([]byte{ - // length - 0x1c, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - document - 0x3, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x12, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "bar" - 0x62, 0x61, 0x72, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - &struct { - Foo *Reader - }{ - &Reader{ - // length - 0x12, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "bar" - 0x62, 0x61, 0x72, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - }, - }, - &struct { - Foo *Reader - }{}, - nil, - }, - - { - "map containing slice", - bytes.NewBuffer([]byte{ - // length - 0x1a, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - array - 0x4, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x10, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "0" - 0x30, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - map[string][]string{ - "foo": {"baz"}, - }, - make(map[string][]string), - nil, - }, - { - "map containing slice pointer", - bytes.NewBuffer([]byte{ - // length - 0x1a, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - array - 0x4, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x10, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "0" - 0x30, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - map[string]*[]string{ - "foo": {"baz"}, - }, - make(map[string]*[]string), - nil, - }, - { - "map containing array", - bytes.NewBuffer([]byte{ - // length - 0x1a, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - array - 0x4, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x10, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "0" - 0x30, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - map[string][1]string{ - "foo": {"baz"}, - }, - make(map[string][1]string), - nil, - }, - { - "map containing array pointer", - bytes.NewBuffer([]byte{ - // length - 0x1a, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - array - 0x4, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x10, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "0" - 0x30, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - map[string]*[1]string{ - "foo": {"baz"}, - }, - make(map[string]*[1]string), - nil, - }, - { - "map containing struct", - bytes.NewBuffer([]byte{ - // length - 0x1c, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - document - 0x3, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x12, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "bar" - 0x62, 0x61, 0x72, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - map[string]struct{ Bar string }{ - "foo": {Bar: "baz"}, - }, - make(map[string]struct{ Bar string }), - nil, - }, - { - "map containing pointer to struct", - bytes.NewBuffer([]byte{ - // length - 0x1c, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - document - 0x3, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x12, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "bar" - 0x62, 0x61, 0x72, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - map[string]*struct{ Bar string }{ - "foo": {Bar: "baz"}, - }, - make(map[string]*struct{ Bar string }), - nil, - }, - { - "map containing struct with pointers", - bytes.NewBuffer([]byte{ - // length - 0x1c, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - document - 0x3, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x12, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "bar" - 0x62, 0x61, 0x72, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - map[string]struct{ Bar *string }{ - "foo": {Bar: &stringValue}, - }, - make(map[string]struct{ Bar *string }), - nil, - }, - { - "map containing pointer to struct with pointers", - bytes.NewBuffer([]byte{ - // length - 0x1c, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - document - 0x3, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x12, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "bar" - 0x62, 0x61, 0x72, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - }), - map[string]*struct{ Bar *string }{ - "foo": {Bar: &stringValue}, - }, - make(map[string]*struct{ Bar *string }), - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - d := NewDecoder(tc.reader) - - err := d.Decode(tc.actual) - requireErrEqual(t, tc.err, err) - if err != nil { - return - } - - require.True(t, cmp.Equal(tc.actual, tc.expected)) - }) - } - }) - t.Run("pluggable types", func(t *testing.T) { - intJSONNumber := json.Number("5") - floatJSONNumber := json.Number("10.1") - murl, err := url.Parse("https://mongodb.com/random-url?hello=world") - if err != nil { - t.Errorf("Error parsing URL: %v", err) - t.FailNow() - } - testCases := []struct { - name string - reader []byte - expected interface{} - actual interface{} - err error - }{ - { - "*url.URL", - docToBytes(NewDocument(EC.String("a", murl.String()))), - &struct { - A *url.URL - }{ - A: murl, - }, - &struct { - A *url.URL - }{}, - nil, - }, - { - "json.Number", - docToBytes(NewDocument(EC.Int64("a", 5), EC.Double("b", 10.10))), - &struct { - A json.Number - B json.Number - }{ - A: json.Number("5"), - B: json.Number("10.1"), - }, - &struct { - A json.Number - B json.Number - }{}, - nil, - }, - { - "json.Number pointer", - docToBytes(NewDocument(EC.Int64("a", 5), EC.Double("b", 10.10))), - &struct { - A *json.Number - B *json.Number - }{ - A: &intJSONNumber, - B: &floatJSONNumber, - }, - &struct { - A *json.Number - B *json.Number - }{}, - nil, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - d := NewDecoder(bytes.NewBuffer(tc.reader)) - - err := d.Decode(tc.actual) - requireErrEqual(t, tc.err, err) - if err != nil { - return - } - - require.True(t, reflect.DeepEqual(tc.expected, tc.actual)) - }) - } - }) -} - func elementSliceEqual(t *testing.T, e1 []*Element, e2 []*Element) { require.Equal(t, len(e1), len(e2)) diff --git a/bson/decoder.go b/bson/decoder.go deleted file mode 100644 index cf0b966e04..0000000000 --- a/bson/decoder.go +++ /dev/null @@ -1,53 +0,0 @@ -package bson - -import ( - "errors" - "reflect" -) - -// A Decoderv2 reads and decodes BSON documents from a stream. -type Decoderv2 struct { - r *Registry - vr ValueReader -} - -// NewDecoderv2 returns a new decoder that uses Registry reg to read from r. -func NewDecoderv2(r *Registry, vr ValueReader) (*Decoderv2, error) { - if r == nil { - return nil, errors.New("cannot create a new Decoder with a nil Registry") - } - if vr == nil { - return nil, errors.New("cannot create a new Decoder with a nil ValueReader") - } - - return &Decoderv2{ - r: r, - vr: vr, - }, nil -} - -// Decode reads the next BSON document from the stream and decodes it into the -// value pointed to by val. -// -// The documentation for Unmarshal contains details about of BSON into a Go -// value. -func (d *Decoderv2) Decode(val interface{}) error { - codec, err := d.r.Lookup(reflect.TypeOf(val)) - if err != nil { - return err - } - return codec.DecodeValue(DecodeContext{Registry: d.r}, d.vr, val) -} - -// Reset will reset the state of the decoder, using the same *Registry used in -// the original construction but using r for reading. -func (d *Decoderv2) Reset(vr ValueReader) error { - d.vr = vr - return nil -} - -// SetRegistry replaces the current registry of the decoder with r. -func (d *Decoderv2) SetRegistry(r *Registry) error { - d.r = r - return nil -} diff --git a/bson/element.go b/bson/element.go index de7380acdb..e822127062 100644 --- a/bson/element.go +++ b/bson/element.go @@ -354,7 +354,8 @@ func (e *Element) String() string { return fmt.Sprintf(`bson.Element{[%s]"%s": %v}`, e.Value().Type(), e.Key(), val) } -func (e *Element) equal(e2 *Element) bool { +// Equal compares this element to element and returns true if they are equal. +func (e *Element) Equal(e2 *Element) bool { if e == nil && e2 == nil { return true } @@ -362,7 +363,10 @@ func (e *Element) equal(e2 *Element) bool { return false } - return e.value.equal(e2.value) + if e.Key() != e2.Key() { + return false + } + return e.value.Equal(e2.value) } func elemsFromValues(values []*Value) []*Element { @@ -396,7 +400,10 @@ func convertValueToElem(key string, v *Value) *Element { elem := newElement(0, uint32(keyLen+2)) elem.value.data = d - elem.value.d = v.d + elem.value.d = nil + if v.d != nil { + elem.value.d = v.d.Copy() + } return elem } diff --git a/bson/element_test.go b/bson/element_test.go index 8a30e13c9a..663e009123 100644 --- a/bson/element_test.go +++ b/bson/element_test.go @@ -1739,7 +1739,7 @@ func testConvertValueToElem(t *testing.T) { got := convertValueToElem(tc.key, tc.val) want := tc.elem - if !got.equal(want) { + if !got.Equal(want) { t.Errorf("Expected elements to be equal but they are not. got %v; want %v", got, want) } }) diff --git a/bson/empty_interface_codec.go b/bson/empty_interface_codec.go deleted file mode 100644 index 8914d479e3..0000000000 --- a/bson/empty_interface_codec.go +++ /dev/null @@ -1,144 +0,0 @@ -package bson - -import ( - "fmt" - "reflect" - - "github.com/mongodb/mongo-go-driver/bson/decimal" - "github.com/mongodb/mongo-go-driver/bson/objectid" -) - -var defaultEmptyInterfaceCodec = &EmptyInterfaceCodec{} - -// EmptyInterfaceCodec is the Codec used for empty interface (interface{}) -// values. -type EmptyInterfaceCodec struct{} - -var _ Codec = &EmptyInterfaceCodec{} - -// EncodeValue implements the Codec interface. -func (eic *EmptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - codec, err := ec.Lookup(reflect.TypeOf(i)) - if err != nil { - return err - } - - return codec.EncodeValue(ec, vw, i) -} - -// DecodeValue implements the Codec interface. -func (eic *EmptyInterfaceCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - target, ok := i.(*interface{}) - if !ok || target == nil { - return fmt.Errorf("%T can only be used to decode non-nil *interface{} values, provided type if %T", eic, i) - } - - // fn is a function we call to assign val back to the target, we do this so - // we can keep down on the repeated code in this method. In all of the - // implementations this is a closure, so we don't need to provide the - // target as a parameter. - var fn func() - var val interface{} - var rtype reflect.Type - - switch vr.Type() { - case TypeDouble: - val = new(float64) - rtype = tFloat64 - fn = func() { *target = *(val.(*float64)) } - case TypeString: - val = new(string) - rtype = tString - fn = func() { *target = *(val.(*string)) } - case TypeEmbeddedDocument: - val = NewDocument() - rtype = tDocument - fn = func() { *target = val.(*Document) } - case TypeArray: - val = NewArray() - rtype = tArray - fn = func() { *target = val.(*Array) } - case TypeBinary: - val = new(Binary) - rtype = tBinary - fn = func() { *target = *(val.(*Binary)) } - case TypeUndefined: - val = new(Undefinedv2) - rtype = tUndefined - fn = func() { *target = *(val.(*Undefinedv2)) } - case TypeObjectID: - val = new(objectid.ObjectID) - rtype = tOID - fn = func() { *target = *(val.(*objectid.ObjectID)) } - case TypeBoolean: - val = new(bool) - rtype = tBool - fn = func() { *target = *(val.(*bool)) } - case TypeDateTime: - val = new(DateTime) - rtype = tDateTime - fn = func() { *target = *(val.(*DateTime)) } - case TypeNull: - val = new(Nullv2) - rtype = tNull - fn = func() { *target = *(val.(*Nullv2)) } - case TypeRegex: - val = new(Regex) - rtype = tRegex - fn = func() { *target = *(val.(*Regex)) } - case TypeDBPointer: - val = new(DBPointer) - rtype = tDBPointer - fn = func() { *target = *(val.(*DBPointer)) } - case TypeJavaScript: - val = new(JavaScriptCode) - rtype = tJavaScriptCode - fn = func() { *target = *(val.(*JavaScriptCode)) } - case TypeSymbol: - val = new(Symbol) - rtype = tSymbol - fn = func() { *target = *(val.(*Symbol)) } - case TypeCodeWithScope: - val = new(CodeWithScope) - rtype = tCodeWithScope - fn = func() { *target = *(val.(*CodeWithScope)) } - case TypeInt32: - val = new(int32) - rtype = tInt32 - fn = func() { *target = *(val.(*int32)) } - case TypeInt64: - val = new(int64) - rtype = tInt64 - fn = func() { *target = *(val.(*int64)) } - case TypeTimestamp: - val = new(Timestamp) - rtype = tTimestamp - fn = func() { *target = *(val.(*Timestamp)) } - case TypeDecimal128: - val = new(decimal.Decimal128) - rtype = tDecimal - fn = func() { *target = *(val.(*decimal.Decimal128)) } - case TypeMinKey: - val = new(MinKeyv2) - rtype = tMinKey - fn = func() { *target = *(val.(*MinKeyv2)) } - case TypeMaxKey: - val = new(MaxKeyv2) - rtype = tMaxKey - fn = func() { *target = *(val.(*MaxKeyv2)) } - default: - return fmt.Errorf("Type %s is not a valid BSON type and has no default Go type to decode into", vr.Type()) - } - - codec, err := dc.Lookup(rtype) - if err != nil { - return err - } - err = codec.DecodeValue(dc, vr, val) - if err != nil { - return err - } - - fn() - return nil -} diff --git a/bson/empty_interface_codec_test.go b/bson/empty_interface_codec_test.go deleted file mode 100644 index 0e33862607..0000000000 --- a/bson/empty_interface_codec_test.go +++ /dev/null @@ -1,310 +0,0 @@ -package bson - -import ( - "errors" - "fmt" - "reflect" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/mongodb/mongo-go-driver/bson/decimal" - "github.com/mongodb/mongo-go-driver/bson/objectid" -) - -func TestEmptyInterfaceCodec(t *testing.T) { - testCases := []struct { - name string - val interface{} - bsontype Type - }{ - { - "Double - float64", - float64(3.14159), - TypeDouble, - }, - { - "String - string", - string("foo bar baz"), - TypeString, - }, - { - "Embedded Document - *Document", - NewDocument(EC.Null("foo")), - TypeEmbeddedDocument, - }, - { - "Array - *Array", - NewArray(VC.Double(3.14159)), - TypeArray, - }, - { - "Binary - Binary", - Binary{Subtype: 0xFF, Data: []byte{0x01, 0x02, 0x03}}, - TypeBinary, - }, - { - "Undefined - Undefined", - Undefinedv2{}, - TypeUndefined, - }, - { - "ObjectID - objectid.ObjectID", - objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, - TypeObjectID, - }, - { - "Boolean - bool", - bool(true), - TypeBoolean, - }, - { - "DateTime - DateTime", - DateTime(1234567890), - TypeDateTime, - }, - { - "Null - Null", - Nullv2{}, - TypeNull, - }, - { - "Regex - Regex", - Regex{Pattern: "foo", Options: "bar"}, - TypeRegex, - }, - { - "DBPointer - DBPointer", - DBPointer{ - DB: "foobar", - Pointer: objectid.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}, - }, - TypeDBPointer, - }, - { - "JavaScript - JavaScriptCode", - JavaScriptCode("var foo = 'bar';"), - TypeJavaScript, - }, - { - "Symbol - Symbol", - Symbol("foobarbazlolz"), - TypeSymbol, - }, - { - "CodeWithScope - CodeWithScope", - CodeWithScope{ - Code: "var foo = 'bar';", - Scope: NewDocument(EC.Double("foo", 3.14159)), - }, - TypeCodeWithScope, - }, - { - "Int32 - int32", - int32(123456), - TypeInt32, - }, - { - "Int64 - int64", - int64(1234567890), - TypeInt64, - }, - { - "Timestamp - Timestamp", - Timestamp{T: 12345, I: 67890}, - TypeTimestamp, - }, - { - "Decimal128 - decimal.Decimal128", - decimal.NewDecimal128(12345, 67890), - TypeDecimal128, - }, - { - "MinKey - MinKey", - MinKeyv2{}, - TypeMinKey, - }, - { - "MaxKey - MaxKey", - MaxKeyv2{}, - TypeMaxKey, - }, - } - - t.Run("EncodeValue", func(t *testing.T) { - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - llvr := &llValueReaderWriter{bsontype: tc.bsontype} - eic := &EmptyInterfaceCodec{} - - t.Run("Lookup failure", func(t *testing.T) { - ec := EncodeContext{Registry: NewEmptyRegistryBuilder().Build()} - want := ErrNoCodec{Type: reflect.TypeOf(tc.val)} - got := eic.EncodeValue(ec, llvr, tc.val) - if !compareErrors(got, want) { - t.Errorf("Errors are not equal. got %v; want %v", got, want) - } - }) - - t.Run("Success", func(t *testing.T) { - want := tc.val - llc := &llCodec{t: t} - ec := EncodeContext{ - Registry: NewEmptyRegistryBuilder().Register(reflect.TypeOf(tc.val), llc).Build(), - } - err := eic.EncodeValue(ec, llvr, tc.val) - noerr(t, err) - got := llc.encodeval - if !cmp.Equal(got, want, cmp.Comparer(compareDecimal128)) { - t.Errorf("Did not receive expected value. got %v; want %v", got, want) - } - }) - }) - } - }) - - t.Run("DecodeValue", func(t *testing.T) { - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - llvr := &llValueReaderWriter{bsontype: tc.bsontype} - eic := &EmptyInterfaceCodec{} - - t.Run("Lookup failure", func(t *testing.T) { - val := new(interface{}) - dc := DecodeContext{Registry: NewEmptyRegistryBuilder().Build()} - want := ErrNoCodec{Type: reflect.TypeOf(tc.val)} - got := eic.DecodeValue(dc, llvr, val) - if !compareErrors(got, want) { - t.Errorf("Errors are not equal. got %v; want %v", got, want) - } - }) - - t.Run("DecodeValue failure", func(t *testing.T) { - want := errors.New("DecodeValue failure error") - llc := &llCodec{t: t, err: want} - dc := DecodeContext{ - Registry: NewEmptyRegistryBuilder().Register(reflect.TypeOf(tc.val), llc).Build(), - } - got := eic.DecodeValue(dc, llvr, new(interface{})) - if !compareErrors(got, want) { - t.Errorf("Errors are not equal. got %v; want %v", got, want) - } - }) - - t.Run("Success", func(t *testing.T) { - want := tc.val - llc := &llCodec{t: t, decodeval: tc.val} - dc := DecodeContext{ - Registry: NewEmptyRegistryBuilder().Register(reflect.TypeOf(tc.val), llc).Build(), - } - got := new(interface{}) - err := eic.DecodeValue(dc, llvr, got) - noerr(t, err) - if !cmp.Equal(*got, want, cmp.Comparer(compareDecimal128)) { - t.Errorf("Did not receive expected value. got %v; want %v", *got, want) - } - }) - }) - } - - t.Run("non-*interface{}", func(t *testing.T) { - eic := &EmptyInterfaceCodec{} - val := uint64(1234567890) - want := fmt.Errorf("%T can only be used to decode non-nil *interface{} values, provided type if %T", eic, &val) - got := eic.DecodeValue(DecodeContext{}, nil, &val) - if !compareErrors(got, want) { - t.Errorf("Errors are not equal. got %v; want %v", got, want) - } - }) - - t.Run("nil *interface{}", func(t *testing.T) { - eic := &EmptyInterfaceCodec{} - var val *interface{} - want := fmt.Errorf("%T can only be used to decode non-nil *interface{} values, provided type if %T", eic, val) - got := eic.DecodeValue(DecodeContext{}, nil, val) - if !compareErrors(got, want) { - t.Errorf("Errors are not equal. got %v; want %v", got, want) - } - }) - - t.Run("unknown BSON type", func(t *testing.T) { - llvr := &llValueReaderWriter{bsontype: Type(0)} - eic := &EmptyInterfaceCodec{} - want := fmt.Errorf("Type %s is not a valid BSON type and has no default Go type to decode into", Type(0)) - got := eic.DecodeValue(DecodeContext{}, llvr, new(interface{})) - if !compareErrors(got, want) { - t.Errorf("Errors are not equal. got %v; want %v", got, want) - } - }) - }) -} - -type llCodec struct { - t *testing.T - decodeval interface{} - encodeval interface{} - err error -} - -func (llc *llCodec) EncodeValue(_ EncodeContext, _ ValueWriter, i interface{}) error { - if llc.err != nil { - return llc.err - } - - llc.encodeval = i - return nil -} - -func (llc *llCodec) DecodeValue(_ DecodeContext, _ ValueReader, i interface{}) error { - if llc.err != nil { - return llc.err - } - - val := reflect.ValueOf(i) - if val.Type().Kind() != reflect.Ptr { - llc.t.Errorf("Value provided to DecodeValue must be a pointer, but got %T", i) - return nil - } - - switch val.Type() { - case tDocument: - decodeval, ok := llc.decodeval.(*Document) - if !ok { - llc.t.Errorf("decodeval must be a *Document if the i is a *Document. decodeval %T", llc.decodeval) - return nil - } - - doc := i.(*Document) - doc.Reset() - err := doc.Concat(decodeval) - if err != nil { - llc.t.Errorf("could not concatenate the decoded val to doc: %v", err) - return err - } - - return nil - case tArray: - decodeval, ok := llc.decodeval.(*Array) - if !ok { - llc.t.Errorf("decodeval must be a *Array if the i is a *Array. decodeval %T", llc.decodeval) - return nil - } - - arr := i.(*Array) - arr.Reset() - err := arr.Concat(decodeval) - if err != nil { - llc.t.Errorf("could not concatenate the decoded val to array: %v", err) - return err - } - - return nil - } - - if !reflect.TypeOf(llc.decodeval).AssignableTo(val.Type().Elem()) { - llc.t.Errorf("decodeval must be assignable to i provided to DecodeValue, but is not. decodeval %T; i %T", llc.decodeval, i) - return nil - } - - val.Elem().Set(reflect.ValueOf(llc.decodeval)) - return nil -} diff --git a/bson/encode.go b/bson/encode.go index 73db05472b..72aa80d478 100644 --- a/bson/encode.go +++ b/bson/encode.go @@ -5,783 +5,3 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 package bson - -import ( - "encoding/json" - "errors" - "fmt" - "io" - "math" - "net/url" - "reflect" - "strconv" - "strings" - "time" - - "github.com/mongodb/mongo-go-driver/bson/decimal" - "github.com/mongodb/mongo-go-driver/bson/objectid" -) - -// ErrEncoderNilWriter indicates that encoder.Encode was called with a nil argument. -var ErrEncoderNilWriter = errors.New("encoder.Encode called on Encoder with nil io.Writer") - -var tElementSlice = reflect.TypeOf(([]*Element)(nil)) -var tByteSlice = reflect.TypeOf(([]byte)(nil)) -var tByte = reflect.TypeOf(byte(0x00)) -var tElement = reflect.TypeOf((*Element)(nil)) -var tURL = reflect.TypeOf(url.URL{}) -var tJSONNumber = reflect.TypeOf(json.Number("")) - -// Marshaler describes a type that can marshal a BSON representation of itself into bytes. -type Marshaler interface { - MarshalBSON() ([]byte, error) -} - -// DocumentMarshaler describes a type that can marshal itself into a bson.Document. -type DocumentMarshaler interface { - MarshalBSONDocument() (*Document, error) -} - -// ElementMarshaler describes a type that can marshal itself into a bson.Element. -type ElementMarshaler interface { - MarshalBSONElement() (*Element, error) -} - -// ValueMarshaler describes a type that can marshal itself into a bson.Value. -type ValueMarshaler interface { - MarshalBSONValue() (*Value, error) -} - -// Encoder describes a type that can encode itself into a value. -type Encoder interface { - // Encode encodes a value from an io.Writer into the given value. - // - // The value can be any one of the following types: - // - // - bson.Marshaler - // - io.Reader - // - []byte - // - bson.Reader - // - any map with string keys - // - a struct (possibly with tags) - // - // In the case of a struct, the lowercased field name is used as the key for each exported - // field but this behavior may be changed using a struct tag. The tag may also contain flags to - // adjust the marshalling behavior for the field. The tag formats accepted are: - // - // "[][,[,]]" - // - // `(...) bson:"[][,[,]]" (...)` - // - // The following flags are currently supported: - // - // omitempty Only include the field if it's not set to the zero value for the type or to - // empty slices or maps. - // - // minsize Marshal an integer of a type larger than 32 bits value as an int32, if that's - // feasible while preserving the numeric value. - // - // inline Inline the field, which must be a struct or a map, causing all of its fields - // or keys to be processed as if they were part of the outer struct. - // - // An example: - // - // type T struct { - // A bool - // B int "myb" - // C string "myc,omitempty" - // D string `bson:",omitempty" json:"jsonkey"` - // E int64 ",minsize" - // F int64 "myf,omitempty,minsize" - // } - Encode(interface{}) error -} - -// DocumentEncoder describes a type that can marshal itself into a value and return the bson.Document it represents. -type DocumentEncoder interface { - EncodeDocument(interface{}) (*Document, error) -} - -type encoder struct { - w io.Writer -} - -// NewEncoder creates an encoder that writes to w. -func NewEncoder(w io.Writer) Encoder { - return &encoder{w: w} -} - -// NewDocumentEncoder creates an encoder that encodes into a *Document. -func NewDocumentEncoder() DocumentEncoder { - return &encoder{} -} - -func convertTimeToInt64(t time.Time) int64 { - return t.Unix()*1000 + int64(t.Nanosecond()/1e6) -} - -func (e *encoder) Encode(v interface{}) error { - var err error - - if e.w == nil { - return ErrEncoderNilWriter - } - - switch t := v.(type) { - case Marshaler: - var b []byte - b, err = t.MarshalBSON() - if err != nil { - return err - } - _, err = Reader(b).Validate() - if err != nil { - return err - } - _, err = e.w.Write(b) - return err - case io.Reader: - var r Reader - r, err = NewFromIOReader(t) - if err != nil { - return err - } - - _, err = r.Validate() - if err != nil { - return err - } - _, err = e.w.Write(r) - case []byte: - _, err = Reader(t).Validate() - if err != nil { - return err - } - _, err = e.w.Write(t) - - case Reader: - _, err = t.Validate() - if err != nil { - return err - } - _, err = e.w.Write(t) - default: - var elems []*Element - rval := reflect.ValueOf(v) - elems, err = e.reflectEncode(rval) - if err != nil { - return err - } - _, err = NewDocument(elems...).WriteTo(e.w) - } - - return err -} - -// EncodeDocument encodes a value from an io.Writer into the given value and returns the document -// it represents. -// -// EncodeDocument accepts the same types as Encoder.Encode. -func (e *encoder) EncodeDocument(v interface{}) (*Document, error) { - var err error - - d := NewDocument() - - switch t := v.(type) { - case *Document: - err = d.Concat(t) - case Marshaler: - var b []byte - b, err = t.MarshalBSON() - if err != nil { - return nil, err - } - _, err = Reader(b).Validate() - if err != nil { - return nil, err - } - err = d.Concat(b) - case io.Reader: - var r Reader - r, err = NewFromIOReader(t) - if err != nil { - return nil, err - } - - _, err = r.Validate() - if err != nil { - return nil, err - } - err = d.Concat(r) - case []byte: - _, err = Reader(t).Validate() - if err != nil { - return nil, err - } - err = d.Concat(t) - case Reader: - _, err = t.Validate() - if err != nil { - return nil, err - } - err = d.Concat(t) - default: - var elems []*Element - rval := reflect.ValueOf(v) - elems, err = e.reflectEncode(rval) - if err != nil { - return nil, err - } - d.Append(elems...) - } - - if err != nil { - return nil, err - } - - return d, nil -} - -// underlyingVal will unwrap the given reflect.Value until it is not a pointer -// nor an interface. -func (e *encoder) underlyingVal(val reflect.Value) reflect.Value { - if val.Kind() != reflect.Ptr && val.Kind() != reflect.Interface { - return val - } - if val.IsNil() { - return val - } - return e.underlyingVal(val.Elem()) -} - -func (e *encoder) reflectEncode(val reflect.Value) ([]*Element, error) { - val = e.underlyingVal(val) - - var elems []*Element - var err error - switch val.Kind() { - case reflect.Map: - elems, err = e.encodeMap(val) - case reflect.Slice, reflect.Array: - elems, err = e.encodeSlice(val) - case reflect.Struct: - elems, err = e.encodeStruct(val) - default: - err = fmt.Errorf("Cannot encode type %s as a BSON Document", val.Type()) - } - - if err != nil { - return nil, err - } - - return elems, nil -} - -func (e *encoder) encodeMap(val reflect.Value) ([]*Element, error) { - mapkeys := val.MapKeys() - elems := make([]*Element, 0, val.Len()) - for _, rkey := range mapkeys { - orig := rkey - rkey = e.underlyingVal(rkey) - - var key string - switch rkey.Kind() { - case reflect.Bool: - key = strconv.FormatBool(rkey.Bool()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - key = strconv.FormatInt(rkey.Int(), 10) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - key = strconv.FormatUint(rkey.Uint(), 10) - case reflect.Float32: - key = strconv.FormatFloat(rkey.Float(), 'g', -1, 32) - case reflect.Float64: - key = strconv.FormatFloat(rkey.Float(), 'g', -1, 64) - case reflect.Complex64, reflect.Complex128: - key = fmt.Sprintf("%g", rkey.Complex()) - case reflect.String: - key = rkey.String() - default: - switch rkey.Type() { - case tOID: - key = fmt.Sprintf("%s", rkey.Interface()) - case tURL: - rkey = orig - key = fmt.Sprintf("%s", rkey.Interface()) - case tDecimal: - key = fmt.Sprintf("%s", rkey.Interface()) - default: - return nil, fmt.Errorf("Unsupported map key type %s", rkey.Kind()) - } - } - - rval := val.MapIndex(rkey) - - switch t := rval.Interface().(type) { - case *Element: - elems = append(elems, t) - continue - case *Document: - elems = append(elems, EC.SubDocument(key, t)) - continue - case Reader: - elems = append(elems, EC.SubDocumentFromReader(key, t)) - continue - case json.Number: - // We try to do an int first - if i64, err := t.Int64(); err == nil { - elems = append(elems, EC.Int64(key, i64)) - continue - } - f64, err := t.Float64() - if err != nil { - return nil, fmt.Errorf("Invalid json.Number used as map value: %s", err) - } - elems = append(elems, EC.Double(key, f64)) - continue - case *url.URL: - elems = append(elems, EC.String(key, t.String())) - continue - case decimal.Decimal128: - elems = append(elems, EC.Decimal128(key, t)) - continue - } - rval = e.underlyingVal(rval) - - elem, err := e.elemFromValue(key, rval, false) - if err != nil { - return nil, err - } - elems = append(elems, elem) - } - return elems, nil -} - -func (e *encoder) encodeSlice(val reflect.Value) ([]*Element, error) { - elems := make([]*Element, 0, val.Len()) - for i := 0; i < val.Len(); i++ { - sval := val.Index(i) - key := strconv.Itoa(i) - switch t := sval.Interface().(type) { - case *Element: - elems = append(elems, t) - continue - case *Document: - elems = append(elems, EC.SubDocument(key, t)) - continue - case Reader: - elems = append(elems, EC.SubDocumentFromReader(key, t)) - continue - case json.Number: - // We try to do an int first - if i64, err := t.Int64(); err == nil { - elems = append(elems, EC.Int64(key, i64)) - continue - } - f64, err := t.Float64() - if err != nil { - return nil, fmt.Errorf("Invalid json.Number used as map value: %s", err) - } - elems = append(elems, EC.Double(key, f64)) - continue - case *url.URL: - elems = append(elems, EC.String(key, t.String())) - continue - case decimal.Decimal128: - elems = append(elems, EC.Decimal128(key, t)) - continue - } - sval = e.underlyingVal(sval) - elem, err := e.elemFromValue(key, sval, false) - if err != nil { - return nil, err - } - elems = append(elems, elem) - } - return elems, nil -} - -func (e *encoder) encodeSliceAsArray(rval reflect.Value, minsize bool) ([]*Value, error) { - vals := make([]*Value, 0, rval.Len()) - for i := 0; i < rval.Len(); i++ { - sval := rval.Index(i) - switch t := sval.Interface().(type) { - case *Element: - vals = append(vals, t.value) - continue - case *Value: - vals = append(vals, t) - continue - case *Document: - vals = append(vals, VC.Document(t)) - continue - case Reader: - vals = append(vals, VC.DocumentFromReader(t)) - continue - case json.Number: - // We try to do an int first - if i64, err := t.Int64(); err == nil { - vals = append(vals, VC.Int64(i64)) - continue - } - f64, err := t.Float64() - if err != nil { - return nil, fmt.Errorf("Invalid json.Number used as map value: %s", err) - } - vals = append(vals, VC.Double(f64)) - continue - case *url.URL: - vals = append(vals, VC.String(t.String())) - continue - case decimal.Decimal128: - vals = append(vals, VC.Decimal128(t)) - continue - case time.Time: - vals = append(vals, VC.DateTime(convertTimeToInt64(t))) - continue - case *time.Time: - if t == nil { - vals = append(vals, VC.Null()) - } else { - vals = append(vals, VC.DateTime(convertTimeToInt64(*t))) - } - continue - } - - sval = e.underlyingVal(sval) - val, err := e.valueFromValue(sval, minsize) - if err != nil { - return nil, err - } - vals = append(vals, val) - } - return vals, nil -} - -func (e *encoder) encodeStruct(val reflect.Value) ([]*Element, error) { - elems := make([]*Element, 0, val.NumField()) - sType := val.Type() - - for i := 0; i < val.NumField(); i++ { - sf := sType.Field(i) - if sf.PkgPath != "" { - continue - } - key := strings.ToLower(sf.Name) - tag, ok := sf.Tag.Lookup("bson") - var omitempty, minsize, inline = false, false, false - switch { - case ok: - if tag == "-" { - continue - } - for idx, str := range strings.Split(tag, ",") { - if idx == 0 && str != "" { - key = str - } - switch str { - case "omitempty": - omitempty = true - case "minsize": - minsize = true - case "inline": - inline = true - } - } - case !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0: - key = string(sf.Tag) - } - - field := val.Field(i) - - if omitempty && e.isZero(field) { - continue - } - - switch t := field.Interface().(type) { - case *Element: - elems = append(elems, t) - continue - case *Document: - elems = append(elems, EC.SubDocument(key, t)) - continue - case Reader: - elems = append(elems, EC.SubDocumentFromReader(key, t)) - continue - case json.Number: - // We try to do an int first - if i64, err := t.Int64(); err == nil { - elems = append(elems, EC.Int64(key, i64)) - continue - } - f64, err := t.Float64() - if err != nil { - return nil, fmt.Errorf("Invalid json.Number used as map value: %s", err) - } - elems = append(elems, EC.Double(key, f64)) - continue - case *url.URL: - elems = append(elems, EC.String(key, t.String())) - continue - case decimal.Decimal128: - elems = append(elems, EC.Decimal128(key, t)) - continue - case time.Time: - elems = append(elems, EC.DateTime(key, convertTimeToInt64(t))) - continue - case *time.Time: - if t == nil { - elems = append(elems, EC.Null(key)) - } else { - elems = append(elems, EC.DateTime(key, convertTimeToInt64(*t))) - } - continue - } - field = e.underlyingVal(field) - - if inline { - switch sf.Type.Kind() { - case reflect.Map: - melems, err := e.encodeMap(field) - if err != nil { - return nil, err - } - elems = append(elems, melems...) - continue - case reflect.Struct: - selems, err := e.encodeStruct(field) - if err != nil { - return nil, err - } - elems = append(elems, selems...) - continue - default: - return nil, errors.New("inline is only supported for map and struct types") - } - } - - elem, err := e.elemFromValue(key, field, minsize) - if err != nil { - return nil, err - } - elems = append(elems, elem) - } - return elems, nil -} - -func (e *encoder) isZero(v reflect.Value) bool { - switch v.Kind() { - case reflect.Array, reflect.Map, reflect.Slice, reflect.String: - return v.Len() == 0 - case reflect.Bool: - return !v.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return v.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return v.Uint() == 0 - case reflect.Float32, reflect.Float64: - return v.Float() == 0 - case reflect.Interface, reflect.Ptr: - return v.IsNil() - case reflect.Struct: - if z, ok := v.Interface().(Zeroer); ok { - return z.IsZero() - } - return false - } - - return false -} - -func (e *encoder) elemFromValue(key string, val reflect.Value, minsize bool) (*Element, error) { - var elem *Element - switch val.Kind() { - case reflect.Interface, reflect.Ptr: - if !val.IsNil() { - return nil, errors.New("Values must be unwrapped when calling elemFromValue, try calling underlyingVal first") - } - elem = EC.Null(key) - case reflect.Bool: - elem = EC.Boolean(key, val.Bool()) - case reflect.Int8, reflect.Int16, reflect.Int32: - elem = EC.Int32(key, int32(val.Int())) - case reflect.Int, reflect.Int64: - i := val.Int() - if minsize && i < math.MaxInt32 { - elem = EC.Int32(key, int32(val.Int())) - break - } - elem = EC.Int64(key, val.Int()) - case reflect.Uint8, reflect.Uint16: - i := val.Uint() - elem = EC.Int32(key, int32(i)) - case reflect.Uint, reflect.Uint32, reflect.Uint64: - i := val.Uint() - switch { - case i < math.MaxInt32 && minsize: - elem = EC.Int32(key, int32(i)) - case i < math.MaxInt64: - elem = EC.Int64(key, int64(i)) - default: - return nil, fmt.Errorf("BSON only has signed integer types and %d overflows an int64", i) - } - case reflect.Float32, reflect.Float64: - elem = EC.Double(key, val.Float()) - case reflect.String: - elem = EC.String(key, val.String()) - case reflect.Map: - // We specifically check if the value is nil so we can properly round trip. - // If we didn't do this, we couldn't differentiate between an empty map, which should - // be a document, and a nil map, which should be null. In Go, there is a difference - // between the empty value and nil, we should preserve that. - if val.IsNil() { - elem = EC.Null(key) - break - } - mapElems, err := e.encodeMap(val) - if err != nil { - return nil, err - } - elem = EC.SubDocumentFromElements(key, mapElems...) - case reflect.Slice: - // We specifically check if the value is nil so we can properly round trip. - // If we didn't do this, we couldn't differentiate between an empty slice, which should - // be an array, and a nil slice, which should be null. In Go, there is a difference - // between the empty value and nil, we should preserve that. - if val.IsNil() { - elem = EC.Null(key) - break - } - if val.Type() == tByteSlice { - elem = EC.Binary(key, val.Slice(0, val.Len()).Interface().([]byte)) - break - } - sliceElems, err := e.encodeSliceAsArray(val, minsize) - if err != nil { - return nil, err - } - elem = EC.ArrayFromElements(key, sliceElems...) - case reflect.Array: - switch { - case val.Type() == tOID: - elem = EC.ObjectID(key, val.Interface().(objectid.ObjectID)) - case val.Type().Elem() == tByte: - b := make([]byte, val.Len()) - for i := 0; i < val.Len(); i++ { - b[i] = byte(val.Index(i).Uint()) - } - elem = EC.Binary(key, b) - default: - arrayElems, err := e.encodeSliceAsArray(val, minsize) - if err != nil { - return nil, err - } - elem = EC.ArrayFromElements(key, arrayElems...) - } - case reflect.Struct: - structElems, err := e.encodeStruct(val) - if err != nil { - return nil, err - } - elem = EC.SubDocumentFromElements(key, structElems...) - default: - return nil, fmt.Errorf("Unsupported value type %s", val.Kind()) - } - return elem, nil -} - -func (e *encoder) valueFromValue(val reflect.Value, minsize bool) (*Value, error) { - var elem *Value - switch val.Kind() { - case reflect.Interface, reflect.Ptr: - if !val.IsNil() { - return nil, errors.New("Values must be unwrapped when calling elemFromValue, try calling underlyingVal first") - } - elem = VC.Null() - case reflect.Bool: - elem = VC.Boolean(val.Bool()) - case reflect.Int8, reflect.Int16, reflect.Int32: - elem = VC.Int32(int32(val.Int())) - case reflect.Int, reflect.Int64: - i := val.Int() - if minsize && i < math.MaxInt32 { - elem = VC.Int32(int32(val.Int())) - break - } - elem = VC.Int64(val.Int()) - case reflect.Uint8, reflect.Uint16: - i := val.Uint() - elem = VC.Int32(int32(i)) - case reflect.Uint, reflect.Uint32, reflect.Uint64: - i := val.Uint() - switch { - case i < math.MaxInt32 && minsize: - elem = VC.Int32(int32(i)) - case i < math.MaxInt64: - elem = VC.Int64(int64(i)) - default: - return nil, fmt.Errorf("BSON only has signed integer types and %d overflows an int64", i) - } - case reflect.Float32, reflect.Float64: - elem = VC.Double(val.Float()) - case reflect.String: - elem = VC.String(val.String()) - case reflect.Map: - // We specifically check if the value is nil so we can properly round trip. - // If we didn't do this, we couldn't differentiate between an empty map, which should - // be a document, and a nil map, which should be null. In Go, there is a difference - // between the empty value and nil, we should preserve that. - if val.IsNil() { - elem = VC.Null() - break - } - mapElems, err := e.encodeMap(val) - if err != nil { - return nil, err - } - elem = VC.DocumentFromElements(mapElems...) - case reflect.Slice: - // We specifically check if the value is nil so we can properly round trip. - // If we didn't do this, we couldn't differentiate between an empty slice, which should - // be an array, and a nil slice, which should be null. In Go, there is a difference - // between the empty value and nil, we should preserve that. - if val.IsNil() { - elem = VC.Null() - break - } - if val.Type() == tByteSlice { - elem = VC.Binary(val.Slice(0, val.Len()).Interface().([]byte)) - break - } - sliceElems, err := e.encodeSliceAsArray(val, minsize) - if err != nil { - return nil, err - } - elem = VC.ArrayFromValues(sliceElems...) - case reflect.Array: - switch { - case val.Type() == tOID: - elem = VC.ObjectID(val.Interface().(objectid.ObjectID)) - case val.Type().Elem() == tByte: - b := make([]byte, val.Len()) - for i := 0; i < val.Len(); i++ { - b[i] = byte(val.Index(i).Uint()) - } - elem = VC.Binary(b) - default: - arrayElems, err := e.encodeSliceAsArray(val, minsize) - if err != nil { - return nil, err - } - elem = VC.ArrayFromValues(arrayElems...) - } - case reflect.Struct: - structElems, err := e.encodeStruct(val) - if err != nil { - return nil, err - } - elem = VC.DocumentFromElements(structElems...) - default: - return nil, fmt.Errorf("Unsupported value type %s", val.Kind()) - } - return elem, nil -} diff --git a/bson/encode_test.go b/bson/encode_test.go index bbaf32f81c..e50a9b6e8a 100644 --- a/bson/encode_test.go +++ b/bson/encode_test.go @@ -6,1120 +6,6 @@ package bson -import ( - "bytes" - "encoding/json" - "io" - "net/url" - "reflect" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/mongodb/mongo-go-driver/bson/decimal" - "github.com/mongodb/mongo-go-driver/bson/objectid" - "github.com/stretchr/testify/assert" -) - -func TestEncoder(t *testing.T) { - t.Run("Writer/Marshaler", func(t *testing.T) { - testCases := []struct { - name string - m Marshaler - b []byte - err error - }{ - { - "success", - NewDocument(EC.Null("foo")), - []byte{ - 0x0A, 0x00, 0x00, 0x00, - 0x0A, 'f', 'o', 'o', 0x00, - 0x00, - }, - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var buf bytes.Buffer - enc := NewEncoder(&buf) - err := enc.Encode(tc.m) - if err != tc.err { - t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) - } - b := buf.Bytes() - if diff := cmp.Diff(tc.b, b); diff != "" { - t.Errorf("Bytes written differ: (-got +want)\n%s", diff) - } - }) - } - }) - t.Run("Document/Document", func(t *testing.T) { - testCases := []struct { - name string - d *Document - want *Document - err error - }{ - { - "success", - NewDocument(EC.Null("foo")), - NewDocument(EC.Null("foo")), - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - enc := NewDocumentEncoder() - got, err := enc.EncodeDocument(tc.d) - if err != tc.err { - t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) - } - if diff := cmp.Diff(got, tc.want, cmp.AllowUnexported(Document{}, Element{}, Value{})); diff != "" { - t.Errorf("Documents differ: (-got +want)\n%s", diff) - } - }) - } - }) - t.Run("Writer/io.Reader", func(t *testing.T) { - testCases := []struct { - name string - m io.Reader - b []byte - err error - }{ - { - "success", - bytes.NewReader([]byte{ - 0x0A, 0x00, 0x00, 0x00, - 0x0A, 'f', 'o', 'o', 0x00, - 0x00, - }), - []byte{ - 0x0A, 0x00, 0x00, 0x00, - 0x0A, 'f', 'o', 'o', 0x00, - 0x00, - }, - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var buf bytes.Buffer - enc := NewEncoder(&buf) - err := enc.Encode(tc.m) - if err != tc.err { - t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) - } - b := buf.Bytes() - if diff := cmp.Diff(tc.b, b); diff != "" { - t.Errorf("Bytes written differ: (-got +want)\n%s", diff) - } - }) - } - }) - t.Run("Document/io.Reader", func(t *testing.T) { - testCases := []struct { - name string - m io.Reader - want *Document - err error - }{ - { - "success", - bytes.NewReader([]byte{ - 0x0A, 0x00, 0x00, 0x00, - 0x0A, 'f', 'o', 'o', 0x00, - 0x00, - }), - NewDocument(EC.Null("foo")), - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - enc := NewDocumentEncoder() - got, err := enc.EncodeDocument(tc.m) - if err != tc.err { - t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) - } - if !documentComparer(got, tc.want) { - t.Errorf("Documents differ. got %v; want %v", got, tc.want) - } - }) - } - }) - t.Run("Writer/[]byte", func(t *testing.T) { - testCases := []struct { - name string - m []byte - b []byte - err error - }{ - { - "success", - []byte{ - 0x0A, 0x00, 0x00, 0x00, - 0x0A, 'f', 'o', 'o', 0x00, - 0x00, - }, - []byte{ - 0x0A, 0x00, 0x00, 0x00, - 0x0A, 'f', 'o', 'o', 0x00, - 0x00, - }, - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var buf bytes.Buffer - enc := NewEncoder(&buf) - err := enc.Encode(tc.m) - if err != tc.err { - t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) - } - b := buf.Bytes() - if diff := cmp.Diff(tc.b, b); diff != "" { - t.Errorf("Bytes written differ: (-got +want)\n%s", diff) - } - }) - } - }) - t.Run("Document/[]byte", func(t *testing.T) { - testCases := []struct { - name string - m []byte - want *Document - err error - }{ - { - "success", - []byte{ - 0x0A, 0x00, 0x00, 0x00, - 0x0A, 'f', 'o', 'o', 0x00, - 0x00, - }, - NewDocument(EC.Null("foo")), - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - enc := NewDocumentEncoder() - got, err := enc.EncodeDocument(tc.m) - if err != tc.err { - t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) - } - if !documentComparer(got, tc.want) { - t.Errorf("Documents differ. got %v; want %v", got, tc.want) - } - }) - } - }) - t.Run("Writer/Reader", func(t *testing.T) { - testCases := []struct { - name string - r Reader - b []byte - err error - }{ - { - "success", - []byte{ - 0x0A, 0x00, 0x00, 0x00, - 0x0A, 'f', 'o', 'o', 0x00, - 0x00, - }, - []byte{ - 0x0A, 0x00, 0x00, 0x00, - 0x0A, 'f', 'o', 'o', 0x00, - 0x00, - }, - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var buf bytes.Buffer - enc := NewEncoder(&buf) - err := enc.Encode(tc.r) - if err != tc.err { - t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) - } - b := buf.Bytes() - if diff := cmp.Diff(tc.b, b); diff != "" { - t.Errorf("Bytes written differ: (-got +want)\n%s", diff) - } - }) - } - }) - t.Run("Document/Reader", func(t *testing.T) { - testCases := []struct { - name string - r Reader - want *Document - err error - }{ - { - "success", - []byte{ - 0x0A, 0x00, 0x00, 0x00, - 0x0A, 'f', 'o', 'o', 0x00, - 0x00, - }, - NewDocument(EC.Null("foo")), - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - enc := NewDocumentEncoder() - got, err := enc.EncodeDocument(tc.r) - if err != tc.err { - t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) - } - if !documentComparer(got, tc.want) { - t.Errorf("Documents differ. got %v; want %v", got, tc.want) - } - }) - } - }) - t.Run("Document/Marshaler", func(t *testing.T) { - testCases := []struct { - name string - r Marshaler - want *Document - err error - }{ - { - "success", - byteMarshaler([]byte{ - 0x0A, 0x00, 0x00, 0x00, - 0x0A, 'f', 'o', 'o', 0x00, - 0x00, - }), - NewDocument(EC.Null("foo")), - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - enc := NewDocumentEncoder() - got, err := enc.EncodeDocument(tc.r) - if err != tc.err { - t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) - } - if !documentComparer(got, tc.want) { - t.Errorf("Documents differ. got %v; want %v", got, tc.want) - } - }) - } - }) - t.Run("Writer/Reflection", reflectionEncoderTest) - t.Run("Document/Reflection", func(t *testing.T) { - testCases := []struct { - name string - value interface{} - want *Document - err error - }{ - { - "struct", - struct { - A string - }{ - A: "foo", - }, - NewDocument(EC.String("a", "foo")), - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - enc := NewDocumentEncoder() - got, err := enc.EncodeDocument(tc.value) - if err != tc.err { - t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) - } - if !documentComparer(got, tc.want) { - t.Errorf("Documents differ. got %v; want %v", got, tc.want) - } - }) - } - }) -} - -func reflectionEncoderTest(t *testing.T) { - oid := objectid.New() - oids := []objectid.ObjectID{objectid.New(), objectid.New(), objectid.New()} - var str = new(string) - *str = "bar" - now := time.Now() - murl, err := url.Parse("https://mongodb.com/random-url?hello=world") - if err != nil { - t.Errorf("Error parsing URL: %v", err) - t.FailNow() - } - decimal128, err := decimal.ParseDecimal128("1.5e10") - if err != nil { - t.Errorf("Error parsing decimal128: %v", err) - t.FailNow() - } - - testCases := []struct { - name string - value interface{} - b []byte - err error - }{ - { - "map[bool]int", - map[bool]int32{false: 1}, - []byte{ - 0x10, 0x00, 0x00, 0x00, - 0x10, 'f', 'a', 'l', 's', 'e', 0x00, - 0x01, 0x00, 0x00, 0x00, - 0x00, - }, - nil, - }, - { - "map[int]int", - map[int]int32{1: 1}, - []byte{ - 0x0C, 0x00, 0x00, 0x00, - 0x10, '1', 0x00, - 0x01, 0x00, 0x00, 0x00, - 0x00, - }, - nil, - }, - { - "map[uint]int", - map[uint]int32{1: 1}, - []byte{ - 0x0C, 0x00, 0x00, 0x00, - 0x10, '1', 0x00, - 0x01, 0x00, 0x00, 0x00, - 0x00, - }, - nil, - }, - { - "map[float32]int", - map[float32]int32{3.14: 1}, - []byte{ - 0x0F, 0x00, 0x00, 0x00, - 0x10, '3', '.', '1', '4', 0x00, - 0x01, 0x00, 0x00, 0x00, - 0x00, - }, - nil, - }, - { - "map[float64]int", - map[float64]int32{3.14: 1}, - []byte{ - 0x0F, 0x00, 0x00, 0x00, - 0x10, '3', '.', '1', '4', 0x00, - 0x01, 0x00, 0x00, 0x00, - 0x00, - }, - nil, - }, - { - "map[string]int", - map[string]int32{"foo": 1}, - []byte{ - 0x0E, 0x00, 0x00, 0x00, - 0x10, 'f', 'o', 'o', 0x00, - 0x01, 0x00, 0x00, 0x00, - 0x00, - }, - nil, - }, - { - "map[string]objectid.ObjectID", - map[string]objectid.ObjectID{"foo": oid}, - docToBytes(NewDocument(EC.ObjectID("foo", oid))), - nil, - }, - { - "map[objectid.ObjectID]string", - map[objectid.ObjectID]string{oid: "foo"}, - docToBytes(NewDocument(EC.String(oid.String(), "foo"))), - nil, - }, - { - "map[string]*string", - map[string]*string{"foo": str}, - docToBytes(NewDocument(EC.String("foo", "bar"))), - nil, - }, - { - "map[string]*string with nil", - map[string]*string{"baz": nil}, - docToBytes(NewDocument(EC.Null("baz"))), - nil, - }, - { - "map[string]_Interface", - map[string]_Interface{"foo": _impl{Foo: "bar"}}, - docToBytes(NewDocument(EC.SubDocumentFromElements("foo", EC.String("foo", "bar")))), - nil, - }, - { - "map[string]_Interface with nil", - map[string]_Interface{"baz": (*_impl)(nil)}, - docToBytes(NewDocument(EC.Null("baz"))), - nil, - }, - { - "map[json.Number]json.Number(int64)", - map[json.Number]json.Number{ - json.Number("5"): json.Number("10"), - }, - docToBytes(NewDocument(EC.Int64("5", 10))), - nil, - }, - { - "map[json.Number]json.Number(float64)", - map[json.Number]json.Number{ - json.Number("5.0"): json.Number("10.1"), - }, - docToBytes(NewDocument(EC.Double("5.0", 10.1))), - nil, - }, - { - "map[*url.URL]*url.URL", - map[*url.URL]*url.URL{ - murl: murl, - }, - docToBytes(NewDocument(EC.String(murl.String(), murl.String()))), - nil, - }, - { - "map[decimal.Decimal128]decimal.Decimal128", - map[decimal.Decimal128]decimal.Decimal128{ - decimal128: decimal128, - }, - docToBytes(NewDocument(EC.Decimal128(decimal128.String(), decimal128))), - nil, - }, - { - "[]string", - []string{"foo", "bar", "baz"}, - []byte{ - 0x26, 0x00, 0x00, 0x00, - 0x02, '0', 0x00, - 0x04, 0x00, 0x00, 0x00, - 'f', 'o', 'o', 0x00, - 0x02, '1', 0x00, - 0x04, 0x00, 0x00, 0x00, - 'b', 'a', 'r', 0x00, - 0x02, '2', 0x00, - 0x04, 0x00, 0x00, 0x00, - 'b', 'a', 'z', 0x00, - 0x00, - }, - nil, - }, - { - "[]*Element", - []*Element{EC.Null("A"), EC.Null("B"), EC.Null("C")}, - []byte{ - 0x0E, 0x00, 0x00, 0x00, - 0x0A, 'A', 0x00, - 0x0A, 'B', 0x00, - 0x0A, 'C', 0x00, - 0x00, - }, - nil, - }, - { - "[]*Document", - []*Document{NewDocument(EC.Null("A"))}, - docToBytes(NewDocument( - EC.SubDocumentFromElements("0", (EC.Null("A"))), - )), - nil, - }, - { - "[]Reader", - []Reader{{0x05, 0x00, 0x00, 0x00, 0x00}}, - docToBytes(NewDocument( - EC.SubDocumentFromElements("0"), - )), - nil, - }, - { - "[]objectid.ObjectID", - oids, - arrToBytes(NewArray( - VC.ObjectID(oids[0]), - VC.ObjectID(oids[1]), - VC.ObjectID(oids[2]), - )), - nil, - }, - { - "[]*string with nil", - []*string{str, nil}, - arrToBytes(NewArray( - VC.String(*str), - VC.Null(), - )), - nil, - }, - { - "[]_Interface with nil", - []_Interface{_impl{Foo: "bar"}, (*_impl)(nil), nil}, - arrToBytes(NewArray( - VC.DocumentFromElements(EC.String("foo", "bar")), - VC.Null(), - VC.Null(), - )), - nil, - }, - { - "[]json.Number", - []json.Number{"5", "10.1"}, - arrToBytes(NewArray( - VC.Int64(5), - VC.Double(10.1), - )), - nil, - }, - { - "[]*url.URL", - []*url.URL{murl}, - arrToBytes(NewArray( - VC.String(murl.String()), - )), - nil, - }, - { - "[]decimal.Decimal128", - []decimal.Decimal128{decimal128}, - arrToBytes(NewArray( - VC.Decimal128(decimal128), - )), - nil, - }, - { - "map[string][]*Element", - map[string][]*Element{"Z": {EC.Int32("A", 1), EC.Int32("B", 2), EC.Int32("EC", 3)}}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.Int32(1), VC.Int32(2), VC.Int32(3)), - )), - nil, - }, - { - "map[string][]*Value", - map[string][]*Value{"Z": {VC.Int32(1), VC.Int32(2), VC.Int32(3)}}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.Int32(1), VC.Int32(2), VC.Int32(3)), - )), - nil, - }, - { - "map[string]*Element", - map[string]*Element{"Z": EC.Int32("foo", 12345)}, - docToBytes(NewDocument( - EC.Int32("foo", 12345), - )), - nil, - }, - { - "map[string]*Document", - map[string]*Document{"Z": NewDocument(EC.Null("foo"))}, - docToBytes(NewDocument( - EC.SubDocumentFromElements("Z", EC.Null("foo")), - )), - nil, - }, - { - "map[string]Reader", - map[string]Reader{"Z": {0x05, 0x00, 0x00, 0x00, 0x00}}, - docToBytes(NewDocument( - EC.SubDocumentFromReader("Z", Reader{0x05, 0x00, 0x00, 0x00, 0x00}), - )), - nil, - }, - { - "map[string][]int32", - map[string][]int32{"Z": {1, 2, 3}}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.Int32(1), VC.Int32(2), VC.Int32(3)), - )), - nil, - }, - { - "map[string][]objectid.ObjectID", - map[string][]objectid.ObjectID{"Z": oids}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.ObjectID(oids[0]), VC.ObjectID(oids[1]), VC.ObjectID(oids[2])), - )), - nil, - }, - { - "map[string][]*string with nil", - map[string][]*string{"Z": {str, nil}}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.String("bar"), VC.Null()), - )), - nil, - }, - { - "map[string][]_Interface with nil", - map[string][]_Interface{"Z": {_impl{Foo: "bar"}, (*_impl)(nil), nil}}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.DocumentFromElements(EC.String("foo", "bar")), VC.Null(), VC.Null()), - )), - nil, - }, - { - "map[string][]json.Number(int64)", - map[string][]json.Number{"Z": {json.Number("5"), json.Number("10")}}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.Int64(5), VC.Int64(10)), - )), - nil, - }, - { - "map[string][]json.Number(float64)", - map[string][]json.Number{"Z": {json.Number("5.0"), json.Number("10.10")}}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.Double(5.0), VC.Double(10.10)), - )), - nil, - }, - { - "map[string][]*url.URL", - map[string][]*url.URL{"Z": {murl}}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.String(murl.String())), - )), - nil, - }, - { - "map[string][]decimal.Decimal128", - map[string][]decimal.Decimal128{"Z": {decimal128}}, - docToBytes(NewDocument( - EC.ArrayFromElements("Z", VC.Decimal128(decimal128)), - )), - nil, - }, - { - "[2]*Element", - [2]*Element{EC.Int32("A", 1), EC.Int32("B", 2)}, - docToBytes(NewDocument( - EC.Int32("A", 1), EC.Int32("B", 2), - )), - nil, - }, - { - "-", - struct { - A string `bson:"-"` - }{ - A: "", - }, - docToBytes(NewDocument()), - nil, - }, - { - "omitempty", - struct { - A string `bson:",omitempty"` - }{ - A: "", - }, - docToBytes(NewDocument()), - nil, - }, - { - "omitempty, empty time", - struct { - A time.Time `bson:",omitempty"` - }{ - A: time.Time{}, - }, - docToBytes(NewDocument()), - nil, - }, - { - "no private fields", - struct { - a string - }{ - a: "should be empty", - }, - docToBytes(NewDocument()), - nil, - }, - { - "minsize", - struct { - A int64 `bson:",minsize"` - }{ - A: 12345, - }, - docToBytes(NewDocument(EC.Int32("a", 12345))), - nil, - }, - { - "inline", - struct { - Foo struct { - A int64 `bson:",minsize"` - } `bson:",inline"` - }{ - Foo: struct { - A int64 `bson:",minsize"` - }{ - A: 12345, - }, - }, - docToBytes(NewDocument(EC.Int32("a", 12345))), - nil, - }, - { - "inline map", - struct { - Foo map[string]string `bson:",inline"` - }{ - Foo: map[string]string{"foo": "bar"}, - }, - docToBytes(NewDocument(EC.String("foo", "bar"))), - nil, - }, - { - "alternate name bson:name", - struct { - A string `bson:"foo"` - }{ - A: "bar", - }, - docToBytes(NewDocument(EC.String("foo", "bar"))), - nil, - }, - { - "alternate name", - struct { - A string `bson:"foo"` - }{ - A: "bar", - }, - docToBytes(NewDocument(EC.String("foo", "bar"))), - nil, - }, - { - "inline, omitempty", - struct { - A string - Foo zeroTest `bson:"omitempty,inline"` - }{ - A: "bar", - Foo: zeroTest{true}, - }, - docToBytes(NewDocument(EC.String("a", "bar"))), - nil, - }, - { - "struct{}", - struct { - A bool - B int32 - C int64 - D uint16 - E uint64 - F float64 - G string - H map[string]string - I []byte - J [4]byte - K [2]string - L struct { - M string - } - N *Element - O *Document - P Reader - Q objectid.ObjectID - R *string - S map[struct{}]struct{} - T []struct{} - U _Interface - V _Interface - W map[struct{}]struct{} - X map[struct{}]struct{} - Y json.Number - Z time.Time - AA json.Number - AB *url.URL - AC decimal.Decimal128 - AD *time.Time - }{ - A: true, - B: 123, - C: 456, - D: 789, - E: 101112, - F: 3.14159, - G: "Hello, world", - H: map[string]string{"foo": "bar"}, - I: []byte{0x01, 0x02, 0x03}, - J: [4]byte{0x04, 0x05, 0x06, 0x07}, - K: [2]string{"baz", "qux"}, - L: struct { - M string - }{ - M: "foobar", - }, - N: EC.Null("N"), - O: NewDocument(EC.Int64("countdown", 9876543210)), - P: Reader{0x05, 0x00, 0x00, 0x00, 0x00}, - Q: oid, - R: nil, - S: nil, - T: nil, - U: nil, - V: _Interface((*_impl)(nil)), // typed nil - W: map[struct{}]struct{}{}, - X: nil, - Y: json.Number("5"), - Z: now, - AA: json.Number("10.10"), - AB: murl, - AC: decimal128, - AD: &now, - }, - docToBytes(NewDocument( - EC.Boolean("a", true), - EC.Int32("b", 123), - EC.Int64("c", 456), - EC.Int32("d", 789), - EC.Int64("e", 101112), - EC.Double("f", 3.14159), - EC.String("g", "Hello, world"), - EC.SubDocumentFromElements("h", EC.String("foo", "bar")), - EC.Binary("i", []byte{0x01, 0x02, 0x03}), - EC.Binary("j", []byte{0x04, 0x05, 0x06, 0x07}), - EC.ArrayFromElements("k", VC.String("baz"), VC.String("qux")), - EC.SubDocumentFromElements("l", EC.String("m", "foobar")), - EC.Null("N"), - EC.SubDocumentFromElements("o", EC.Int64("countdown", 9876543210)), - EC.SubDocumentFromElements("p"), - EC.ObjectID("q", oid), - EC.Null("r"), - EC.Null("s"), - EC.Null("t"), - EC.Null("u"), - EC.Null("v"), - EC.SubDocument("w", NewDocument()), - EC.Null("x"), - EC.Int64("y", 5), - EC.DateTime("z", now.UnixNano()/int64(time.Millisecond)), - EC.Double("aa", 10.10), - EC.String("ab", murl.String()), - EC.Decimal128("ac", decimal128), - EC.DateTime("ad", now.UnixNano()/int64(time.Millisecond)), - )), - nil, - }, - { - "struct{[]interface{}}", - struct { - A []bool - B []int32 - C []int64 - D []uint16 - E []uint64 - F []float64 - G []string - H []map[string]string - I [][]byte - J [1][4]byte - K [1][2]string - L []struct { - M string - } - N [][]string - O []*Element - P []*Document - Q []Reader - R []objectid.ObjectID - S []*string - T []struct{} - U []_Interface - V []_Interface - W []map[struct{}]struct{} - X []map[struct{}]struct{} - Y []map[struct{}]struct{} - Z []time.Time - AA []json.Number - AB []*url.URL - AC []decimal.Decimal128 - AD []*time.Time - }{ - A: []bool{true}, - B: []int32{123}, - C: []int64{456}, - D: []uint16{789}, - E: []uint64{101112}, - F: []float64{3.14159}, - G: []string{"Hello, world"}, - H: []map[string]string{{"foo": "bar"}}, - I: [][]byte{{0x01, 0x02, 0x03}}, - J: [1][4]byte{{0x04, 0x05, 0x06, 0x07}}, - K: [1][2]string{{"baz", "qux"}}, - L: []struct { - M string - }{ - { - M: "foobar", - }, - }, - N: [][]string{{"foo", "bar"}}, - O: []*Element{EC.Null("N")}, - P: []*Document{NewDocument(EC.Int64("countdown", 9876543210))}, - Q: []Reader{{0x05, 0x00, 0x00, 0x00, 0x00}}, - R: oids, - S: []*string{str, nil}, - T: nil, - U: nil, - V: []_Interface{_impl{Foo: "bar"}, nil, (*_impl)(nil)}, - W: nil, - X: []map[struct{}]struct{}{}, // Should be empty BSON Array - Y: []map[struct{}]struct{}{{}}, // Should be BSON array with one element, an empty BSON SubDocument - Z: []time.Time{now, now}, - AA: []json.Number{json.Number("5"), json.Number("10.10")}, - AB: []*url.URL{murl}, - AC: []decimal.Decimal128{decimal128}, - AD: []*time.Time{&now, &now}, - }, - docToBytes(NewDocument( - EC.ArrayFromElements("a", VC.Boolean(true)), - EC.ArrayFromElements("b", VC.Int32(123)), - EC.ArrayFromElements("c", VC.Int64(456)), - EC.ArrayFromElements("d", VC.Int32(789)), - EC.ArrayFromElements("e", VC.Int64(101112)), - EC.ArrayFromElements("f", VC.Double(3.14159)), - EC.ArrayFromElements("g", VC.String("Hello, world")), - EC.ArrayFromElements("h", VC.DocumentFromElements(EC.String("foo", "bar"))), - EC.ArrayFromElements("i", VC.Binary([]byte{0x01, 0x02, 0x03})), - EC.ArrayFromElements("j", VC.Binary([]byte{0x04, 0x05, 0x06, 0x07})), - EC.ArrayFromElements("k", VC.ArrayFromValues(VC.String("baz"), VC.String("qux"))), - EC.ArrayFromElements("l", VC.DocumentFromElements(EC.String("m", "foobar"))), - EC.ArrayFromElements("n", VC.ArrayFromValues(VC.String("foo"), VC.String("bar"))), - EC.ArrayFromElements("o", VC.Null()), - EC.ArrayFromElements("p", VC.DocumentFromElements(EC.Int64("countdown", 9876543210))), - EC.ArrayFromElements("q", VC.DocumentFromElements()), - EC.ArrayFromElements("r", VC.ObjectID(oids[0]), VC.ObjectID(oids[1]), VC.ObjectID(oids[2])), - EC.ArrayFromElements("s", VC.String("bar"), VC.Null()), - EC.Null("t"), - EC.Null("u"), - EC.ArrayFromElements("v", VC.DocumentFromElements(EC.String("foo", "bar")), VC.Null(), VC.Null()), - EC.Null("w"), - EC.Array("x", NewArray()), - EC.ArrayFromElements("y", VC.Document(NewDocument())), - EC.ArrayFromElements("z", VC.DateTime(now.UnixNano()/int64(time.Millisecond)), VC.DateTime(now.UnixNano()/int64(time.Millisecond))), - EC.ArrayFromElements("aa", VC.Int64(5), VC.Double(10.10)), - EC.ArrayFromElements("ab", VC.String(murl.String())), - EC.ArrayFromElements("ac", VC.Decimal128(decimal128)), - EC.ArrayFromElements("ad", VC.DateTime(now.UnixNano()/int64(time.Millisecond)), VC.DateTime(now.UnixNano()/int64(time.Millisecond))), - )), - nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var buf bytes.Buffer - enc := NewEncoder(&buf) - err := enc.Encode(tc.value) - if err != tc.err { - t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) - } - b := buf.Bytes() - if diff := cmp.Diff(b, tc.b); diff != "" { - t.Errorf("Bytes written differ: (-got +want)\n%s", diff) - t.Errorf("Bytes\ngot: %v\nwant:%v\n", b, tc.b) - t.Errorf("Readers\ngot: %v\nwant:%v\n", Reader(b), Reader(tc.b)) - } - }) - } -} - -type zeroTest struct { - reportZero bool -} - -func (z zeroTest) IsZero() bool { return z.reportZero } - -func compareZeroTest(_, _ zeroTest) bool { return true } - -type nonZeroer struct { - value bool -} - -func TestZeoerInterfaceUsedByDecoder(t *testing.T) { - enc := &encoder{} - - // cases that are zero, because they are known types or pointers - var st *nonZeroer - assert.True(t, enc.isZero(reflect.ValueOf(st))) - assert.True(t, enc.isZero(reflect.ValueOf(0))) - assert.True(t, enc.isZero(reflect.ValueOf(false))) - - // cases that shouldn't be zero - st = &nonZeroer{value: false} - assert.False(t, enc.isZero(reflect.ValueOf(struct{ val bool }{val: true}))) - assert.False(t, enc.isZero(reflect.ValueOf(struct{ val bool }{val: false}))) - assert.False(t, enc.isZero(reflect.ValueOf(st))) - st.value = true - assert.False(t, enc.isZero(reflect.ValueOf(st))) - - // a test to see if the interface impacts the outcome - z := zeroTest{} - assert.False(t, enc.isZero(reflect.ValueOf(z))) - - z.reportZero = true - assert.True(t, enc.isZero(reflect.ValueOf(z))) - -} - -type timePrtStruct struct{ TimePtrField *time.Time } - -func TestRegressionNoDereferenceNilTimePtr(t *testing.T) { - enc := &encoder{} - - assert.NotPanics(t, func() { - res, err := enc.encodeStruct(reflect.ValueOf(timePrtStruct{})) - assert.Len(t, res, 1) - assert.Nil(t, err) - }) - - assert.NotPanics(t, func() { - res, err := enc.encodeSliceAsArray(reflect.ValueOf([]*time.Time{nil, nil, nil}), false) - assert.Len(t, res, 3) - assert.Nil(t, err) - }) -} - func docToBytes(d *Document) []byte { b, err := d.MarshalBSON() if err != nil { diff --git a/bson/encoder_test.go b/bson/encoder_test.go deleted file mode 100644 index d9c2852ce7..0000000000 --- a/bson/encoder_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package bson - -import ( - "bytes" - "testing" -) - -func TestEncoderEncode(t *testing.T) { - for _, tc := range marshalingTestCases { - t.Run(tc.name, func(t *testing.T) { - got := make(writer, 0, 1024) - vw := newValueWriter(&got) - reg := NewRegistryBuilder().Build() - enc, err := NewEncoderv2(reg, vw) - noerr(t, err) - err = enc.Encode(tc.val) - noerr(t, err) - - if !bytes.Equal(got, tc.want) { - t.Errorf("Bytes are not equal. got %v; want %v", Reader(got), Reader(tc.want)) - t.Errorf("Bytes:\n%v\n%v", got, tc.want) - } - }) - } -} diff --git a/bson/internal/llbson/llbson.go b/bson/internal/llbson/llbson.go index 0a78149106..abf287ce03 100644 --- a/bson/internal/llbson/llbson.go +++ b/bson/internal/llbson/llbson.go @@ -36,7 +36,10 @@ func AppendKey(dst []byte, key string) []byte { return append(dst, key+string(0x // AppendHeader will append Type t and key to dst and return the extended // buffer. func AppendHeader(dst []byte, t Type, key string) []byte { - return append(AppendType(dst, t), key+string(0x00)...) + dst = AppendType(dst, t) + dst = append(dst, key...) + return append(dst, 0x00) + // return append(AppendType(dst, t), key+string(0x00)...) } // ReadType will return the first byte of the provided []byte as a type. If @@ -436,6 +439,68 @@ func AppendMaxKeyElement(dst []byte, key string) []byte { return AppendHeader(ds // and return the extended buffer. func AppendMinKeyElement(dst []byte, key string) []byte { return AppendHeader(dst, TypeMinKey, key) } +// EqualValue will return true if the two values are equal. +func EqualValue(t1, t2 Type, v1, v2 []byte) bool { + if t1 != t2 { + return false + } + length1, ok := valueLength(t1, v1) + if !ok { + return false + } + length2, ok := valueLength(t2, v2) + if !ok { + return false + } + return bytes.Equal(v1[:length1], v2[:length2]) +} + +func valueLength(t Type, val []byte) (int32, bool) { + var length int32 + ok := true + switch t { + case TypeArray, TypeEmbeddedDocument, TypeCodeWithScope: + length, ok = readLength(val) + case TypeBinary: + length, ok = readLength(val) + length += 4 + 1 // binary length + subtype byte + case TypeBoolean: + length = 1 + case TypeDBPointer: + length, ok = readLength(val) + length += 4 + 12 // string length + ObjectID length + case TypeDateTime, TypeDouble, TypeInt64, TypeTimestamp: + length = 8 + case TypeDecimal128: + length = 16 + case TypeInt32: + length = 4 + case TypeJavaScript, TypeString, TypeSymbol: + length, ok = readLength(val) + length += 4 + case TypeMaxKey, TypeMinKey, TypeNull, TypeUndefined: + length = 0 + case TypeObjectID: + length = 12 + case TypeRegex: + regex := bytes.IndexByte(val, 0x00) + if regex < 0 { + ok = false + break + } + pattern := bytes.IndexByte(val, 0x00) + if pattern < 0 { + ok = false + break + } + length = int32(int64(regex) + 1 + int64(pattern) + 1) + default: + ok = false + } + + return length, ok +} + func appendLength(dst []byte, l int32) []byte { return appendi32(dst, l) } func appendi32(dst []byte, i32 int32) []byte { diff --git a/bson/map_codec.go b/bson/map_codec.go deleted file mode 100644 index f330dd5b87..0000000000 --- a/bson/map_codec.go +++ /dev/null @@ -1,152 +0,0 @@ -package bson - -import ( - "fmt" - "reflect" -) - -var defaultMapCodec = &MapCodec{} - -// MapCodec is the Codec used for map values. -type MapCodec struct{} - -var _ Codec = &MapCodec{} - -// EncodeValue implements the Codec interface. -func (mc *MapCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - val := reflect.ValueOf(i) - if val.Kind() != reflect.Map || val.Type().Key().Kind() != reflect.String { - return fmt.Errorf("%T can only encode maps with string keys", mc) - } - - dw, err := vw.WriteDocument() - if err != nil { - return err - } - - return mc.encodeValue(ec, dw, val, nil) -} - -// encodeValue handles encoding of the values of a map. The collisionFn returns -// true if the provided key exists, this is mainly used for inline maps in the -// struct codec. -func (mc *MapCodec) encodeValue(ec EncodeContext, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { - - var err error - var codec Codec - switch val.Type().Elem() { - case tElement: - codec = defaultElementCodec - default: - codec, err = ec.Lookup(val.Type().Elem()) - if err != nil { - return err - } - } - - keys := val.MapKeys() - for _, key := range keys { - if collisionFn != nil && collisionFn(key.String()) { - return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key) - } - vw, err := dw.WriteDocumentElement(key.String()) - if err != nil { - return err - } - - err = codec.EncodeValue(ec, vw, val.MapIndex(key).Interface()) - if err != nil { - return err - } - } - - return dw.WriteDocumentEnd() -} - -// DecodeValue implements the Codec interface. -func (mc *MapCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - val := reflect.ValueOf(i) - if !val.IsValid() || val.Kind() != reflect.Ptr || val.IsNil() { - return fmt.Errorf("%T can only be used to decode non-nil pointers to map values, got %T", mc, i) - } - - if val.Elem().Kind() != reflect.Map || val.Elem().Type().Key().Kind() != reflect.String || !val.Elem().CanSet() { - return fmt.Errorf("%T can only decode settable maps with string keys", mc) - } - - dr, err := vr.ReadDocument() - if err != nil { - return err - } - - if val.Elem().IsNil() { - val.Elem().Set(reflect.MakeMap(val.Elem().Type())) - } - - mVal := val.Elem() - - dFn, err := mc.decodeFn(dc, mVal) - if err != nil { - return err - } - - for { - var elem reflect.Value - key, vr, err := dr.ReadElement() - if err == ErrEOD { - break - } - if err != nil { - return err - } - key, elem, err = dFn(dc, vr, key) - if err != nil { - return err - } - - mVal.SetMapIndex(reflect.ValueOf(key), elem) - } - return err -} - -type decodeFn func(dc DecodeContext, vr ValueReader, key string) (updatedKey string, v reflect.Value, err error) - -// decodeFn returns a function that can be used to decode the values of a map. -// The mapVal parameter should be a map type, not a pointer to a map type. -// -// If error is nil, decodeFn will return a non-nil decodeFn. -func (mc *MapCodec) decodeFn(dc DecodeContext, mapVal reflect.Value) (decodeFn, error) { - var dFn decodeFn - switch mapVal.Type().Elem() { - case tElement: - // TODO(skriptble): We have to decide if we want to support this. We have - // information loss because we can only store either the map key or the element - // key. We could add a struct tag field that allows the user to make a decision. - dFn = func(dc DecodeContext, vr ValueReader, key string) (string, reflect.Value, error) { - var elem *Element - err := defaultElementCodec.decodeValue(dc, vr, key, &elem) - if err != nil { - return key, reflect.Value{}, err - } - return key, reflect.ValueOf(elem), nil - } - default: - eType := mapVal.Type().Elem() - codec, err := dc.Lookup(eType) - if err != nil { - return nil, err - } - - dFn = func(dc DecodeContext, vr ValueReader, key string) (string, reflect.Value, error) { - ptr := reflect.New(eType) - - err = codec.DecodeValue(dc, vr, ptr.Interface()) - if err != nil { - return key, reflect.Value{}, err - } - return key, ptr.Elem(), nil - } - } - - return dFn, nil -} diff --git a/bson/marshal.go b/bson/marshal.go index 8342adaccc..72aa80d478 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -5,183 +5,3 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 package bson - -import "bytes" - -// Marshal converts a BSON type to bytes. -// -// The value can be any one of the following types: -// -// - bson.Marshaler -// - io.Reader -// - []byte -// - bson.Reader -// - any map with string keys -// - a struct (possibly with tags) -// -// The following flags are currently supported for marshaling: -// -// omitempty Only include the field if it's not set to the zero value for the type or to -// empty slices or maps. -// -// minsize Marshal an integer of a type larger than 32 bits value as an int32, if that's -// feasible while preserving the numeric value. -// -// inline Inline the field, which must be a struct or a map, causing all of its fields -// or keys to be processed as if they were part of the outer struct. For maps, -// keys must not conflict with the bson keys of other struct fields. -// -// Skip This struct field should be skipped. This is usually denoted by parsing a "-" -// for the name. -// -// See the DefaultStructTagParser declaration and the StructTags type for more -// information. -func Marshal(value interface{}) ([]byte, error) { - var out bytes.Buffer - - err := NewEncoder(&out).Encode(value) - if err != nil { - return nil, err - } - - return out.Bytes(), nil -} - -// Unmarshal converts bytes into a BSON type. -// -// The value can be any one of the following types: -// -// - bson.Unmarshaler -// - io.Writer -// - []byte -// - bson.Reader -// - any map with string keys -// - a struct (possibly with tags) -// -// In the case of struct values, only exported fields will be deserialized. The lowercased field -// name is used as the key for each exported field, but this behavior may be changed using a struct -// tag. The tag may also contain flags to adjust the unmarshaling behavior for the field. The tag -// formats accepted are: -// -// "[][,[,]]" -// -// `(...) bson:"[][,[,]]" (...)` -// -// The target field or element types of out may not necessarily match the BSON values of the -// provided data. The following conversions are made automatically: -// -// - Numeric types are converted if at least the integer part of the value would be preserved -// correctly -// -// If the value would not fit the type and cannot be converted, it is silently skipped. -// -// Pointer values are initialized when necessary. -func Unmarshal(in []byte, out interface{}) error { - return NewDecoder(bytes.NewReader(in)).Decode(out) -} - -// UnmarshalDocument converts bytes into a *bson.Document. -func UnmarshalDocument(bson []byte) (*Document, error) { - return ReadDocument(bson) -} - -// Marshalv2 returns the BSON encoding of val. -// -// Marshal will use the default registry created by NewRegistry to recursively -// marshal val into a []byte. Marshal will inspect struct tags and alter the -// marshaling process accordingly. -func Marshalv2(val interface{}) ([]byte, error) { - return MarshalWithRegistry(defaultRegistry, val) -} - -// MarshalAppend will append the BSON encoding of val to dst. If dst is not -// large enough to hold the BSON encoding of val, dst will be grown. -func MarshalAppend(dst []byte, val interface{}) ([]byte, error) { - return MarshalAppendWithRegistry(defaultRegistry, dst, val) -} - -// MarshalWithRegistry returns the BSON encoding of val using Registry r. -func MarshalWithRegistry(r *Registry, val interface{}) ([]byte, error) { - dst := make([]byte, 0, 256) // TODO: make the default cap a constant - return MarshalAppendWithRegistry(r, dst, val) -} - -// MarshalAppendWithRegistry will append the BSON encoding of val to dst using -// Registry r. If dst is not large enough to hold the BSON encoding of val, dst -// will be grown. -func MarshalAppendWithRegistry(r *Registry, dst []byte, val interface{}) ([]byte, error) { - // w := writer(dst) - // vw := newValueWriter(&w) - vw := vwPool.Get().(*valueWriter) - defer vwPool.Put(vw) - - vw.reset(dst) - - enc := encPool.Get().(*Encoderv2) - defer encPool.Put(enc) - - err := enc.Reset(vw) - if err != nil { - return nil, err - } - err = enc.SetRegistry(r) - if err != nil { - return nil, err - } - - err = enc.Encode(val) - if err != nil { - return nil, err - } - - return vw.buf, nil -} - -// MarshalDocument returns val encoded as a *Document. -// -// MarshalDocument will use the default registry created by NewRegistry to recursively -// marshal val into a *Document. MarshalDocument will inspect struct tags and alter the -// marshaling process accordingly. -func MarshalDocument(val interface{}) (*Document, error) { - return MarshalDocumentAppend(NewDocument(), val) -} - -// MarshalDocumentAppend will append val encoded to dst. If dst is nil, a new *Document will be -// allocated and the encoding of val will be appended to that. -func MarshalDocumentAppend(dst *Document, val interface{}) (*Document, error) { - return MarshalDocumentAppendWithRegistry(defaultRegistry, dst, val) -} - -// MarshalDocumentWithRegistry returns val encoded as a *Document using r. -func MarshalDocumentWithRegistry(r *Registry, val interface{}) (*Document, error) { - return MarshalDocumentAppendWithRegistry(r, NewDocument(), val) -} - -// MarshalDocumentAppendWithRegistry will append val encoded to dst using r. If dst is nil, a new -// *Document will be allocated and the encoding of val will be appended to that. -func MarshalDocumentAppendWithRegistry(r *Registry, dst *Document, val interface{}) (*Document, error) { - d := dst - if d == nil { - d = NewDocument() - } - dvw := newDocumentValueWriter(d) - - enc := encPool.Get().(*Encoderv2) - defer encPool.Put(enc) - - err := enc.Reset(dvw) - if err != nil { - return nil, err - } - err = enc.SetRegistry(r) - if err != nil { - return nil, err - } - - err = enc.Encode(val) - if err != nil { - return nil, err - } - - return d, nil -} diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 1389aa8874..72aa80d478 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -5,159 +5,3 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 package bson - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestMarshal_roundtripFromBytes(t *testing.T) { - before := []byte{ - // length - 0x1c, 0x0, 0x0, 0x0, - - // --- begin array --- - - // type - document - 0x3, - // key - "foo" - 0x66, 0x6f, 0x6f, 0x0, - - // length - 0x12, 0x0, 0x0, 0x0, - // type - string - 0x2, - // key - "bar" - 0x62, 0x61, 0x72, 0x0, - // value - string length - 0x4, 0x0, 0x0, 0x0, - // value - "baz" - 0x62, 0x61, 0x7a, 0x0, - - // null terminator - 0x0, - - // --- end array --- - - // null terminator - 0x0, - } - - doc := NewDocument() - require.NoError(t, Unmarshal(before, doc)) - - after, err := Marshal(doc) - require.NoError(t, err) - - require.True(t, bytes.Equal(before, after)) -} - -func TestMarshal_roundtripFromDoc(t *testing.T) { - before := NewDocument( - EC.String("foo", "bar"), - EC.Int32("baz", -27), - EC.ArrayFromElements("bing", VC.Null(), VC.Regex("word", "i")), - ) - - bson, err := Marshal(before) - require.NoError(t, err) - - after := NewDocument() - require.NoError(t, Unmarshal(bson, after)) - - require.True(t, before.Equal(after)) -} - -func TestMarshal_roundtripWithUnmarshalDoc(t *testing.T) { - before := NewDocument( - EC.String("foo", "bar"), - EC.Int32("baz", -27), - EC.ArrayFromElements("bing", VC.Null(), VC.Regex("word", "i")), - ) - - bson, err := Marshal(before) - require.NoError(t, err) - - after, err := UnmarshalDocument(bson) - require.NoError(t, err) - - require.True(t, before.Equal(after)) -} - -func TestMarshalAppendWithRegistry(t *testing.T) { - for _, tc := range marshalingTestCases { - t.Run(tc.name, func(t *testing.T) { - dst := make([]byte, 0, 1024) - var reg *Registry - if tc.reg != nil { - reg = tc.reg - } else { - reg = NewRegistryBuilder().Build() - } - got, err := MarshalAppendWithRegistry(reg, dst, tc.val) - noerr(t, err) - - if !bytes.Equal(got, tc.want) { - t.Errorf("Bytes are not equal. got %v; want %v", Reader(got), Reader(tc.want)) - t.Errorf("Bytes:\n%v\n%v", got, tc.want) - } - }) - } -} - -func TestMarshalWithRegistry(t *testing.T) { - for _, tc := range marshalingTestCases { - t.Run(tc.name, func(t *testing.T) { - var reg *Registry - if tc.reg != nil { - reg = tc.reg - } else { - reg = NewRegistryBuilder().Build() - } - got, err := MarshalWithRegistry(reg, tc.val) - noerr(t, err) - - if !bytes.Equal(got, tc.want) { - t.Errorf("Bytes are not equal. got %v; want %v", Reader(got), Reader(tc.want)) - t.Errorf("Bytes:\n%v\n%v", got, tc.want) - } - }) - } -} - -func TestMarshalAppend(t *testing.T) { - for _, tc := range marshalingTestCases { - t.Run(tc.name, func(t *testing.T) { - if tc.reg != nil { - t.Skip() // test requires custom registry - } - dst := make([]byte, 0, 1024) - got, err := MarshalAppend(dst, tc.val) - noerr(t, err) - - if !bytes.Equal(got, tc.want) { - t.Errorf("Bytes are not equal. got %v; want %v", Reader(got), Reader(tc.want)) - t.Errorf("Bytes:\n%v\n%v", got, tc.want) - } - }) - } -} - -func TestMarshal(t *testing.T) { - for _, tc := range marshalingTestCases { - t.Run(tc.name, func(t *testing.T) { - if tc.reg != nil { - t.Skip() // test requires custom registry - } - got, err := Marshal(tc.val) - noerr(t, err) - - if !bytes.Equal(got, tc.want) { - t.Errorf("Bytes are not equal. got %v; want %v", Reader(got), Reader(tc.want)) - t.Errorf("Bytes:\n%v\n%v", got, tc.want) - } - }) - } -} diff --git a/bson/objectid/objectid.go b/bson/objectid/objectid.go index 5ca5264402..2a47b54b8d 100644 --- a/bson/objectid/objectid.go +++ b/bson/objectid/objectid.go @@ -10,6 +10,7 @@ package objectid import ( + "bytes" "crypto/rand" "encoding/binary" "encoding/hex" @@ -53,6 +54,11 @@ func (id ObjectID) String() string { return fmt.Sprintf("ObjectID(%q)", id.Hex()) } +// IsZero returns true if id is the empty ObjectID. +func (id ObjectID) IsZero() bool { + return bytes.Equal(id[:], NilObjectID[:]) +} + // FromHex creates a new ObjectID from a hex string. It returns an error if the hex string is not a // valid ObjectID. func FromHex(s string) (ObjectID, error) { diff --git a/bson/reader.go b/bson/reader.go index 266518c467..200da0dff6 100644 --- a/bson/reader.go +++ b/bson/reader.go @@ -198,6 +198,17 @@ func (r Reader) String() string { return buf.String() } +// MarshalBSON implements the bsoncodec.Marshaler interface. +// +// This method does not copy the bytes from r. +func (r Reader) MarshalBSON() ([]byte, error) { + _, err := r.Validate() + if err != nil { + return nil, err + } + return r, nil +} + // recursiveKeys implements the logic for the Keys method. This is a separate // function to facilitate recursive calls. func (r Reader) recursiveKeys(recursive bool, prefix ...string) (Keys, error) { diff --git a/bson/reader_iterator.go b/bson/reader_iterator.go index 41ed2977e9..d49935c27a 100644 --- a/bson/reader_iterator.go +++ b/bson/reader_iterator.go @@ -57,6 +57,7 @@ func (itr *ReaderIterator) Next() bool { itr.elem.value.start = elemStart itr.elem.value.offset = itr.pos itr.elem.value.data = itr.r + itr.elem.value.d = nil n, err = itr.elem.value.validate(false) itr.pos += n diff --git a/bson/registry.go b/bson/registry.go deleted file mode 100644 index 6b3d2c3d2f..0000000000 --- a/bson/registry.go +++ /dev/null @@ -1,254 +0,0 @@ -package bson - -import ( - "errors" - "reflect" - "sync" -) - -// ErrNoCodec is returned when there is no codec available for a type or interface in the registry. -type ErrNoCodec struct { - Type reflect.Type -} - -func (enc ErrNoCodec) Error() string { - return "no codec found for " + enc.Type.String() -} - -// ErrNotInterface is returned when the provided type is not an interface. -var ErrNotInterface = errors.New("The provided type is not an interface") - -var defaultRegistry = NewRegistryBuilder().Build() - -// A RegistryBuilder is used to build a Registry. This type is not goroutine -// safe. -type RegistryBuilder struct { - types map[reflect.Type]Codec - interfaces []interfacePair - kinds map[reflect.Kind]Codec -} - -// A Registry is used to store and retrieve codecs for types and interfaces. This type is the main -// typed passed around and Encoders and Decoders are constructed from it. -// -// TODO: Create a RegistryBuilder type and make the Registry type immutable. -type Registry struct { - tr typeRegistry - kr kindRegistry - ir interfaceRegistry - mu sync.RWMutex -} - -// NewRegistryBuilder creates a new RegistryBuilder. -func NewRegistryBuilder() *RegistryBuilder { - types := map[reflect.Type]Codec{ - tDocument: defaultDocumentCodec, - tArray: defaultArrayCodec, - tValue: defaultValueCodec, - reflect.PtrTo(tByteSlice): defaultByteSliceCodec, - reflect.PtrTo(tElementSlice): defaultElementSliceCodec, - reflect.PtrTo(tTime): defaultTimeCodec, - reflect.PtrTo(tEmpty): defaultEmptyInterfaceCodec, - reflect.PtrTo(tBinary): defaultBinaryCodec, - reflect.PtrTo(tUndefined): defaultUndefinedCodec, - reflect.PtrTo(tOID): defaultObjectIDCodec, - reflect.PtrTo(tDateTime): defaultDateTimeCodec, - reflect.PtrTo(tNull): defaultNullCodec, - reflect.PtrTo(tRegex): defaultRegexCodec, - reflect.PtrTo(tDBPointer): defaultDBPointerCodec, - reflect.PtrTo(tCodeWithScope): defaultCodeWithScopeCodec, - reflect.PtrTo(tTimestamp): defaultTimestampCodec, - reflect.PtrTo(tDecimal): defaultDecimal128Codec, - reflect.PtrTo(tMinKey): defaultMinKeyCodec, - reflect.PtrTo(tMaxKey): defaultMaxKeyCodec, - reflect.PtrTo(tJSONNumber): defaultJSONNumberCodec, - reflect.PtrTo(tURL): defaultURLCodec, - reflect.PtrTo(tReader): defaultReaderCodec, - } - kinds := map[reflect.Kind]Codec{ - reflect.Bool: defaultBoolCodec, - reflect.Int: defaultIntCodec, - reflect.Int8: defaultIntCodec, - reflect.Int16: defaultIntCodec, - reflect.Int32: defaultIntCodec, - reflect.Int64: defaultIntCodec, - reflect.Uint: defaultUintCodec, - reflect.Uint8: defaultUintCodec, - reflect.Uint16: defaultUintCodec, - reflect.Uint32: defaultUintCodec, - reflect.Uint64: defaultUintCodec, - reflect.Float32: defaultFloatCodec, - reflect.Float64: defaultFloatCodec, - reflect.Array: defaultSliceCodec, - reflect.Map: defaultMapCodec, - reflect.Slice: defaultSliceCodec, - reflect.String: defaultStringCodec, - reflect.Struct: &StructCodec{cache: make(map[reflect.Type]*structDescription), parser: DefaultStructTagParser}, - } - - return &RegistryBuilder{ - types: types, - kinds: kinds, - interfaces: make([]interfacePair, 0), - } -} - -// NewEmptyRegistryBuilder creates a new RegistryBuilder with no default kind -// Codecs. -func NewEmptyRegistryBuilder() *RegistryBuilder { - return &RegistryBuilder{ - types: make(map[reflect.Type]Codec), - kinds: make(map[reflect.Kind]Codec), - interfaces: make([]interfacePair, 0), - } -} - -// Register will register the provided Codec to the provided type. If the type is -// an interface, it will be registered in the interface registry. If the type is -// a pointer to or a type that is not an interface, it will be registered in the type -// registry. -func (rb *RegistryBuilder) Register(t reflect.Type, codec Codec) *RegistryBuilder { - switch t.Kind() { - case reflect.Interface: - for idx, ip := range rb.interfaces { - if ip.i == t { - rb.interfaces[idx].c = codec - return rb - } - } - - rb.interfaces = append(rb.interfaces, interfacePair{i: t, c: codec}) - default: - if t.Kind() != reflect.Ptr { - t = reflect.PtrTo(t) - } - - rb.types[t] = codec - } - return rb -} - -// RegisterDefault will register the provided Codec to the provided kind. -func (rb *RegistryBuilder) RegisterDefault(kind reflect.Kind, codec Codec) *RegistryBuilder { - rb.kinds[kind] = codec - return rb -} - -// Build creates a Registry from the current state of this RegistryBuilder. -func (rb *RegistryBuilder) Build() *Registry { - tr := make(typeRegistry) - for t, c := range rb.types { - tr[t] = c - } - kr := make(kindRegistry) - for k, c := range rb.kinds { - kr[k] = c - } - - ir := make(interfaceRegistry, len(rb.interfaces)) - copy(ir, rb.interfaces) - - return &Registry{ - tr: tr, - kr: kr, - ir: ir, - } -} - -// Lookup will inspect the type registry for either the type or a pointer to the type, -// if it doesn't find a codec it will inspect the interface registry for an interface -// that the type satisfies, if it doesn't find a codec there it will attempt to -// return either the default map codec or the default struct codec. If none of those -// apply, an error will be returned. -func (r *Registry) Lookup(t reflect.Type) (Codec, error) { - // We make this year so if we strip a pointer off it won't confuse user. If - // we did it where we return this and the user provided a pointer to the - // type, the error message would be for a lookup for the non-pointer version - // of the type. - codecerr := ErrNoCodec{Type: t} - r.mu.RLock() - codec, found := r.tr.lookup(t) - r.mu.RUnlock() - if found { - if codec == nil { - return nil, ErrNoCodec{Type: t} - } - return codec, nil - } - - codec, found = r.ir.lookup(t) - if found { - r.mu.Lock() - if t.Kind() != reflect.Ptr { - t = reflect.PtrTo(t) - } - r.tr[t] = codec - r.mu.Unlock() - return codec, nil - } - - // We don't allow maps with non-string keys - if t.Kind() == reflect.Map && t.Key().Kind() != reflect.String { - return nil, ErrNoCodec{Type: t} - } - - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - - codec, found = r.kr.lookup(t.Kind()) - if !found { - return nil, codecerr - } - - r.mu.Lock() - r.tr[t] = codec - r.mu.Unlock() - return codec, nil -} - -// The type registry handles codecs that are for specifics types that are not interfaces. -// This registry will handle both the types themselves and pointers to those types. -type typeRegistry map[reflect.Type]Codec - -// lookup handles finding a codec for the registered type. Will return an error if no codec -// could be found. -func (tr typeRegistry) lookup(t reflect.Type) (Codec, bool) { - if t.Kind() != reflect.Ptr { - t = reflect.PtrTo(t) - } - - codec, found := tr[t] - return codec, found -} - -type interfacePair struct { - i reflect.Type - c Codec -} - -// The kind registry handles codecs that are for base kinds. -type kindRegistry map[reflect.Kind]Codec - -// lookup handles finding a codec for the registered kind. Will return an error if no codec -// could be found. -func (kr kindRegistry) lookup(k reflect.Kind) (Codec, bool) { - codec, found := kr[k] - return codec, found -} - -// The interface registry handles codecs that are for interface types. -type interfaceRegistry []interfacePair - -// lookup handles finding a codec for the registered interface. Will return an error if no codec -// could be found. -func (ir interfaceRegistry) lookup(t reflect.Type) (Codec, bool) { - for _, ip := range ir { - if !t.Implements(ip.i) { - continue - } - - return ip.c, true - } - return nil, false -} diff --git a/bson/slice_codec.go b/bson/slice_codec.go deleted file mode 100644 index 2db3a49ce2..0000000000 --- a/bson/slice_codec.go +++ /dev/null @@ -1,182 +0,0 @@ -package bson - -import ( - "fmt" - "reflect" -) - -var defaultSliceCodec = &SliceCodec{} - -// SliceCodec is the Codec used for slice and array values. -type SliceCodec struct{} - -var _ Codec = &SliceCodec{} - -// EncodeValue implements the Codec interface. -func (sc *SliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error { - val := reflect.ValueOf(i) - switch val.Kind() { - case reflect.Array: - case reflect.Slice: - if val.IsNil() { // When nil, special case to null - return vw.WriteNull() - } - default: - return fmt.Errorf("%T can only encode arrays and slices", sc) - } - - length := val.Len() - - aw, err := vw.WriteArray() - if err != nil { - return err - } - - // We do this outside of the loop because an array or a slice can only have - // one element type. If it's the empty interface, we'll use the empty - // interface codec. - var codec Codec - switch val.Type().Elem() { - case tElement: - codec = defaultElementCodec - default: - codec, err = ec.Lookup(val.Type().Elem()) - if err != nil { - return err - } - } - for idx := 0; idx < length; idx++ { - vw, err := aw.WriteArrayElement() - if err != nil { - return err - } - - err = codec.EncodeValue(ec, vw, val.Index(idx).Interface()) - if err != nil { - return err - } - } - - return aw.WriteArrayEnd() -} - -// DecodeValue implements the Codec interface. -func (sc *SliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error { - val := reflect.ValueOf(i) - if !val.IsValid() || val.Kind() != reflect.Ptr || val.IsNil() { - return fmt.Errorf("%T can only be used to decode non-nil pointers to slice or array values, got %T", sc, i) - } - - switch val.Elem().Kind() { - case reflect.Slice, reflect.Array: - if !val.Elem().CanSet() { - return fmt.Errorf("%T can only decode settable slice and array values", sc) - } - default: - return fmt.Errorf("%T can only decode settable slice and array values, got %T", sc, i) - } - - switch vr.Type() { - case TypeArray: - case TypeNull: - if val.Elem().Kind() != reflect.Slice { - return fmt.Errorf("cannot decode %v into an array", vr.Type()) - } - null := reflect.Zero(val.Elem().Type()) - val.Elem().Set(null) - return vr.ReadNull() - default: - return fmt.Errorf("cannot decode %v into a slice", vr.Type()) - } - - eType := val.Type().Elem().Elem() - - ar, err := vr.ReadArray() - if err != nil { - return err - } - - var elems []reflect.Value - switch eType { - case tElement: - elems, err = sc.decodeElement(dc, ar) - default: - elems, err = sc.decodeDefault(dc, ar, eType) - } - - if err != nil { - return err - } - - switch val.Elem().Kind() { - case reflect.Slice: - slc := reflect.MakeSlice(val.Elem().Type(), len(elems), len(elems)) - - for idx, elem := range elems { - slc.Index(idx).Set(elem) - } - - val.Elem().Set(slc) - case reflect.Array: - if len(elems) > val.Elem().Len() { - return fmt.Errorf("more elements returned in array than can fit inside %s", val.Elem().Type()) - } - - for idx, elem := range elems { - val.Elem().Index(idx).Set(elem) - } - } - - return nil -} - -func (sc *SliceCodec) decodeElement(dc DecodeContext, ar ArrayReader) ([]reflect.Value, error) { - elems := make([]reflect.Value, 0) - for { - vr, err := ar.ReadValue() - if err == ErrEOA { - break - } - if err != nil { - return nil, err - } - - var elem *Element - err = defaultElementCodec.decodeValue(dc, vr, "", &elem) - if err != nil { - return nil, err - } - elems = append(elems, reflect.ValueOf(elem)) - } - - return elems, nil -} - -func (sc *SliceCodec) decodeDefault(dc DecodeContext, ar ArrayReader, eType reflect.Type) ([]reflect.Value, error) { - elems := make([]reflect.Value, 0) - - codec, err := dc.Lookup(eType) - if err != nil { - return nil, err - } - - for { - vr, err := ar.ReadValue() - if err == ErrEOA { - break - } - if err != nil { - return nil, err - } - - ptr := reflect.New(eType) - - err = codec.DecodeValue(dc, vr, ptr.Interface()) - if err != nil { - return nil, err - } - elems = append(elems, ptr.Elem()) - } - - return elems, nil -} diff --git a/bson/unmarshal.go b/bson/unmarshal.go deleted file mode 100644 index 3467f8f5e2..0000000000 --- a/bson/unmarshal.go +++ /dev/null @@ -1,68 +0,0 @@ -package bson - -import "sync" - -// This pool is used to keep the allocations of Decoders down. This is only used for the Marshal* -// methods and is not consumable from outside of this package. The Encoders retrieved from this pool -// must have both Reset and SetRegistry called on them. -var decPool = sync.Pool{ - New: func() interface{} { - return new(Decoderv2) - }, -} - -// Unmarshalv2 parses the BSON-encoded data and stores the result in the value -// pointed to by val. If val is nil or not a pointer, Unmarshal returns -// InvalidUnmarshalError. -func Unmarshalv2(data []byte, val interface{}) error { - return UnmarshalWithRegistry(defaultRegistry, data, val) -} - -// UnmarshalWithRegistry parses the BSON-encoded data using Registry r and -// stores the result in the value pointed to by val. If val is nil or not -// a pointer, UnmarshalWithRegistry returns InvalidUnmarshalError. -func UnmarshalWithRegistry(r *Registry, data []byte, val interface{}) error { - vr := newValueReader(data) - - dec := decPool.Get().(*Decoderv2) - defer decPool.Put(dec) - - err := dec.Reset(vr) - if err != nil { - return err - } - err = dec.SetRegistry(r) - if err != nil { - return err - } - - return dec.Decode(val) -} - -// UnmarshalDocumentv2 parses the *Document and stores the result in the value pointed to by val. If -// val is nil or not a pointer, UnmarshalDocument returns InvalidUnmarshalError. -func UnmarshalDocumentv2(d *Document, val interface{}) error { - return UnmarshalDocumentWithRegistry(defaultRegistry, d, val) -} - -// UnmarshalDocumentWithRegistry behaves the same as UnmarshalDocument but uses r as the *Registry. -func UnmarshalDocumentWithRegistry(r *Registry, d *Document, val interface{}) error { - dvr, err := NewDocumentValueReader(d) - if err != nil { - return err - } - - dec := decPool.Get().(*Decoderv2) - defer decPool.Put(dec) - - err = dec.Reset(dvr) - if err != nil { - return err - } - err = dec.SetRegistry(r) - if err != nil { - return err - } - - return dec.Decode(val) -} diff --git a/bson/value.go b/bson/value.go index 0403f442d8..3deda284fd 100644 --- a/bson/value.go +++ b/bson/value.go @@ -7,13 +7,13 @@ package bson import ( - "bytes" "encoding/binary" "fmt" "math" "time" "github.com/mongodb/mongo-go-driver/bson/decimal" + "github.com/mongodb/mongo-go-driver/bson/internal/llbson" "github.com/mongodb/mongo-go-driver/bson/objectid" ) @@ -105,6 +105,12 @@ func (v *Value) Interface() interface{} { } } +// Validate validates the value. +func (v *Value) Validate() error { + _, err := v.validate(false) + return err +} + func (v *Value) validate(sizeOnly bool) (uint32, error) { if v.data == nil { return 0, ErrUninitializedElement @@ -1035,7 +1041,8 @@ func (v *Value) Add(v2 *Value) error { return fmt.Errorf("cannot Add values of types %s and %s yet", v.Type(), v2.Type()) } -func (v *Value) equal(v2 *Value) bool { +// Equal will return true if this value is equal to val. +func (v *Value) Equal(v2 *Value) bool { if v == nil && v2 == nil { return true } @@ -1044,17 +1051,17 @@ func (v *Value) equal(v2 *Value) bool { return false } - if v.start != v2.start { + if v.data[v.start] != v2.data[v2.start] { return false } - if v.offset != v2.offset { - return false - } - - if v.d != nil && !v.d.Equal(v2.d) { - return false + if v.d != nil || v2.d != nil { + if v.d == nil || v2.d == nil { + return false + } + return v.d.Equal(v2.d) } - return bytes.Equal(v.data, v2.data) + t1, t2 := llbson.Type(v.data[v.start]), llbson.Type(v2.data[v2.start]) + return llbson.EqualValue(t1, t2, v.data[v.offset:], v2.data[v2.offset:]) } diff --git a/bson/value_read_writer_copy.go b/bson/value_read_writer_copy.go deleted file mode 100644 index a8db318152..0000000000 --- a/bson/value_read_writer_copy.go +++ /dev/null @@ -1,239 +0,0 @@ -package bson - -import ( - "fmt" - - "github.com/mongodb/mongo-go-driver/bson/decimal" - "github.com/mongodb/mongo-go-driver/bson/objectid" -) - -type copier struct{} - -// CopyDocument handles copying a document from src to dst. -func CopyDocument(dst ValueWriter, src ValueReader) error { - return copier{}.copyDocument(dst, src) -} - -func (c copier) copyDocument(dst ValueWriter, src ValueReader) error { - dr, err := src.ReadDocument() - if err != nil { - return err - } - - dw, err := dst.WriteDocument() - if err != nil { - return err - } - - return c.copyDocumentCore(dw, dr) -} - -func (c copier) copyDocumentCore(dw DocumentWriter, dr DocumentReader) error { - for { - key, vr, err := dr.ReadElement() - if err == ErrEOD { - break - } - if err != nil { - return err - } - - vw, err := dw.WriteDocumentElement(key) - if err != nil { - return err - } - - err = c.copyElement(vw, vr) - if err != nil { - return err - } - } - - return dw.WriteDocumentEnd() -} - -func (c copier) copyArray(dst ValueWriter, src ValueReader) error { - ar, err := src.ReadArray() - if err != nil { - return err - } - - aw, err := dst.WriteArray() - if err != nil { - return err - } - - for { - vr, err := ar.ReadValue() - if err == ErrEOD { - break - } - if err != nil { - return err - } - - vw, err := aw.WriteArrayElement() - if err != nil { - return err - } - - err = c.copyElement(vw, vr) - if err != nil { - return err - } - } - - return aw.WriteArrayEnd() -} - -func (c copier) copyElement(dst ValueWriter, src ValueReader) error { - var err error - switch src.Type() { - case TypeDouble: - var f64 float64 - f64, err = src.ReadDouble() - if err != nil { - break - } - err = dst.WriteDouble(f64) - case TypeString: - var str string - str, err = src.ReadString() - if err != nil { - return err - } - err = dst.WriteString(str) - case TypeEmbeddedDocument: - err = c.copyDocument(dst, src) - case TypeArray: - err = c.copyArray(dst, src) - case TypeBinary: - var data []byte - var subtype byte - data, subtype, err = src.ReadBinary() - if err != nil { - break - } - err = dst.WriteBinaryWithSubtype(data, subtype) - case TypeUndefined: - err = src.ReadUndefined() - if err != nil { - break - } - err = dst.WriteUndefined() - case TypeObjectID: - var oid objectid.ObjectID - oid, err = src.ReadObjectID() - if err != nil { - break - } - err = dst.WriteObjectID(oid) - case TypeBoolean: - var b bool - b, err = src.ReadBoolean() - if err != nil { - break - } - err = dst.WriteBoolean(b) - case TypeDateTime: - var dt int64 - dt, err = src.ReadDateTime() - if err != nil { - break - } - err = dst.WriteDateTime(dt) - case TypeNull: - err = src.ReadNull() - if err != nil { - break - } - err = dst.WriteNull() - case TypeRegex: - var pattern, options string - pattern, options, err = src.ReadRegex() - if err != nil { - break - } - err = dst.WriteRegex(pattern, options) - case TypeDBPointer: - var ns string - var pointer objectid.ObjectID - ns, pointer, err = src.ReadDBPointer() - if err != nil { - break - } - err = dst.WriteDBPointer(ns, pointer) - case TypeJavaScript: - var js string - js, err = src.ReadJavascript() - if err != nil { - break - } - err = dst.WriteJavascript(js) - case TypeSymbol: - var symbol string - symbol, err = src.ReadSymbol() - if err != nil { - break - } - err = dst.WriteSymbol(symbol) - case TypeCodeWithScope: - var code string - var srcScope DocumentReader - code, srcScope, err = src.ReadCodeWithScope() - if err != nil { - break - } - - var dstScope DocumentWriter - dstScope, err = dst.WriteCodeWithScope(code) - if err != nil { - break - } - err = c.copyDocumentCore(dstScope, srcScope) - case TypeInt32: - var i32 int32 - i32, err = src.ReadInt32() - if err != nil { - break - } - err = dst.WriteInt32(i32) - case TypeTimestamp: - var t, i uint32 - t, i, err = src.ReadTimestamp() - if err != nil { - break - } - err = dst.WriteTimestamp(t, i) - case TypeInt64: - var i64 int64 - i64, err = src.ReadInt64() - if err != nil { - break - } - err = dst.WriteInt64(i64) - case TypeDecimal128: - var d128 decimal.Decimal128 - d128, err = src.ReadDecimal128() - if err != nil { - break - } - err = dst.WriteDecimal128(d128) - case TypeMinKey: - err = src.ReadMinKey() - if err != nil { - break - } - err = dst.WriteMinKey() - case TypeMaxKey: - err = src.ReadMaxKey() - if err != nil { - break - } - err = dst.WriteMaxKey() - default: - err = fmt.Errorf("Cannot copy unknown BSON type %s", src.Type()) - } - - return err -} diff --git a/core/auth/mongodbcr.go b/core/auth/mongodbcr.go index 83f0d6bb56..ecac2f029f 100644 --- a/core/auth/mongodbcr.go +++ b/core/auth/mongodbcr.go @@ -14,6 +14,7 @@ import ( "io" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/command" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/wiremessage" @@ -67,7 +68,7 @@ func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, desc description.Serv Nonce string `bson:"nonce"` } - err = bson.Unmarshal(rdr, &getNonceResult) + err = bsoncodec.Unmarshal(rdr, &getNonceResult) if err != nil { return newAuthError("unmarshal error", err) } diff --git a/core/auth/sasl.go b/core/auth/sasl.go index 55745c7175..8e2dbd0244 100644 --- a/core/auth/sasl.go +++ b/core/auth/sasl.go @@ -10,6 +10,7 @@ import ( "context" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/command" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/wiremessage" @@ -72,7 +73,7 @@ func ConductSaslConversation(ctx context.Context, desc description.Server, rw wi return newError(err, mech) } - err = bson.Unmarshal(rdr, &saslResp) + err = bsoncodec.Unmarshal(rdr, &saslResp) if err != nil { return newAuthError("unmarshall error", err) } @@ -111,7 +112,7 @@ func ConductSaslConversation(ctx context.Context, desc description.Server, rw wi return newError(err, mech) } - err = bson.Unmarshal(rdr, &saslResp) + err = bsoncodec.Unmarshal(rdr, &saslResp) if err != nil { return newAuthError("unmarshal error", err) } diff --git a/core/command/abort_transaction.go b/core/command/abort_transaction.go index 8a8a1dc67c..7e4ccfe20a 100644 --- a/core/command/abort_transaction.go +++ b/core/command/abort_transaction.go @@ -10,6 +10,7 @@ import ( "context" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/result" "github.com/mongodb/mongo-go-driver/core/session" @@ -52,7 +53,7 @@ func (at *AbortTransaction) Decode(desc description.SelectedServer, wm wiremessa } func (at *AbortTransaction) decode(desc description.SelectedServer, rdr bson.Reader) *AbortTransaction { - at.err = bson.Unmarshal(rdr, &at.result) + at.err = bsoncodec.Unmarshal(rdr, &at.result) if at.err == nil && at.result.WriteConcernError != nil { at.err = Error{ Code: int32(at.result.WriteConcernError.Code), diff --git a/core/command/buildinfo.go b/core/command/buildinfo.go index f8e04a6baa..7bad44e8cd 100644 --- a/core/command/buildinfo.go +++ b/core/command/buildinfo.go @@ -11,6 +11,7 @@ import ( "fmt" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/result" "github.com/mongodb/mongo-go-driver/core/wiremessage" ) @@ -55,7 +56,7 @@ func (bi *BuildInfo) Decode(wm wiremessage.WireMessage) *BuildInfo { bi.err = err return bi } - err = bson.Unmarshal(rdr, &bi.res) + err = bsoncodec.Unmarshal(rdr, &bi.res) if err != nil { bi.err = err return bi diff --git a/core/command/commit_transaction.go b/core/command/commit_transaction.go index 5b7aa95821..aa05c44b2a 100644 --- a/core/command/commit_transaction.go +++ b/core/command/commit_transaction.go @@ -10,6 +10,7 @@ import ( "context" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/result" "github.com/mongodb/mongo-go-driver/core/session" @@ -52,7 +53,7 @@ func (ct *CommitTransaction) Decode(desc description.SelectedServer, wm wiremess } func (ct *CommitTransaction) decode(desc description.SelectedServer, rdr bson.Reader) *CommitTransaction { - ct.err = bson.Unmarshal(rdr, &ct.result) + ct.err = bsoncodec.Unmarshal(rdr, &ct.result) if ct.err == nil && ct.result.WriteConcernError != nil { ct.err = Error{ Code: int32(ct.result.WriteConcernError.Code), diff --git a/core/command/count_documents.go b/core/command/count_documents.go index 7b6404ee19..667593cce6 100644 --- a/core/command/count_documents.go +++ b/core/command/count_documents.go @@ -79,9 +79,9 @@ func (c *CountDocuments) Decode(ctx context.Context, desc description.SelectedSe return c } - var doc = bson.NewDocument() + var doc *bson.Document if cur.Next(ctx) { - err = cur.Decode(doc) + err = cur.Decode(&doc) if err != nil { c.err = err return c diff --git a/core/command/create_indexes.go b/core/command/create_indexes.go index 95202f4226..998df2e74c 100644 --- a/core/command/create_indexes.go +++ b/core/command/create_indexes.go @@ -10,6 +10,7 @@ import ( "context" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/option" "github.com/mongodb/mongo-go-driver/core/result" @@ -81,7 +82,7 @@ func (ci *CreateIndexes) Decode(desc description.SelectedServer, wm wiremessage. } func (ci *CreateIndexes) decode(desc description.SelectedServer, rdr bson.Reader) *CreateIndexes { - ci.err = bson.Unmarshal(rdr, &ci.result) + ci.err = bsoncodec.Unmarshal(rdr, &ci.result) return ci } diff --git a/core/command/delete.go b/core/command/delete.go index a55a8ac853..bd3e22c866 100644 --- a/core/command/delete.go +++ b/core/command/delete.go @@ -10,6 +10,7 @@ import ( "context" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/option" "github.com/mongodb/mongo-go-driver/core/result" @@ -97,7 +98,7 @@ func (d *Delete) Decode(desc description.SelectedServer, wm wiremessage.WireMess } func (d *Delete) decode(desc description.SelectedServer, rdr bson.Reader) *Delete { - d.err = bson.Unmarshal(rdr, &d.result) + d.err = bsoncodec.Unmarshal(rdr, &d.result) return d } diff --git a/core/command/distinct.go b/core/command/distinct.go index 6ce96eae49..2e5cae221f 100644 --- a/core/command/distinct.go +++ b/core/command/distinct.go @@ -10,6 +10,7 @@ import ( "context" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/option" "github.com/mongodb/mongo-go-driver/core/readconcern" @@ -92,7 +93,7 @@ func (d *Distinct) Decode(desc description.SelectedServer, wm wiremessage.WireMe } func (d *Distinct) decode(desc description.SelectedServer, rdr bson.Reader) *Distinct { - d.err = bson.Unmarshal(rdr, &d.result) + d.err = bsoncodec.Unmarshal(rdr, &d.result) return d } diff --git a/core/command/end_sessions.go b/core/command/end_sessions.go index a48c90676d..fbccf566c7 100644 --- a/core/command/end_sessions.go +++ b/core/command/end_sessions.go @@ -10,6 +10,7 @@ import ( "context" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/result" "github.com/mongodb/mongo-go-driver/core/session" @@ -116,7 +117,7 @@ func (es *EndSessions) Decode(desc description.SelectedServer, wm wiremessage.Wi func (es *EndSessions) decode(desc description.SelectedServer, rdr bson.Reader) *EndSessions { var res result.EndSessions - es.errors = append(es.errors, bson.Unmarshal(rdr, res)) + es.errors = append(es.errors, bsoncodec.Unmarshal(rdr, &res)) es.results = append(es.results, res) return es } diff --git a/core/command/getlasterror.go b/core/command/getlasterror.go index 9141dc92ae..b46ebd7cd9 100644 --- a/core/command/getlasterror.go +++ b/core/command/getlasterror.go @@ -12,6 +12,7 @@ import ( "fmt" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/readpref" "github.com/mongodb/mongo-go-driver/core/result" @@ -73,7 +74,7 @@ func (gle *GetLastError) Decode(wm wiremessage.WireMessage) *GetLastError { } func (gle *GetLastError) decode(rdr bson.Reader) *GetLastError { - err := bson.Unmarshal(rdr, &gle.res) + err := bsoncodec.Unmarshal(rdr, &gle.res) if err != nil { gle.err = err return gle diff --git a/core/command/insert.go b/core/command/insert.go index 6a93830c94..7d62820f08 100644 --- a/core/command/insert.go +++ b/core/command/insert.go @@ -10,6 +10,7 @@ import ( "context" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/option" "github.com/mongodb/mongo-go-driver/core/result" @@ -174,7 +175,7 @@ func (i *Insert) Decode(desc description.SelectedServer, wm wiremessage.WireMess } func (i *Insert) decode(desc description.SelectedServer, rdr bson.Reader) *Insert { - i.err = bson.Unmarshal(rdr, &i.result) + i.err = bsoncodec.Unmarshal(rdr, &i.result) return i } diff --git a/core/command/ismaster.go b/core/command/ismaster.go index e7ee8ac26e..bd9da6ece5 100644 --- a/core/command/ismaster.go +++ b/core/command/ismaster.go @@ -11,6 +11,7 @@ import ( "fmt" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/result" "github.com/mongodb/mongo-go-driver/core/wiremessage" ) @@ -75,7 +76,7 @@ func (im *IsMaster) Decode(wm wiremessage.WireMessage) *IsMaster { im.err = err return im } - err = bson.Unmarshal(rdr, &im.res) + err = bsoncodec.Unmarshal(rdr, &im.res) if err != nil { im.err = err return im diff --git a/core/command/kill_cursors.go b/core/command/kill_cursors.go index d098cce2f1..144798903d 100644 --- a/core/command/kill_cursors.go +++ b/core/command/kill_cursors.go @@ -10,6 +10,7 @@ import ( "context" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/result" "github.com/mongodb/mongo-go-driver/core/session" @@ -66,7 +67,7 @@ func (kc *KillCursors) Decode(desc description.SelectedServer, wm wiremessage.Wi } func (kc *KillCursors) decode(desc description.SelectedServer, rdr bson.Reader) *KillCursors { - err := bson.Unmarshal(rdr, &kc.result) + err := bsoncodec.Unmarshal(rdr, &kc.result) if err != nil { kc.err = err return kc diff --git a/core/command/list_databases.go b/core/command/list_databases.go index dd8969e4f8..97185597f4 100644 --- a/core/command/list_databases.go +++ b/core/command/list_databases.go @@ -10,6 +10,7 @@ import ( "context" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/option" "github.com/mongodb/mongo-go-driver/core/result" @@ -76,7 +77,7 @@ func (ld *ListDatabases) Decode(desc description.SelectedServer, wm wiremessage. } func (ld *ListDatabases) decode(desc description.SelectedServer, rdr bson.Reader) *ListDatabases { - ld.err = bson.Unmarshal(rdr, &ld.result) + ld.err = bsoncodec.Unmarshal(rdr, &ld.result) return ld } diff --git a/core/command/start_session.go b/core/command/start_session.go index 44a0e71468..731a1aab8b 100644 --- a/core/command/start_session.go +++ b/core/command/start_session.go @@ -10,6 +10,7 @@ import ( "context" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/result" "github.com/mongodb/mongo-go-driver/core/session" @@ -51,7 +52,7 @@ func (ss *StartSession) Decode(desc description.SelectedServer, wm wiremessage.W } func (ss *StartSession) decode(desc description.SelectedServer, rdr bson.Reader) *StartSession { - ss.err = bson.Unmarshal(rdr, &ss.result) + ss.err = bsoncodec.Unmarshal(rdr, &ss.result) return ss } diff --git a/core/command/update.go b/core/command/update.go index 60cc1e3e9f..7260efdd19 100644 --- a/core/command/update.go +++ b/core/command/update.go @@ -10,6 +10,7 @@ import ( "context" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/option" "github.com/mongodb/mongo-go-driver/core/result" @@ -96,7 +97,7 @@ func (u *Update) Decode(desc description.SelectedServer, wm wiremessage.WireMess } func (u *Update) decode(desc description.SelectedServer, rdr bson.Reader) *Update { - u.err = bson.Unmarshal(rdr, &u.result) + u.err = bsoncodec.Unmarshal(rdr, &u.result) return u } diff --git a/core/integration/aggregate_test.go b/core/integration/aggregate_test.go index 35a69d9a98..a072eee974 100644 --- a/core/integration/aggregate_test.go +++ b/core/integration/aggregate_test.go @@ -19,6 +19,7 @@ import ( "github.com/mongodb/mongo-go-driver/core/address" "github.com/mongodb/mongo-go-driver/core/command" "github.com/mongodb/mongo-go-driver/core/description" + "github.com/mongodb/mongo-go-driver/core/integration/internal/israce" "github.com/mongodb/mongo-go-driver/core/option" "github.com/mongodb/mongo-go-driver/core/topology" "github.com/mongodb/mongo-go-driver/core/writeconcern" @@ -75,13 +76,13 @@ func TestCommandAggregate(t *testing.T) { }).RoundTrip(context.Background(), server.SelectedDescription(), server, conn) noerr(t, err) - var next = make(bson.Reader, 1024) + var next bson.Reader for i := 4; i > 1; i-- { if !cursor.Next(context.Background()) { t.Error("Cursor should have results, but does not have a next result") } - err = cursor.Decode(next) + err = cursor.Decode(&next) noerr(t, err) if !bytes.Equal(next[:len(readers[i])], readers[i]) { t.Errorf("Did not get expected document. got %v; want %v", bson.Reader(next[:len(readers[i])]), readers[i]) @@ -213,9 +214,16 @@ func TestAggregatePassesMaxAwaitTimeMSThroughToGetMore(t *testing.T) { // wait a bit between insert and getMore commands time.Sleep(time.Millisecond * 100) + if israce.Enabled { + time.Sleep(time.Millisecond * 400) // wait a little longer when race detector is enabled. + } ctx, cancel := context.WithCancel(context.Background()) - time.AfterFunc(time.Millisecond*900, cancel) + if israce.Enabled { + time.AfterFunc(time.Millisecond*2000, cancel) + } else { + time.AfterFunc(time.Millisecond*900, cancel) + } for cursor.Next(ctx) { } diff --git a/core/integration/cursor_test.go b/core/integration/cursor_test.go index 5a14fa84ee..f28c26d895 100644 --- a/core/integration/cursor_test.go +++ b/core/integration/cursor_test.go @@ -58,8 +58,8 @@ func TestTailableCursorLoopsUntilDocsAvailable(t *testing.T) { assert.True(t, cursor.Next(context.Background()), "Cursor should have a next result") // make sure it's the right document - var next = make(bson.Reader, 1024) - err = cursor.Decode(next) + var next bson.Reader + err = cursor.Decode(&next) noerr(t, err) if !bytes.Equal(next[:len(rdr)], rdr) { @@ -88,8 +88,7 @@ func TestTailableCursorLoopsUntilDocsAvailable(t *testing.T) { noerr(t, cursor.Err()) // make sure it's the right document the second time - next = make(bson.Reader, 1024) - err = cursor.Decode(next) + err = cursor.Decode(&next) noerr(t, err) if !bytes.Equal(next[:len(rdr)], rdr) { diff --git a/core/integration/internal/israce/norace.go b/core/integration/internal/israce/norace.go new file mode 100644 index 0000000000..570dc13a28 --- /dev/null +++ b/core/integration/internal/israce/norace.go @@ -0,0 +1,7 @@ +// +build !race + +// Package israce reports if the Go race detector is enabled. +package israce + +// Enabled reports if the race detector is enabled. +const Enabled = false diff --git a/core/integration/internal/israce/race.go b/core/integration/internal/israce/race.go new file mode 100644 index 0000000000..498981efef --- /dev/null +++ b/core/integration/internal/israce/race.go @@ -0,0 +1,7 @@ +// +build race + +// Package israce reports if the Go race detector is enabled. +package israce + +// Enabled reports if the race detector is enabled. +const Enabled = true diff --git a/core/integration/list_collections_test.go b/core/integration/list_collections_test.go index ddf523db51..9cf108adaa 100644 --- a/core/integration/list_collections_test.go +++ b/core/integration/list_collections_test.go @@ -73,10 +73,10 @@ func TestCommandListCollections(t *testing.T) { noerr(t, err) names := map[string]bool{} - next := bson.NewDocument() + var next *bson.Document for cursor.Next(context.Background()) { - err = cursor.Decode(next) + err = cursor.Decode(&next) noerr(t, err) val, err := next.LookupErr("name") diff --git a/core/integration/list_indexes_test.go b/core/integration/list_indexes_test.go index 848d7d835e..fbdea619b6 100644 --- a/core/integration/list_indexes_test.go +++ b/core/integration/list_indexes_test.go @@ -38,7 +38,7 @@ func TestCommandListIndexes(t *testing.T) { var next *bson.Document for cursor.Next(context.Background()) { - err = cursor.Decode(next) + err = cursor.Decode(&next) noerr(t, err) val, err := next.LookupErr("name") @@ -67,7 +67,7 @@ func TestCommandListIndexes(t *testing.T) { var next *bson.Document for cursor.Next(context.Background()) { - err = cursor.Decode(next) + err = cursor.Decode(&next) noerr(t, err) val, err := next.LookupErr("name") @@ -99,10 +99,10 @@ func TestCommandListIndexes(t *testing.T) { noerr(t, err) indexes := []string{} - next := bson.NewDocument() + var next *bson.Document for cursor.Next(context.Background()) { - err = cursor.Decode(next) + err = cursor.Decode(&next) noerr(t, err) val, err := next.LookupErr("name") @@ -139,10 +139,10 @@ func TestCommandListIndexes(t *testing.T) { noerr(t, err) indexes := []string{} - next := bson.NewDocument() + var next *bson.Document for cursor.Next(context.Background()) { - err = cursor.Decode(next) + err = cursor.Decode(&next) noerr(t, err) val, err := next.LookupErr("name") diff --git a/core/option/options.go b/core/option/options.go index 5e7ae5c4bf..a3f0429d81 100644 --- a/core/option/options.go +++ b/core/option/options.go @@ -7,14 +7,14 @@ package option import ( - "time" - "fmt" - "io" "reflect" + "time" + "strconv" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" ) // Optioner is the interface implemented by types that can be used as options @@ -314,55 +314,95 @@ func (opt OptAllowPartialResults) String() string { return "OptAllowPartialResults: " + strconv.FormatBool(bool(opt)) } -// TransformDocument handles transforming a document of an allowable type into -// a *bson.Document. This method is called directly after most methods that -// have one or more parameters that are documents. -// -// The supported types for document are: -// -// bson.Marshaler -// bson.DocumentMarshaler -// bson.Reader -// []byte (must be a valid BSON document) -// io.Reader (only 1 BSON document will be read) -// A custom struct type +// // TransformDocument handles transforming a document of an allowable type into +// // a *bson.Document. This method is called directly after most methods that +// // have one or more parameters that are documents. +// // +// // The supported types for document are: +// // +// // bson.Marshaler +// // bson.DocumentMarshaler +// // bson.Reader +// // []byte (must be a valid BSON document) +// // io.Reader (only 1 BSON document will be read) +// // A custom struct type +// // +// func TransformDocument(document interface{}) (*bson.Document, error) { +// switch d := document.(type) { +// case nil: +// return bson.NewDocument(), nil +// case *bson.Document: +// return d, nil +// case bsoncodec.Marshaler, bson.Reader, []byte, io.Reader: +// return bson.NewDocumentEncoder().EncodeDocument(document) +// case bson.DocumentMarshaler: +// return d.MarshalBSONDocument() +// default: +// var kind reflect.Kind +// if t := reflect.TypeOf(document); t.Kind() == reflect.Ptr { +// kind = t.Elem().Kind() +// } +// if reflect.ValueOf(document).Kind() == reflect.Struct || kind == reflect.Struct { +// return bson.NewDocumentEncoder().EncodeDocument(document) +// } +// if reflect.ValueOf(document).Kind() == reflect.Map && +// reflect.TypeOf(document).Key().Kind() == reflect.String { +// return bson.NewDocumentEncoder().EncodeDocument(document) +// } // -func TransformDocument(document interface{}) (*bson.Document, error) { - switch d := document.(type) { - case nil: +// return nil, fmt.Errorf("cannot transform type %s to a *bson.Document", reflect.TypeOf(document)) +// } +// } + +// MarshalError is returned when attempting to transform a value into a document +// results in an error. +type MarshalError struct { + Value interface{} + Err error +} + +// Error implements the error interface. +func (me MarshalError) Error() string { + return fmt.Sprintf("cannot transform type %s to a *bson.Document", reflect.TypeOf(me.Value)) +} + +var defaultRegistry = bsoncodec.NewRegistryBuilder().Build() + +func transformDocument(registry *bsoncodec.Registry, val interface{}) (*bson.Document, error) { + if val == nil { return bson.NewDocument(), nil - case *bson.Document: - return d, nil - case bson.Marshaler, bson.Reader, []byte, io.Reader: - return bson.NewDocumentEncoder().EncodeDocument(document) - case bson.DocumentMarshaler: - return d.MarshalBSONDocument() - default: - var kind reflect.Kind - if t := reflect.TypeOf(document); t.Kind() == reflect.Ptr { - kind = t.Elem().Kind() - } - if reflect.ValueOf(document).Kind() == reflect.Struct || kind == reflect.Struct { - return bson.NewDocumentEncoder().EncodeDocument(document) - } - if reflect.ValueOf(document).Kind() == reflect.Map && - reflect.TypeOf(document).Key().Kind() == reflect.String { - return bson.NewDocumentEncoder().EncodeDocument(document) - } + } + reg := defaultRegistry + if registry != nil { + reg = registry + } - return nil, fmt.Errorf("cannot transform type %s to a *bson.Document", reflect.TypeOf(document)) + if bs, ok := val.([]byte); ok { + // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery. + val = bson.Reader(bs) } + + // TODO(skriptble): Use a pool of these instead. + buf := make([]byte, 0, 256) + b, err := bsoncodec.MarshalAppendWithRegistry(reg, buf, val) + if err != nil { + return nil, MarshalError{Value: val, Err: err} + } + return bson.ReadDocument(b) } // OptArrayFilters is for internal use. //type OptArrayFilters []*bson.Document -type OptArrayFilters []interface{} +type OptArrayFilters struct { + Registry *bsoncodec.Registry + Filters []interface{} +} // Option implements the Optioner interface. func (opt OptArrayFilters) Option(d *bson.Document) error { - docs := make([]*bson.Document, 0, len(opt)) - for _, f := range opt { - d, err := TransformDocument(f) + docs := make([]*bson.Document, 0, len(opt.Filters)) + for _, f := range opt.Filters { + d, err := transformDocument(opt.Registry, f) if err != nil { return err } @@ -556,12 +596,13 @@ func (opt OptLimit) String() string { // OptMax is for internal use. type OptMax struct { - Max interface{} + Registry *bsoncodec.Registry + Max interface{} } // Option implements the Optioner interface. func (opt OptMax) Option(d *bson.Document) error { - doc, err := TransformDocument(opt.Max) + doc, err := transformDocument(opt.Registry, opt.Max) if err != nil { return err } @@ -643,12 +684,13 @@ func (opt OptMaxTime) String() string { // OptMin is for internal use. type OptMin struct { - Min interface{} + Registry *bsoncodec.Registry + Min interface{} } // Option implements the Optioner interface. func (opt OptMin) Option(d *bson.Document) error { - doc, err := TransformDocument(opt.Min) + doc, err := transformDocument(opt.Registry, opt.Min) if err != nil { return err } @@ -718,6 +760,7 @@ func (opt OptOrdered) String() string { // OptProjection is for internal use. type OptProjection struct { + Registry *bsoncodec.Registry Projection interface{} } @@ -725,7 +768,7 @@ type OptProjection struct { func (opt OptProjection) Option(d *bson.Document) error { var key = "projection" - doc, err := TransformDocument(opt.Projection) + doc, err := transformDocument(opt.Registry, opt.Projection) if err != nil { return err } @@ -747,13 +790,14 @@ func (opt OptProjection) String() string { // OptFields is for internal use. type OptFields struct { - Fields interface{} + Registry *bsoncodec.Registry + Fields interface{} } // Option implements the Optioner interface. func (opt OptFields) Option(d *bson.Document) error { var key = "fields" - doc, err := TransformDocument(opt.Fields) + doc, err := transformDocument(opt.Registry, opt.Fields) if err != nil { return err } @@ -877,12 +921,13 @@ func (opt OptSnapshot) String() string { // OptSort is for internal use. type OptSort struct { - Sort interface{} + Registry *bsoncodec.Registry + Sort interface{} } // Option implements the Optioner interface. func (opt OptSort) Option(d *bson.Document) error { - doc, err := TransformDocument(opt.Sort) + doc, err := transformDocument(opt.Registry, opt.Sort) if err != nil { return err } diff --git a/core/topology/cursor.go b/core/topology/cursor.go index 880b3cca92..caf418e334 100644 --- a/core/topology/cursor.go +++ b/core/topology/cursor.go @@ -7,12 +7,12 @@ package topology import ( - "bytes" "context" "errors" "fmt" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/command" "github.com/mongodb/mongo-go-driver/core/option" "github.com/mongodb/mongo-go-driver/core/session" @@ -28,6 +28,7 @@ type cursor struct { err error server *Server opts []option.CursorOptioner + registry *bsoncodec.Registry } func newCursor(result bson.Reader, clientSession *session.Client, clock *session.ClusterClock, server *Server, opts ...option.CursorOptioner) (command.Cursor, error) { @@ -49,6 +50,7 @@ func newCursor(result bson.Reader, clientSession *session.Client, clock *session clock: clock, current: -1, server: server, + registry: server.cfg.registry, opts: opts, } var ok bool @@ -125,7 +127,8 @@ func (c *cursor) Decode(v interface{}) error { if err != nil { return err } - return bson.NewDecoder(bytes.NewReader(br)).Decode(v) + + return bsoncodec.UnmarshalWithRegistry(c.registry, br, v) } func (c *cursor) DecodeBytes() (bson.Reader, error) { diff --git a/core/topology/server_options.go b/core/topology/server_options.go index 4bc315bca0..1a234e4db9 100644 --- a/core/topology/server_options.go +++ b/core/topology/server_options.go @@ -9,10 +9,13 @@ package topology import ( "time" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/connection" "github.com/mongodb/mongo-go-driver/core/session" ) +var defaultRegistry = bsoncodec.NewRegistryBuilder().Build() + type serverConfig struct { clock *session.ClusterClock compressionOpts []string @@ -22,6 +25,7 @@ type serverConfig struct { heartbeatTimeout time.Duration maxConns uint16 maxIdleConns uint16 + registry *bsoncodec.Registry } func newServerConfig(opts ...ServerOption) (*serverConfig, error) { @@ -30,6 +34,7 @@ func newServerConfig(opts ...ServerOption) (*serverConfig, error) { heartbeatTimeout: 30 * time.Second, maxConns: 100, maxIdleConns: 100, + registry: defaultRegistry, } for _, opt := range opts { @@ -104,3 +109,12 @@ func WithClock(fn func(clock *session.ClusterClock) *session.ClusterClock) Serve return nil } } + +// WithRegistry configures the registry for the server to use when creating +// cursors. +func WithRegistry(fn func(*bsoncodec.Registry) *bsoncodec.Registry) ServerOption { + return func(cfg *serverConfig) error { + cfg.registry = fn(cfg.registry) + return nil + } +} diff --git a/core/topology/topology.go b/core/topology/topology.go index 394ea4c4c3..8ab7149bff 100644 --- a/core/topology/topology.go +++ b/core/topology/topology.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "time" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/address" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/session" @@ -50,6 +51,8 @@ const ( // Topology represents a MongoDB deployment. type Topology struct { + registry *bsoncodec.Registry + connectionstate int32 cfg *config diff --git a/mongo/change_stream.go b/mongo/change_stream.go index cb78c3f4ea..2b9b11447c 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -7,12 +7,12 @@ package mongo import ( - "bytes" "context" "errors" "time" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/command" "github.com/mongodb/mongo-go-driver/core/option" "github.com/mongodb/mongo-go-driver/core/session" @@ -41,7 +41,7 @@ const errorCodeCursorNotFound int32 = 43 func newChangeStream(ctx context.Context, coll *Collection, pipeline interface{}, opts ...changestreamopt.ChangeStream) (*changeStream, error) { - pipelineArr, err := transformAggregatePipeline(pipeline) + pipelineArr, err := transformAggregatePipeline(coll.registry, pipeline) if err != nil { return nil, err } @@ -193,7 +193,7 @@ func (cs *changeStream) Decode(out interface{}) error { return err } - return bson.NewDecoder(bytes.NewReader(br)).Decode(out) + return bsoncodec.UnmarshalWithRegistry(cs.coll.registry, br, out) } func (cs *changeStream) DecodeBytes() (bson.Reader, error) { diff --git a/mongo/change_stream_test.go b/mongo/change_stream_test.go index 4ad871ea17..c5b582e1ff 100644 --- a/mongo/change_stream_test.go +++ b/mongo/change_stream_test.go @@ -14,6 +14,7 @@ import ( "strings" + "github.com/google/go-cmp/cmp" "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/core/command" "github.com/mongodb/mongo-go-driver/core/option" @@ -125,14 +126,16 @@ func TestChangeStream_trackResumeToken(t *testing.T) { for i := 1; i <= 4; i++ { getNextChange(changes) - doc := bson.NewDocument() - err := changes.Decode(doc) + var doc *bson.Document + err := changes.Decode(&doc) require.NoError(t, err) id, err := doc.LookupErr("_id") require.NoError(t, err) - require.Equal(t, id.MutableDocument(), changes.(*changeStream).resumeToken) + if !cmp.Equal(id.MutableDocument(), changes.(*changeStream).resumeToken) { + t.Errorf("Resume tokens do not match. got %v; want %v", id.MutableDocument(), changes.(*changeStream).resumeToken) + } } } @@ -264,5 +267,6 @@ func TestChangeStream_resumeAfterKillCursors(t *testing.T) { }() getNextChange(changes) - require.NoError(t, changes.Decode(bson.NewDocument())) + var doc *bson.Document + require.NoError(t, changes.Decode(&doc)) } diff --git a/mongo/client.go b/mongo/client.go index a185e3e1ab..5f44e2c982 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -10,6 +10,7 @@ import ( "context" "time" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/command" "github.com/mongodb/mongo-go-driver/core/connstring" "github.com/mongodb/mongo-go-driver/core/description" @@ -29,6 +30,8 @@ import ( const defaultLocalThreshold = 15 * time.Millisecond +var defaultRegistry = bsoncodec.NewRegistryBuilder().Build() + // Client performs operations on a given topology. type Client struct { id uuid.UUID @@ -41,6 +44,8 @@ type Client struct { readPreference *readpref.ReadPref readConcern *readconcern.ReadConcern writeConcern *writeconcern.WriteConcern + registry *bsoncodec.Registry + marshaller BSONAppender } // Connect creates a new Client and then initializes it using the Connect method. @@ -176,6 +181,7 @@ func newClient(cs connstring.ConnString, opts ...clientopt.Option) (*Client, err topologyOptions: clientOpt.TopologyOptions, connString: clientOpt.ConnString, localThreshold: defaultLocalThreshold, + registry: clientOpt.Registry, } uuid, err := uuid.New() @@ -223,6 +229,10 @@ func newClient(cs connstring.ConnString, opts ...clientopt.Option) (*Client, err client.readPreference = readpref.Primary() } } + + if client.registry == nil { + client.registry = defaultRegistry + } return client, nil } @@ -333,7 +343,7 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ... return ListDatabasesResult{}, err } - f, err := TransformDocument(filter) + f, err := transformDocument(c.registry, filter) if err != nil { return ListDatabasesResult{}, err } diff --git a/mongo/client_internal_test.go b/mongo/client_internal_test.go index bc0159cd47..c969cc85b9 100644 --- a/mongo/client_internal_test.go +++ b/mongo/client_internal_test.go @@ -15,6 +15,7 @@ import ( "fmt" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/readpref" "github.com/mongodb/mongo-go-driver/core/tag" "github.com/mongodb/mongo-go-driver/internal/testutil" @@ -37,6 +38,7 @@ func createTestClient(t *testing.T) *Client { connString: testutil.ConnString(t), readPreference: readpref.Primary(), clock: &session.ClusterClock{}, + registry: defaultRegistry, } } @@ -184,7 +186,7 @@ func TestClient_X509Auth(t *testing.T) { DB string } - if err := bson.Unmarshal(rdr, &u); err != nil { + if err := bsoncodec.Unmarshal(rdr, &u); err != nil { continue } diff --git a/mongo/client_options_test.go b/mongo/client_options_test.go index 2426b2ec48..f700417668 100644 --- a/mongo/client_options_test.go +++ b/mongo/client_options_test.go @@ -14,6 +14,7 @@ import ( "time" + "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/core/connstring" "github.com/mongodb/mongo-go-driver/core/readconcern" "github.com/mongodb/mongo-go-driver/core/readpref" @@ -182,7 +183,7 @@ func TestClientOptions_CustomDialer(t *testing.T) { require.NoError(t, err) err = client.Connect(context.Background()) require.NoError(t, err) - _, err = client.ListDatabases(context.Background(), nil) + _, err = client.ListDatabases(context.Background(), bson.NewDocument()) require.NoError(t, err) got := atomic.LoadInt32(&td.called) if got < 1 { diff --git a/mongo/clientopt/clientopt.go b/mongo/clientopt/clientopt.go index 948a2cfd94..6d0fc08d62 100644 --- a/mongo/clientopt/clientopt.go +++ b/mongo/clientopt/clientopt.go @@ -13,6 +13,7 @@ import ( "reflect" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/connection" "github.com/mongodb/mongo-go-driver/core/connstring" "github.com/mongodb/mongo-go-driver/core/event" @@ -88,6 +89,7 @@ type Client struct { ReadPreference *readpref.ReadPref ReadConcern *readconcern.ReadConcern WriteConcern *writeconcern.WriteConcern + Registry *bsoncodec.Registry } // ClientBundle is a bundle of client options @@ -230,6 +232,14 @@ func (cb *ClientBundle) ReadPreference(rp *readpref.ReadPref) *ClientBundle { } } +// Registry specifies the bsoncodec.Registry. +func (cb *ClientBundle) Registry(registry *bsoncodec.Registry) *ClientBundle { + return &ClientBundle{ + option: Registry(registry), + next: cb, + } +} + // ReplicaSet specifies the name of the replica set of the cluster. func (cb *ClientBundle) ReplicaSet(s string) *ClientBundle { return &ClientBundle{ @@ -530,7 +540,7 @@ func ReadConcern(rc *readconcern.ReadConcern) Option { }) } -// ReadPreference specifies the read preference +// ReadPreference specifies the read preference. func ReadPreference(rp *readpref.ReadPref) Option { return optionFunc( func(c *Client) error { @@ -541,6 +551,16 @@ func ReadPreference(rp *readpref.ReadPref) Option { }) } +// Registry specifies the bsoncodec.Registry. +func Registry(registry *bsoncodec.Registry) Option { + return optionFunc(func(c *Client) error { + if c.Registry == nil { + c.Registry = registry + } + return nil + }) +} + // ReplicaSet specifies the name of the replica set of the cluster. func ReplicaSet(s string) Option { return optionFunc( diff --git a/mongo/collection.go b/mongo/collection.go index 2c9e14b0b6..4667b66317 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -12,6 +12,7 @@ import ( "strings" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/command" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/dispatch" @@ -43,6 +44,7 @@ type Collection struct { readPreference *readpref.ReadPref readSelector description.ServerSelector writeSelector description.ServerSelector + registry *bsoncodec.Registry } func newCollection(db *Database, name string, opts ...collectionopt.Option) *Collection { @@ -80,6 +82,7 @@ func newCollection(db *Database, name string, opts ...collectionopt.Option) *Col writeConcern: wc, readSelector: readSelector, writeSelector: db.writeSelector, + registry: db.registry, } return coll @@ -157,7 +160,7 @@ func (coll *Collection) InsertOne(ctx context.Context, document interface{}, ctx = context.Background() } - doc, err := TransformDocument(document) + doc, err := transformDocument(coll.registry, document) if err != nil { return nil, err } @@ -230,7 +233,7 @@ func (coll *Collection) InsertMany(ctx context.Context, documents []interface{}, docs := make([]*bson.Document, len(documents)) for i, doc := range documents { - bdoc, err := TransformDocument(doc) + bdoc, err := transformDocument(coll.registry, doc) if err != nil { return nil, err } @@ -308,7 +311,7 @@ func (coll *Collection) DeleteOne(ctx context.Context, filter interface{}, ctx = context.Background() } - f, err := TransformDocument(filter) + f, err := transformDocument(coll.registry, filter) if err != nil { return nil, err } @@ -373,7 +376,7 @@ func (coll *Collection) DeleteMany(ctx context.Context, filter interface{}, ctx = context.Background() } - f, err := TransformDocument(filter) + f, err := transformDocument(coll.registry, filter) if err != nil { return nil, err } @@ -491,12 +494,12 @@ func (coll *Collection) UpdateOne(ctx context.Context, filter interface{}, updat ctx = context.Background() } - f, err := TransformDocument(filter) + f, err := transformDocument(coll.registry, filter) if err != nil { return nil, err } - u, err := TransformDocument(update) + u, err := transformDocument(coll.registry, update) if err != nil { return nil, err } @@ -531,12 +534,12 @@ func (coll *Collection) UpdateMany(ctx context.Context, filter interface{}, upda ctx = context.Background() } - f, err := TransformDocument(filter) + f, err := transformDocument(coll.registry, filter) if err != nil { return nil, err } - u, err := TransformDocument(update) + u, err := transformDocument(coll.registry, update) if err != nil { return nil, err } @@ -619,12 +622,12 @@ func (coll *Collection) ReplaceOne(ctx context.Context, filter interface{}, ctx = context.Background() } - f, err := TransformDocument(filter) + f, err := transformDocument(coll.registry, filter) if err != nil { return nil, err } - r, err := TransformDocument(replacement) + r, err := transformDocument(coll.registry, replacement) if err != nil { return nil, err } @@ -666,7 +669,7 @@ func (coll *Collection) Aggregate(ctx context.Context, pipeline interface{}, ctx = context.Background() } - pipelineArr, err := transformAggregatePipeline(pipeline) + pipelineArr, err := transformAggregatePipeline(coll.registry, pipeline) if err != nil { return nil, err } @@ -727,7 +730,7 @@ func (coll *Collection) Count(ctx context.Context, filter interface{}, ctx = context.Background() } - f, err := TransformDocument(filter) + f, err := transformDocument(coll.registry, filter) if err != nil { return 0, err } @@ -779,7 +782,7 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, ctx = context.Background() } - pipelineArr, err := countDocumentsAggregatePipeline(filter, opts...) + pipelineArr, err := countDocumentsAggregatePipeline(coll.registry, filter, opts...) if err != nil { return 0, err } @@ -877,7 +880,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i var f *bson.Document var err error if filter != nil { - f, err = TransformDocument(filter) + f, err = transformDocument(coll.registry, filter) if err != nil { return nil, err } @@ -940,7 +943,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, var f *bson.Document var err error if filter != nil { - f, err = TransformDocument(filter) + f, err = transformDocument(coll.registry, filter) if err != nil { return nil, err } @@ -998,7 +1001,7 @@ func (coll *Collection) FindOne(ctx context.Context, filter interface{}, var f *bson.Document var err error if filter != nil { - f, err = TransformDocument(filter) + f, err = transformDocument(coll.registry, filter) if err != nil { return &DocumentResult{err: err} } @@ -1042,7 +1045,7 @@ func (coll *Collection) FindOne(ctx context.Context, filter interface{}, return &DocumentResult{err: err} } - return &DocumentResult{cur: cursor} + return &DocumentResult{cur: cursor, reg: coll.registry} } // FindOneAndDelete find a single document and deletes it, returning the @@ -1064,7 +1067,7 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{} var f *bson.Document var err error if filter != nil { - f, err = TransformDocument(filter) + f, err = transformDocument(coll.registry, filter) if err != nil { return &DocumentResult{err: err} } @@ -1107,7 +1110,7 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{} return &DocumentResult{err: err} } - return &DocumentResult{rdr: res.Value} + return &DocumentResult{rdr: res.Value, reg: coll.registry} } // FindOneAndReplace finds a single document and replaces it, returning either @@ -1126,12 +1129,12 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ ctx = context.Background() } - f, err := TransformDocument(filter) + f, err := transformDocument(coll.registry, filter) if err != nil { return &DocumentResult{err: err} } - r, err := TransformDocument(replacement) + r, err := transformDocument(coll.registry, replacement) if err != nil { return &DocumentResult{err: err} } @@ -1179,7 +1182,7 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ return &DocumentResult{err: err} } - return &DocumentResult{rdr: res.Value} + return &DocumentResult{rdr: res.Value, reg: coll.registry} } // FindOneAndUpdate finds a single document and updates it, returning either @@ -1198,12 +1201,12 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} ctx = context.Background() } - f, err := TransformDocument(filter) + f, err := transformDocument(coll.registry, filter) if err != nil { return &DocumentResult{err: err} } - u, err := TransformDocument(update) + u, err := transformDocument(coll.registry, update) if err != nil { return &DocumentResult{err: err} } @@ -1250,7 +1253,7 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} return &DocumentResult{err: err} } - return &DocumentResult{rdr: res.Value} + return &DocumentResult{rdr: res.Value, reg: coll.registry} } // Watch returns a change stream cursor used to receive notifications of changes to the collection. diff --git a/mongo/collection_internal_test.go b/mongo/collection_internal_test.go index 7acd3ae08c..fc58857810 100644 --- a/mongo/collection_internal_test.go +++ b/mongo/collection_internal_test.go @@ -12,6 +12,7 @@ import ( "os" "testing" + "github.com/google/go-cmp/cmp" "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/bson/objectid" "github.com/mongodb/mongo-go-driver/core/readconcern" @@ -198,7 +199,9 @@ func TestCollection_InsertOne(t *testing.T) { result, err := coll.InsertOne(context.Background(), doc) require.Nil(t, err) - require.Equal(t, result.InsertedID, want) + if !cmp.Equal(result.InsertedID, want) { + t.Errorf("Result documents do not match. got %v; want %v", result.InsertedID, want) + } } @@ -277,9 +280,13 @@ func TestCollection_InsertMany(t *testing.T) { require.Nil(t, err) require.Len(t, result.InsertedIDs, 3) - require.Equal(t, result.InsertedIDs[0], want1) + if !cmp.Equal(result.InsertedIDs[0], want1) { + t.Errorf("Result documents do not match. got %v; want %v", result.InsertedIDs[0], want1) + } require.NotNil(t, result.InsertedIDs[1]) - require.Equal(t, result.InsertedIDs[2], want2) + if !cmp.Equal(result.InsertedIDs[2], want2) { + t.Errorf("Result documents do not match. got %v; want %v", result.InsertedIDs[2], want2) + } } @@ -435,6 +442,7 @@ func TestCollection_InsertMany_WriteConcernError(t *testing.T) { got, ok := err.(BulkWriteError) if !ok { t.Errorf("Did not receive correct type of error. got %T; want %T\nError message: %s", err, BulkWriteError{}, err) + t.Errorf("got error message %v", err) } if got.WriteConcernError.Code != want.Code { t.Errorf("Did not receive the correct error code. got %d; want %d", got.WriteConcernError.Code, want.Code) @@ -1122,9 +1130,9 @@ func TestCollection_Aggregate(t *testing.T) { require.Nil(t, err) for i := 2; i < 5; i++ { - var doc = bson.NewDocument() + var doc *bson.Document cursor.Next(context.Background()) - err = cursor.Decode(doc) + err = cursor.Decode(&doc) require.NoError(t, err) require.Equal(t, doc.Len(), 1) @@ -1184,9 +1192,9 @@ func testAggregateWithOptions(t *testing.T, createIndex bool, opts aggregateopt. } for i := 2; i < 5; i++ { - var doc = bson.NewDocument() + var doc *bson.Document cursor.Next(context.Background()) - err = cursor.Decode(doc) + err = cursor.Decode(&doc) if err != nil { return err } @@ -1445,9 +1453,9 @@ func TestCollection_Find_found(t *testing.T) { require.Nil(t, err) results := make([]int, 0, 5) - var doc = make(bson.Reader, 1024) + var doc bson.Reader for cursor.Next(context.Background()) { - err = cursor.Decode(doc) + err = cursor.Decode(&doc) require.NoError(t, err) _, err = doc.Lookup("_id") @@ -1493,10 +1501,10 @@ func TestCollection_FindOne_found(t *testing.T) { initCollection(t, coll) filter := bson.NewDocument(bson.EC.Int32("x", 1)) - var result = bson.NewDocument() + var result *bson.Document err := coll.FindOne(context.Background(), filter, - ).Decode(result) + ).Decode(&result) require.Nil(t, err) require.Equal(t, result.Len(), 2) @@ -1524,11 +1532,11 @@ func TestCollection_FindOne_found_withOption(t *testing.T) { initCollection(t, coll) filter := bson.NewDocument(bson.EC.Int32("x", 1)) - var result = bson.NewDocument() + var result *bson.Document err := coll.FindOne(context.Background(), filter, findopt.Comment("here's a query for ya"), - ).Decode(result) + ).Decode(&result) require.Nil(t, err) require.Equal(t, result.Len(), 2) @@ -1571,8 +1579,8 @@ func TestCollection_FindOneAndDelete_found(t *testing.T) { filter := bson.NewDocument(bson.EC.Int32("x", 3)) - var result = bson.NewDocument() - err := coll.FindOneAndDelete(context.Background(), filter).Decode(result) + var result *bson.Document + err := coll.FindOneAndDelete(context.Background(), filter).Decode(&result) require.NoError(t, err) elem, err := result.LookupErr("x") @@ -1642,8 +1650,8 @@ func TestCollection_FindOneAndReplace_found(t *testing.T) { filter := bson.NewDocument(bson.EC.Int32("x", 3)) replacement := bson.NewDocument(bson.EC.Int32("y", 3)) - var result = bson.NewDocument() - err := coll.FindOneAndReplace(context.Background(), filter, replacement).Decode(result) + var result *bson.Document + err := coll.FindOneAndReplace(context.Background(), filter, replacement).Decode(&result) require.NoError(t, err) elem, err := result.LookupErr("x") @@ -1717,8 +1725,8 @@ func TestCollection_FindOneAndUpdate_found(t *testing.T) { update := bson.NewDocument( bson.EC.SubDocumentFromElements("$set", bson.EC.Int32("x", 6))) - var result = bson.NewDocument() - err := coll.FindOneAndUpdate(context.Background(), filter, update).Decode(result) + var result *bson.Document + err := coll.FindOneAndUpdate(context.Background(), filter, update).Decode(&result) require.NoError(t, err) elem, err := result.LookupErr("x") diff --git a/mongo/crud_util_test.go b/mongo/crud_util_test.go index cf60583937..66b4a0ef59 100644 --- a/mongo/crud_util_test.go +++ b/mongo/crud_util_test.go @@ -536,8 +536,8 @@ func verifyCursorResult(t *testing.T, cur Cursor, result json.RawMessage) { require.NotNil(t, cur) require.True(t, cur.Next(context.Background())) - actual := bson.NewDocument() - require.NoError(t, cur.Decode(actual)) + var actual *bson.Document + require.NoError(t, cur.Decode(&actual)) compareDocs(t, expected, actual) } @@ -550,8 +550,8 @@ func verifyDocumentResult(t *testing.T, res *DocumentResult, result json.RawMess jsonBytes, err := result.MarshalJSON() require.NoError(t, err) - actual := bson.NewDocument() - err = res.Decode(actual) + var actual *bson.Document + err = res.Decode(&actual) if err == ErrNoDocuments { var expected map[string]interface{} err := json.NewDecoder(bytes.NewBuffer(jsonBytes)).Decode(&expected) diff --git a/mongo/database.go b/mongo/database.go index b36ec3e10e..1b46c9b922 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -10,6 +10,7 @@ import ( "context" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/command" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/dispatch" @@ -32,6 +33,7 @@ type Database struct { readPreference *readpref.ReadPref readSelector description.ServerSelector writeSelector description.ServerSelector + registry *bsoncodec.Registry } func newDatabase(client *Client, name string, opts ...dbopt.Option) *Database { @@ -61,6 +63,7 @@ func newDatabase(client *Client, name string, opts ...dbopt.Option) *Database { readPreference: rp, readConcern: rc, writeConcern: wc, + registry: client.registry, } db.readSelector = description.CompositeSelector([]description.ServerSelector{ @@ -110,7 +113,7 @@ func (db *Database) RunCommand(ctx context.Context, runCommand interface{}, opts } } - runCmdDoc, err := TransformDocument(runCommand) + runCmdDoc, err := transformDocument(db.registry, runCommand) if err != nil { return nil, err } diff --git a/mongo/document_result.go b/mongo/document_result.go index e73a06dd5a..76718a7617 100644 --- a/mongo/document_result.go +++ b/mongo/document_result.go @@ -11,6 +11,7 @@ import ( "errors" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" ) // ErrNoDocuments is returned by Decode when an operation that returns a @@ -24,6 +25,7 @@ type DocumentResult struct { err error cur Cursor rdr bson.Reader + reg *bsoncodec.Registry } // Decode will attempt to decode the first document into v. If there was an @@ -38,7 +40,7 @@ func (dr *DocumentResult) Decode(v interface{}) error { if v == nil { return nil } - return bson.Unmarshal(dr.rdr, v) + return bsoncodec.UnmarshalWithRegistry(dr.reg, dr.rdr, v) case dr.cur != nil: defer dr.cur.Close(context.TODO()) if !dr.cur.Next(context.TODO()) { diff --git a/mongo/findopt/findopt.go b/mongo/findopt/findopt.go index eb1f7c78c0..a516f50d2a 100644 --- a/mongo/findopt/findopt.go +++ b/mongo/findopt/findopt.go @@ -92,7 +92,7 @@ func AllowPartialResults(b bool) OptAllowPartialResults { // ArrayFilters specifies which array elements an update should apply. // UpdateOne func ArrayFilters(filters ...interface{}) OptArrayFilters { - return OptArrayFilters(filters) + return OptArrayFilters{Filters: filters} } // BatchSize specifies the number of documents to return in each batch. @@ -142,7 +142,7 @@ func Limit(i int64) OptLimit { // Max sets an exclusive upper bound for a specific index. // Find, One func Max(max interface{}) OptMax { - return OptMax{max} + return OptMax{Max: max} } // MaxAwaitTime specifies the max amount of time for the server to wait on new documents. @@ -166,7 +166,7 @@ func MaxTime(d time.Duration) OptMaxTime { // Min specifies the inclusive lower bound for a specific index. // Find, One func Min(min interface{}) OptMin { - return OptMin{min} + return OptMin{Min: min} } // NoCursorTimeout prevents cursors from timing out after an inactivity period. @@ -223,7 +223,7 @@ func Snapshot(b bool) OptSnapshot { // Sort specifies the order in which to return results. // Find, One, DeleteOne, ReplaceOne, UpdateOne func Sort(sort interface{}) OptSort { - return OptSort{sort} + return OptSort{Sort: sort} } // Upsert specifies whether a document should be inserted if no match is found. diff --git a/mongo/findopt/findopt_deleteone_test.go b/mongo/findopt/findopt_deleteone_test.go index 7e5328fc77..51a41f17a2 100644 --- a/mongo/findopt/findopt_deleteone_test.go +++ b/mongo/findopt/findopt_deleteone_test.go @@ -64,12 +64,12 @@ func TestFindAndDeleteOneOpt(t *testing.T) { bundle1 = bundle1.Projection(true).Sort(false) testhelpers.RequireNotNil(t, bundle1, "created bundle was nil") bundle1Opts := []option.Optioner{ - OptProjection{true}.ConvertDeleteOneOption(), - OptSort{false}.ConvertDeleteOneOption(), + OptProjection{Projection: true}.ConvertDeleteOneOption(), + OptSort{Sort: false}.ConvertDeleteOneOption(), } bundle1DedupOpts := []option.Optioner{ - OptProjection{true}.ConvertDeleteOneOption(), - OptSort{false}.ConvertDeleteOneOption(), + OptProjection{Projection: true}.ConvertDeleteOneOption(), + OptSort{Sort: false}.ConvertDeleteOneOption(), } bundle2 := BundleDeleteOne(MaxTime(1)) @@ -86,13 +86,13 @@ func TestFindAndDeleteOneOpt(t *testing.T) { bundle3Opts := []option.Optioner{ OptMaxTime(1).ConvertDeleteOneOption(), OptMaxTime(2).ConvertDeleteOneOption(), - OptProjection{false}.ConvertDeleteOneOption(), - OptProjection{true}.ConvertDeleteOneOption(), + OptProjection{Projection: false}.ConvertDeleteOneOption(), + OptProjection{Projection: true}.ConvertDeleteOneOption(), } bundle3DedupOpts := []option.Optioner{ OptMaxTime(2).ConvertDeleteOneOption(), - OptProjection{true}.ConvertDeleteOneOption(), + OptProjection{Projection: true}.ConvertDeleteOneOption(), } nilBundle := BundleDeleteOne() @@ -100,40 +100,40 @@ func TestFindAndDeleteOneOpt(t *testing.T) { nestedBundle1 := createNestedDeleteOneBundle1(t) nestedBundleOpts1 := []option.Optioner{ - OptProjection{true}.ConvertDeleteOneOption(), + OptProjection{Projection: true}.ConvertDeleteOneOption(), OptMaxTime(500).ConvertDeleteOneOption(), - OptProjection{false}.ConvertDeleteOneOption(), + OptProjection{Projection: false}.ConvertDeleteOneOption(), OptMaxTime(1000).ConvertDeleteOneOption(), } nestedBundleDedupOpts1 := []option.Optioner{ - OptProjection{false}.ConvertDeleteOneOption(), + OptProjection{Projection: false}.ConvertDeleteOneOption(), OptMaxTime(1000).ConvertDeleteOneOption(), } nestedBundle2 := createNestedDeleteOneBundle2(t) nestedBundleOpts2 := []option.Optioner{ - OptProjection{true}.ConvertDeleteOneOption(), + OptProjection{Projection: true}.ConvertDeleteOneOption(), OptMaxTime(500).ConvertDeleteOneOption(), OptMaxTime(100).ConvertDeleteOneOption(), - OptProjection{false}.ConvertDeleteOneOption(), + OptProjection{Projection: false}.ConvertDeleteOneOption(), OptMaxTime(1000).ConvertDeleteOneOption(), } nestedBundleDedupOpts2 := []option.Optioner{ - OptProjection{false}.ConvertDeleteOneOption(), + OptProjection{Projection: false}.ConvertDeleteOneOption(), OptMaxTime(1000).ConvertDeleteOneOption(), } nestedBundle3 := createNestedDeleteOneBundle3(t) nestedBundleOpts3 := []option.Optioner{ OptMaxTime(100).ConvertDeleteOneOption(), - OptProjection{true}.ConvertDeleteOneOption(), + OptProjection{Projection: true}.ConvertDeleteOneOption(), OptMaxTime(500).ConvertDeleteOneOption(), OptMaxTime(100).ConvertDeleteOneOption(), - OptProjection{false}.ConvertDeleteOneOption(), + OptProjection{Projection: false}.ConvertDeleteOneOption(), OptMaxTime(1000).ConvertDeleteOneOption(), } nestedBundleDedupOpts3 := []option.Optioner{ - OptProjection{false}.ConvertDeleteOneOption(), + OptProjection{Projection: false}.ConvertDeleteOneOption(), OptMaxTime(1000).ConvertDeleteOneOption(), } diff --git a/mongo/gridfs/gridfs_test.go b/mongo/gridfs/gridfs_test.go index 73a70f466e..0c8ee0e452 100644 --- a/mongo/gridfs/gridfs_test.go +++ b/mongo/gridfs/gridfs_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/bson/objectid" "github.com/mongodb/mongo-go-driver/internal/testutil" "github.com/mongodb/mongo-go-driver/internal/testutil/helpers" @@ -314,12 +315,12 @@ func compareChunks(t *testing.T, filesID objectid.ObjectID) { t.Fatalf("chunks has fewer documents than expectedChunks") } - actualChunk := bson.NewDocument() - expectedChunk := bson.NewDocument() + var actualChunk *bson.Document + var expectedChunk *bson.Document - err = actualCursor.Decode(actualChunk) + err = actualCursor.Decode(&actualChunk) testhelpers.RequireNil(t, err, "error decoding actual chunk: %s", err) - err = expectedCursor.Decode(expectedChunk) + err = expectedCursor.Decode(&expectedChunk) testhelpers.RequireNil(t, err, "error decoding expected chunk: %s", err) compareGfsDoc(t, expectedChunk, actualChunk, filesID) @@ -338,12 +339,12 @@ func compareFiles(t *testing.T) { t.Fatalf("files has fewer documents than expectedFiles") } - actualFile := bson.NewDocument() - expectedFile := bson.NewDocument() + var actualFile *bson.Document + var expectedFile *bson.Document - err = actualCursor.Decode(actualFile) + err = actualCursor.Decode(&actualFile) testhelpers.RequireNil(t, err, "error decoding actual file: %s", err) - err = expectedCursor.Decode(expectedFile) + err = expectedCursor.Decode(&expectedFile) testhelpers.RequireNil(t, err, "error decoding expected file: %s", err) compareGfsDoc(t, expectedFile, actualFile, objectid.ObjectID{}) @@ -376,8 +377,10 @@ func runUploadAssert(t *testing.T, test test, fileID objectid.ObjectID) { docs := make([]interface{}, len(assertData.Documents)) for i, docInterface := range assertData.Documents { - doc, err := mongo.TransformDocument(docInterface) - testhelpers.RequireNil(t, err, "error transforming doc: %s", err) + rdr, err := bsoncodec.Marshal(docInterface) + testhelpers.RequireNil(t, err, "error marshaling doc: %s", err) + doc, err := bson.ReadDocument(rdr) + testhelpers.RequireNil(t, err, "error reading doc: %s", err) if id, err := doc.LookupErr("_id"); err == nil { idStr := id.StringValue() diff --git a/mongo/index_view_internal_test.go b/mongo/index_view_internal_test.go index 5a02a0ecc1..466594bae3 100644 --- a/mongo/index_view_internal_test.go +++ b/mongo/index_view_internal_test.go @@ -67,9 +67,9 @@ func TestIndexView_List(t *testing.T) { require.NoError(t, err) var found bool - var idx index for cursor.Next(context.Background()) { + var idx index err := cursor.Decode(&idx) require.NoError(t, err) @@ -110,9 +110,9 @@ func TestIndexView_CreateOne(t *testing.T) { require.NoError(t, err) var found bool - var idx index for cursor.Next(context.Background()) { + var idx index err := cursor.Decode(&idx) require.NoError(t, err) @@ -155,9 +155,9 @@ func TestIndexView_CreateOneWithNameOption(t *testing.T) { require.NoError(t, err) var found bool - var idx index for cursor.Next(context.Background()) { + var idx index err := cursor.Decode(&idx) require.NoError(t, err) @@ -302,9 +302,9 @@ func TestIndexView_CreateMany(t *testing.T) { fooFound := false barBazFound := false - var idx index for cursor.Next(context.Background()) { + var idx index err := cursor.Decode(&idx) require.NoError(t, err) @@ -477,9 +477,9 @@ func TestIndexView_CreateIndexesOptioner(t *testing.T) { fooFound := false barBazFound := false - var idx index for cursor.Next(context.Background()) { + var idx index err := cursor.Decode(&idx) require.NoError(t, err) @@ -551,9 +551,8 @@ func TestIndexView_DropIndexesOptioner(t *testing.T) { cursor, err := indexView.List(context.Background()) require.NoError(t, err) - var idx index - for cursor.Next(context.Background()) { + var idx index err := cursor.Decode(&idx) require.NoError(t, err) require.Equal(t, expectedNS, idx.NS) diff --git a/mongo/mongo.go b/mongo/mongo.go index c7e7969a0f..304fcca839 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -10,12 +10,12 @@ import ( "context" "errors" "fmt" - "io" "net" "reflect" "strings" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/bson/objectid" "github.com/mongodb/mongo-go-driver/mongo/countopt" ) @@ -25,44 +25,78 @@ type Dialer interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } -// TransformDocument handles transforming a document of an allowable type into -// a *bson.Document. This method is called directly after most methods that -// have one or more parameters that are documents. -// -// The supported types for document are: -// -// bson.Marshaler -// bson.DocumentMarshaler -// bson.Reader -// []byte (must be a valid BSON document) -// io.Reader (only 1 BSON document will be read) -// A custom struct type +// BSONAppender is an interface implemented by types that can marshal a +// provided type into BSON bytes and append those bytes to the provided []byte. +// The AppendBSON can return a non-nil error and non-nil []byte. The AppendBSON +// method may also write incomplete BSON to the []byte. +type BSONAppender interface { + AppendBSON([]byte, interface{}) ([]byte, error) +} + +// BSONAppenderFunc is an adapter function that allows any function that +// satisfies the AppendBSON method signature to be used where a BSONAppender is +// used. +type BSONAppenderFunc func([]byte, interface{}) ([]byte, error) + +// AppendBSON implements the BSONAppender interface +func (baf BSONAppenderFunc) AppendBSON(dst []byte, val interface{}) ([]byte, error) { + return baf(dst, val) +} + +// MarshalError is returned when attempting to transform a value into a document +// results in an error. +type MarshalError struct { + Value interface{} + Err error +} + +// Error implements the error interface. +func (me MarshalError) Error() string { + return fmt.Sprintf("cannot transform type %s to a *bson.Document", reflect.TypeOf(me.Value)) +} + +// // TransformDocument handles transforming a document of an allowable type into +// // a *bson.Document. This method is called directly after most methods that +// // have one or more parameters that are documents. +// // +// // The supported types for document are: +// // +// // bson.Marshaler +// // bson.Reader +// // []byte (must be a valid BSON document) +// // A custom struct type +// // +// func TransformDocument(val interface{}) (*bson.Document, error) { +// document, err := transformDocument(BSONAppenderFunc(bsoncodec.MarshalAppend), val) +// if err != nil { +// return nil, MarshalError{Value: val, Err: err} +// } // -func TransformDocument(document interface{}) (*bson.Document, error) { - switch d := document.(type) { - case nil: +// return document, nil +// } + +func transformDocument(registry *bsoncodec.Registry, val interface{}) (*bson.Document, error) { + if registry == nil { + registry = bsoncodec.NewRegistryBuilder().Build() + } + if val == nil { return bson.NewDocument(), nil - case *bson.Document: - return d, nil - case bson.Marshaler, bson.Reader, []byte, io.Reader: - return bson.NewDocumentEncoder().EncodeDocument(document) - case bson.DocumentMarshaler: - return d.MarshalBSONDocument() - default: - var kind reflect.Kind - if t := reflect.TypeOf(document); t.Kind() == reflect.Ptr { - kind = t.Elem().Kind() - } - if reflect.ValueOf(document).Kind() == reflect.Struct || kind == reflect.Struct { - return bson.NewDocumentEncoder().EncodeDocument(document) - } - if reflect.ValueOf(document).Kind() == reflect.Map && - reflect.TypeOf(document).Key().Kind() == reflect.String { - return bson.NewDocumentEncoder().EncodeDocument(document) - } + } + if doc, ok := val.(*bson.Document); ok { + return doc.Copy(), nil + } + if bs, ok := val.([]byte); ok { + // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery. + val = bson.Reader(bs) + } - return nil, fmt.Errorf("cannot transform type %s to a *bson.Document", reflect.TypeOf(document)) + // TODO(skriptble): Use a pool of these instead. + buf := make([]byte, 0, 256) + b, err := bsoncodec.MarshalAppendWithRegistry(registry, buf[:0], val) + if err != nil { + return nil, MarshalError{Value: val, Err: err} } + return bson.ReadDocument(b) } func ensureID(d *bson.Document) (interface{}, error) { @@ -89,7 +123,7 @@ func ensureDollarKey(doc *bson.Document) error { return nil } -func transformAggregatePipeline(pipeline interface{}) (*bson.Array, error) { +func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface{}) (*bson.Array, error) { var pipelineArr *bson.Array switch t := pipeline.(type) { case *bson.Array: @@ -104,7 +138,7 @@ func transformAggregatePipeline(pipeline interface{}) (*bson.Array, error) { pipelineArr = bson.NewArray() for _, val := range t { - doc, err := TransformDocument(val) + doc, err := transformDocument(registry, val) if err != nil { return nil, err } @@ -112,7 +146,7 @@ func transformAggregatePipeline(pipeline interface{}) (*bson.Array, error) { pipelineArr.Append(bson.VC.Document(doc)) } default: - p, err := TransformDocument(pipeline) + p, err := transformDocument(registry, pipeline) if err != nil { return nil, err } @@ -124,9 +158,9 @@ func transformAggregatePipeline(pipeline interface{}) (*bson.Array, error) { } // Build the aggregation pipeline for the CountDocument command. -func countDocumentsAggregatePipeline(filter interface{}, opts ...countopt.Count) (*bson.Array, error) { +func countDocumentsAggregatePipeline(registry *bsoncodec.Registry, filter interface{}, opts ...countopt.Count) (*bson.Array, error) { pipeline := bson.NewArray() - filterDoc, err := TransformDocument(filter) + filterDoc, err := transformDocument(registry, filter) if err != nil { return nil, err diff --git a/mongo/mongo_test.go b/mongo/mongo_test.go index 3ca6852746..7627a8fe14 100644 --- a/mongo/mongo_test.go +++ b/mongo/mongo_test.go @@ -7,12 +7,12 @@ package mongo import ( - "fmt" - "reflect" + "errors" "testing" "github.com/google/go-cmp/cmp" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" ) func TestTransformDocument(t *testing.T) { @@ -28,12 +28,6 @@ func TestTransformDocument(t *testing.T) { bson.NewDocument(bson.EC.String("foo", "bar")), nil, }, - { - "bson.DocumentMarshaler", - dMarsh{bson.NewDocument(bson.EC.String("foo", "bar"))}, - bson.NewDocument(bson.EC.String("foo", "bar")), - nil, - }, { "reflection", reflectStruct{Foo: "bar"}, @@ -50,13 +44,13 @@ func TestTransformDocument(t *testing.T) { "unsupported type", []string{"foo", "bar"}, nil, - fmt.Errorf("cannot transform type %s to a *bson.Document", reflect.TypeOf([]string{})), + MarshalError{Value: []string{"foo", "bar"}, Err: errors.New("invalid state transition: TopLevel -> ArrayMode")}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - got, err := TransformDocument(tc.document) + got, err := transformDocument(bsoncodec.NewRegistryBuilder().Build(), tc.document) if !cmp.Equal(err, tc.err, cmp.Comparer(compareErrors)) { t.Errorf("Error does not match expected error. got %v; want %v", err, tc.err) } @@ -84,7 +78,7 @@ func compareErrors(err1, err2 error) bool { return true } -var _ bson.Marshaler = bMarsh{} +var _ bsoncodec.Marshaler = bMarsh{} type bMarsh struct { *bson.Document @@ -94,16 +88,6 @@ func (b bMarsh) MarshalBSON() ([]byte, error) { return b.Document.MarshalBSON() } -var _ bson.DocumentMarshaler = dMarsh{} - -type dMarsh struct { - d *bson.Document -} - -func (d dMarsh) MarshalBSONDocument() (*bson.Document, error) { - return d.d, nil -} - type reflectStruct struct { Foo string } diff --git a/mongo/read_write_concern_spec_test.go b/mongo/read_write_concern_spec_test.go index 4234c756de..68a669d9de 100644 --- a/mongo/read_write_concern_spec_test.go +++ b/mongo/read_write_concern_spec_test.go @@ -7,7 +7,6 @@ package mongo import ( - "bytes" "encoding/json" "errors" "fmt" @@ -17,6 +16,7 @@ import ( "time" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/connstring" "github.com/mongodb/mongo-go-driver/core/readconcern" "github.com/mongodb/mongo-go-driver/core/writeconcern" @@ -212,8 +212,7 @@ func runDocumentTest(t *testing.T, testName string, testCase *documentTest) { rcBytes := rcDoc.Value().ReaderDocument() actual := make(map[string]interface{}) - decoder := bson.NewDecoder(bytes.NewBuffer(rcBytes)) - err = decoder.Decode(actual) + err = bsoncodec.Unmarshal(rcBytes, &actual) requireMapEqual(t, testCase.ReadConcernDocument, actual) } @@ -237,8 +236,7 @@ func runDocumentTest(t *testing.T, testName string, testCase *documentTest) { wcBytes := wcDoc.Value().ReaderDocument() actual := make(map[string]interface{}) - decoder := bson.NewDecoder(bytes.NewBuffer(wcBytes)) - err = decoder.Decode(actual) + err = bsoncodec.Unmarshal(wcBytes, &actual) require.NoError(t, err) requireMapEqual(t, testCase.WriteConcernDocument, actual) diff --git a/mongo/results.go b/mongo/results.go index a0a57bb3a7..548635232a 100644 --- a/mongo/results.go +++ b/mongo/results.go @@ -7,10 +7,10 @@ package mongo import ( - "bytes" "fmt" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/result" ) @@ -114,7 +114,7 @@ func (result *UpdateResult) UnmarshalBSON(b []byte) error { var d struct { ID interface{} `bson:"_id"` } - err = bson.NewDecoder(bytes.NewReader(e.Value().ReaderDocument())).Decode(&d) + err = bsoncodec.Unmarshal(e.Value().ReaderDocument(), &d) if err != nil { return err } diff --git a/mongo/results_test.go b/mongo/results_test.go index a044207320..ed2cc17f3b 100644 --- a/mongo/results_test.go +++ b/mongo/results_test.go @@ -7,10 +7,10 @@ package mongo import ( - "bytes" "testing" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/stretchr/testify/require" ) @@ -25,7 +25,7 @@ func TestDeleteResult_unmarshalInto(t *testing.T) { require.Nil(t, err) var result DeleteResult - err = bson.NewDecoder(bytes.NewReader(b)).Decode(&result) + err = bsoncodec.Unmarshal(b, &result) require.Nil(t, err) require.Equal(t, result.DeletedCount, int64(2)) } @@ -34,11 +34,10 @@ func TestDeleteResult_marshalFrom(t *testing.T) { t.Parallel() result := DeleteResult{DeletedCount: 1} - var buf bytes.Buffer - err := bson.NewEncoder(&buf).Encode(result) + buf, err := bsoncodec.Marshal(result) require.Nil(t, err) - doc, err := bson.ReadDocument(buf.Bytes()) + doc, err := bson.ReadDocument(buf) require.Nil(t, err) require.Equal(t, doc.Len(), 1) @@ -66,7 +65,7 @@ func TestUpdateOneResult_unmarshalInto(t *testing.T) { require.Nil(t, err) var result UpdateResult - err = bson.NewDecoder(bytes.NewReader(b)).Decode(&result) + err = bsoncodec.Unmarshal(b, &result) require.Nil(t, err) require.Equal(t, result.MatchedCount, int64(1)) require.Equal(t, result.ModifiedCount, int64(2)) diff --git a/mongo/retryable_writes_test.go b/mongo/retryable_writes_test.go index 6e53fffd45..e826d1534d 100644 --- a/mongo/retryable_writes_test.go +++ b/mongo/retryable_writes_test.go @@ -340,6 +340,7 @@ func createRetryMonitoredClient(t *testing.T, monitor *event.CommandMonitor) *Cl connString: testutil.ConnString(t), readPreference: readpref.Primary(), clock: clock, + registry: defaultRegistry, } subscription, err := c.topology.Subscribe() diff --git a/mongo/transactions_test.go b/mongo/transactions_test.go index ca4765814e..e67c699251 100644 --- a/mongo/transactions_test.go +++ b/mongo/transactions_test.go @@ -20,6 +20,7 @@ import ( "path" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/core/command" "github.com/mongodb/mongo-go-driver/core/description" "github.com/mongodb/mongo-go-driver/core/event" @@ -296,6 +297,7 @@ func createTransactionsMonitoredClient(t *testing.T, monitor *event.CommandMonit connString: testutil.ConnString(t), readPreference: readpref.Primary(), clock: clock, + registry: bsoncodec.NewRegistryBuilder().Build(), } addClientOptions(c, opts) diff --git a/mongo/updateopt/updateopt.go b/mongo/updateopt/updateopt.go index 8ad275bbf0..79954f4f1c 100644 --- a/mongo/updateopt/updateopt.go +++ b/mongo/updateopt/updateopt.go @@ -231,7 +231,7 @@ func (ub *UpdateBundle) unbundle() ([]option.UpdateOptioner, *session.Client, er // ArrayFilters specifies which array elements an update should apply. func ArrayFilters(filter ...interface{}) OptArrayFilters { - return OptArrayFilters(filter) + return OptArrayFilters{Filters: filter} } // BypassDocumentValidation allows the write to opt-out of document-level validation.