From 4534e031c1676bbb5721025a802a3bece3775001 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 20 Jun 2026 09:16:27 +0000 Subject: [PATCH] Bump go.mongodb.org/mongo-driver from 1.9.0 to 1.17.7 Bumps [go.mongodb.org/mongo-driver](https://github.com/mongodb/mongo-go-driver) from 1.9.0 to 1.17.7. - [Release notes](https://github.com/mongodb/mongo-go-driver/releases) - [Commits](https://github.com/mongodb/mongo-go-driver/compare/v1.9.0...v1.17.7) --- updated-dependencies: - dependency-name: go.mongodb.org/mongo-driver dependency-version: 1.17.7 dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- go.mod | 5 +- go.sum | 119 +- vendor/github.com/go-stack/stack/.travis.yml | 15 - vendor/github.com/go-stack/stack/README.md | 38 - vendor/github.com/go-stack/stack/go.mod | 1 - vendor/github.com/go-stack/stack/stack.go | 400 - vendor/github.com/golang/snappy/AUTHORS | 3 + vendor/github.com/golang/snappy/CONTRIBUTORS | 4 + vendor/github.com/golang/snappy/decode.go | 87 +- .../github.com/golang/snappy/decode_arm64.s | 494 + .../snappy/{decode_amd64.go => decode_asm.go} | 1 + .../github.com/golang/snappy/decode_other.go | 24 +- vendor/github.com/golang/snappy/encode.go | 4 + .../github.com/golang/snappy/encode_arm64.s | 722 ++ .../snappy/{encode_amd64.go => encode_asm.go} | 1 + .../github.com/golang/snappy/encode_other.go | 2 +- .../github.com/klauspost/compress/.gitignore | 7 + .../klauspost/compress/.goreleaser.yml | 4 + .../github.com/klauspost/compress/README.md | 220 +- .../github.com/klauspost/compress/SECURITY.md | 25 + .../klauspost/compress/fse/compress.go | 31 +- .../klauspost/compress/fse/decompress.go | 4 +- vendor/github.com/klauspost/compress/go.mod | 9 +- .../klauspost/compress/huff0/bitreader.go | 130 +- .../klauspost/compress/huff0/bitwriter.go | 133 +- .../klauspost/compress/huff0/bytereader.go | 10 - .../klauspost/compress/huff0/compress.go | 127 +- .../klauspost/compress/huff0/decompress.go | 744 +- .../compress/huff0/decompress_amd64.go | 226 + .../compress/huff0/decompress_amd64.s | 830 ++ .../compress/huff0/decompress_generic.go | 299 + .../klauspost/compress/huff0/huff0.go | 2 + .../compress/internal/cpuinfo/cpuinfo.go | 34 + .../internal/cpuinfo/cpuinfo_amd64.go | 11 + .../compress/internal/cpuinfo/cpuinfo_amd64.s | 36 + .../compress/internal/snapref/encode_other.go | 38 +- .../klauspost/compress/zstd/README.md | 172 +- .../klauspost/compress/zstd/bitreader.go | 12 +- .../klauspost/compress/zstd/bitwriter.go | 98 +- .../klauspost/compress/zstd/blockdec.go | 526 +- .../klauspost/compress/zstd/blockenc.go | 117 +- .../klauspost/compress/zstd/bytebuf.go | 27 +- .../klauspost/compress/zstd/bytereader.go | 6 - .../klauspost/compress/zstd/decodeheader.go | 93 +- .../klauspost/compress/zstd/decoder.go | 685 +- .../compress/zstd/decoder_options.go | 103 +- .../klauspost/compress/zstd/dict.go | 51 +- .../klauspost/compress/zstd/enc_base.go | 17 +- .../klauspost/compress/zstd/enc_best.go | 264 +- .../klauspost/compress/zstd/enc_better.go | 43 +- .../klauspost/compress/zstd/enc_dfast.go | 35 +- .../klauspost/compress/zstd/enc_fast.go | 176 +- .../klauspost/compress/zstd/encoder.go | 151 +- .../compress/zstd/encoder_options.go | 55 +- .../klauspost/compress/zstd/framedec.go | 342 +- .../klauspost/compress/zstd/fse_decoder.go | 128 +- .../compress/zstd/fse_decoder_amd64.go | 65 + .../compress/zstd/fse_decoder_amd64.s | 126 + .../compress/zstd/fse_decoder_generic.go | 72 + .../klauspost/compress/zstd/fse_encoder.go | 28 +- .../klauspost/compress/zstd/hash.go | 6 - .../klauspost/compress/zstd/history.go | 67 +- .../compress/zstd/internal/xxhash/README.md | 49 +- .../compress/zstd/internal/xxhash/xxhash.go | 47 +- .../zstd/internal/xxhash/xxhash_amd64.go | 12 - .../zstd/internal/xxhash/xxhash_amd64.s | 337 +- .../zstd/internal/xxhash/xxhash_arm64.s | 184 + .../zstd/internal/xxhash/xxhash_asm.go | 16 + .../zstd/internal/xxhash/xxhash_other.go | 23 +- .../klauspost/compress/zstd/matchlen_amd64.go | 16 + .../klauspost/compress/zstd/matchlen_amd64.s | 68 + .../compress/zstd/matchlen_generic.go | 33 + .../klauspost/compress/zstd/seqdec.go | 328 +- .../klauspost/compress/zstd/seqdec_amd64.go | 394 + .../klauspost/compress/zstd/seqdec_amd64.s | 4175 +++++++++ .../klauspost/compress/zstd/seqdec_generic.go | 237 + .../github.com/klauspost/compress/zstd/zip.go | 69 +- .../klauspost/compress/zstd/zstd.go | 55 +- .../github.com/montanaflynn/stats/.gitignore | 7 + .../montanaflynn/stats/CHANGELOG.md | 534 ++ .../montanaflynn/stats/DOCUMENTATION.md | 1271 +++ .../LICENSE.md => montanaflynn/stats/LICENSE} | 2 +- vendor/github.com/montanaflynn/stats/Makefile | 34 + .../github.com/montanaflynn/stats/README.md | 237 + .../montanaflynn/stats/correlation.go | 60 + .../montanaflynn/stats/cumulative_sum.go | 21 + vendor/github.com/montanaflynn/stats/data.go | 169 + .../github.com/montanaflynn/stats/describe.go | 81 + .../montanaflynn/stats/deviation.go | 57 + .../montanaflynn/stats/distances.go | 91 + vendor/github.com/montanaflynn/stats/doc.go | 23 + .../github.com/montanaflynn/stats/entropy.go | 31 + .../github.com/montanaflynn/stats/errors.go | 35 + .../stats/geometric_distribution.go | 42 + vendor/github.com/montanaflynn/stats/go.mod | 3 + .../github.com/montanaflynn/stats/legacy.go | 49 + vendor/github.com/montanaflynn/stats/load.go | 199 + vendor/github.com/montanaflynn/stats/max.go | 26 + vendor/github.com/montanaflynn/stats/mean.go | 60 + .../github.com/montanaflynn/stats/median.go | 25 + vendor/github.com/montanaflynn/stats/min.go | 26 + vendor/github.com/montanaflynn/stats/mode.go | 47 + vendor/github.com/montanaflynn/stats/norm.go | 254 + .../github.com/montanaflynn/stats/outlier.go | 44 + .../montanaflynn/stats/percentile.go | 86 + .../github.com/montanaflynn/stats/quartile.go | 74 + .../github.com/montanaflynn/stats/ranksum.go | 183 + .../montanaflynn/stats/regression.go | 113 + vendor/github.com/montanaflynn/stats/round.go | 38 + .../github.com/montanaflynn/stats/sample.go | 76 + .../github.com/montanaflynn/stats/sigmoid.go | 18 + .../github.com/montanaflynn/stats/softmax.go | 25 + vendor/github.com/montanaflynn/stats/sum.go | 18 + vendor/github.com/montanaflynn/stats/util.go | 43 + .../github.com/montanaflynn/stats/variance.go | 105 + vendor/github.com/xdg-go/scram/CHANGELOG.md | 12 + vendor/github.com/xdg-go/scram/doc.go | 6 +- vendor/github.com/xdg-go/scram/go.mod | 2 +- vendor/github.com/xdg-go/scram/go.sum | 30 +- vendor/github.com/xdg-go/scram/scram.go | 5 + .../github.com/xdg-go/stringprep/CHANGELOG.md | 14 + vendor/github.com/xdg-go/stringprep/go.mod | 2 +- vendor/github.com/xdg-go/stringprep/go.sum | 26 +- vendor/github.com/youmark/pkcs8/.travis.yml | 9 - vendor/github.com/youmark/pkcs8/README.md | 5 +- vendor/github.com/youmark/pkcs8/cipher.go | 60 + .../github.com/youmark/pkcs8/cipher_3des.go | 24 + vendor/github.com/youmark/pkcs8/cipher_aes.go | 84 + vendor/github.com/youmark/pkcs8/go.mod | 5 + vendor/github.com/youmark/pkcs8/go.sum | 2 + vendor/github.com/youmark/pkcs8/kdf_pbkdf2.go | 91 + vendor/github.com/youmark/pkcs8/kdf_scrypt.go | 62 + vendor/github.com/youmark/pkcs8/pkcs8.go | 420 +- .../go.mongodb.org/mongo-driver/bson/bson.go | 6 +- .../bson/bsoncodec/array_codec.go | 9 +- .../mongo-driver/bson/bsoncodec/bsoncodec.go | 168 +- .../bson/bsoncodec/byte_slice_codec.go | 35 +- .../bson/bsoncodec/codec_cache.go | 166 + .../bson/bsoncodec/default_value_decoders.go | 190 +- .../bson/bsoncodec/default_value_encoders.go | 190 +- .../mongo-driver/bson/bsoncodec/doc.go | 79 +- .../bson/bsoncodec/empty_interface_codec.go | 47 +- .../mongo-driver/bson/bsoncodec/map_codec.go | 81 +- .../bson/bsoncodec/pointer_codec.go | 73 +- .../mongo-driver/bson/bsoncodec/registry.go | 488 +- .../bson/bsoncodec/slice_codec.go | 51 +- .../bson/bsoncodec/string_codec.go | 29 +- .../bson/bsoncodec/struct_codec.go | 258 +- .../bson/bsoncodec/struct_tag_parser.go | 53 +- .../mongo-driver/bson/bsoncodec/time_codec.go | 30 +- .../mongo-driver/bson/bsoncodec/types.go | 1 + .../mongo-driver/bson/bsoncodec/uint_codec.go | 35 +- .../bsonoptions/byte_slice_codec_options.go | 11 + .../mongo-driver/bson/bsonoptions/doc.go | 8 + .../empty_interface_codec_options.go | 11 + .../bson/bsonoptions/map_codec_options.go | 15 + .../bson/bsonoptions/slice_codec_options.go | 11 + .../bson/bsonoptions/string_codec_options.go | 11 + .../bson/bsonoptions/struct_codec_options.go | 20 + .../bson/bsonoptions/time_codec_options.go | 11 + .../bson/bsonoptions/uint_codec_options.go | 11 + .../mongo-driver/bson/bsonrw/copier.go | 56 +- .../bson/bsonrw/extjson_parser.go | 4 +- .../bson/bsonrw/extjson_reader.go | 13 +- .../bson/bsonrw/extjson_wrappers.go | 4 +- .../bson/bsonrw/extjson_writer.go | 43 +- .../mongo-driver/bson/bsonrw/json_scanner.go | 53 +- .../mongo-driver/bson/bsonrw/reader.go | 2 + .../mongo-driver/bson/bsonrw/value_reader.go | 41 +- .../mongo-driver/bson/bsonrw/value_writer.go | 84 +- .../mongo-driver/bson/bsonrw/writer.go | 9 + .../mongo-driver/bson/bsontype/bsontype.go | 21 +- .../mongo-driver/bson/decoder.go | 94 +- .../go.mongodb.org/mongo-driver/bson/doc.go | 198 +- .../mongo-driver/bson/encoder.go | 110 +- .../mongo-driver/bson/marshal.go | 225 +- .../mongo-driver/bson/primitive/decimal.go | 30 +- .../mongo-driver/bson/primitive/objectid.go | 16 +- .../mongo-driver/bson/primitive/primitive.go | 52 +- .../mongo-driver/bson/primitive_codecs.go | 40 +- .../go.mongodb.org/mongo-driver/bson/raw.go | 34 +- .../mongo-driver/bson/raw_element.go | 7 +- .../mongo-driver/bson/raw_value.go | 27 +- .../mongo-driver/bson/registry.go | 31 +- .../go.mongodb.org/mongo-driver/bson/types.go | 16 +- .../mongo-driver/bson/unmarshal.go | 95 +- .../go.mongodb.org/mongo-driver/event/doc.go | 58 +- .../mongo-driver/event/monitoring.go | 29 +- .../mongo-driver/internal/aws/awserr/error.go | 60 + .../mongo-driver/internal/aws/awserr/types.go | 144 + .../aws/credentials/chain_provider.go | 72 + .../internal/aws/credentials/credentials.go | 197 + .../internal/aws/signer/v4/header_rules.go | 51 + .../aws/signer/v4}/request.go | 4 +- .../aws/signer/v4/uri_path.go} | 23 +- .../aws/signer/v4/v4.go} | 119 +- .../mongo-driver/internal/aws/types.go | 153 + .../{string_util.go => bsonutil/bsonutil.go} | 33 +- .../internal/cancellation_listener.go | 47 - .../internal/codecutil/encoding.go | 65 + .../credproviders/assume_role_provider.go | 148 + .../internal/credproviders/ec2_provider.go | 183 + .../internal/credproviders/ecs_provider.go | 112 + .../internal/credproviders/env_provider.go | 69 + .../internal/credproviders/imds_provider.go | 103 + .../internal/credproviders/static_provider.go | 59 + .../mongo-driver/internal/csfle/csfle.go | 40 + .../mongo-driver/internal/csot/csot.go | 60 + .../mongo-driver/internal/driverutil/hello.go | 128 + .../internal/driverutil/operation.go | 31 + .../mongo-driver/internal/error.go | 119 - .../{const.go => handshake/handshake.go} | 8 +- .../internal/httputil/httputil.go | 30 + .../mongo-driver/internal/logger/component.go | 314 + .../mongo-driver/internal/logger/context.go | 48 + .../mongo-driver/internal/logger/io_sink.go | 63 + .../mongo-driver/internal/logger/level.go | 74 + .../mongo-driver/internal/logger/logger.go | 275 + .../mongo-driver/internal/ptrutil/int64.go | 39 + .../mongo-driver/internal/rand/bits.go | 38 + .../mongo-driver/internal/rand/exp.go | 223 + .../mongo-driver/internal/rand/normal.go | 158 + .../mongo-driver/internal/rand/rand.go | 374 + .../mongo-driver/internal/rand/rng.go | 93 + .../internal/randutil/randutil.go | 64 +- .../internal/uri_validation_errors.go | 22 - .../{x/mongo/driver => internal}/uuid/uuid.go | 46 +- .../mongo-driver/mongo/address/addr.go | 1 + .../{internal => mongo}/background_context.go | 6 +- .../mongo-driver/mongo/batch_cursor.go | 23 + .../mongo-driver/mongo/bulk_write.go | 170 +- .../mongo-driver/mongo/bulk_write_models.go | 6 +- .../mongo-driver/mongo/change_stream.go | 164 +- .../mongo/change_stream_deployment.go | 5 +- .../mongo-driver/mongo/client.go | 712 +- .../mongo-driver/mongo/client_encryption.go | 337 +- .../mongo-driver/mongo/collection.go | 615 +- .../mongo-driver/mongo/cursor.go | 140 +- .../mongo-driver/mongo/database.go | 323 +- .../mongo/description/description.go | 1 + .../mongo-driver/mongo/description/server.go | 117 +- .../mongo/description/server_selector.go | 208 +- .../mongo/description/topology.go | 12 +- .../go.mongodb.org/mongo-driver/mongo/doc.go | 176 +- .../mongo-driver/mongo/errors.go | 100 +- .../mongo-driver/mongo/index_view.go | 99 +- .../mongo-driver/mongo/mongo.go | 276 +- .../mongo-driver/mongo/mongocryptd.go | 32 +- .../mongo-driver/mongo/mongointernal.go | 41 + .../mongo/options/aggregateoptions.go | 15 +- .../mongo/options/autoencryptionoptions.go | 66 +- .../mongo/options/bulkwriteoptions.go | 27 +- .../mongo/options/changestreamoptions.go | 49 +- .../mongo/options/clientencryptionoptions.go | 26 +- .../mongo/options/clientoptions.go | 381 +- .../mongo/options/collectionoptions.go | 32 +- .../mongo/options/countoptions.go | 27 + .../mongo/options/createcollectionoptions.go | 103 +- .../mongo/options/datakeyoptions.go | 58 +- .../mongo-driver/mongo/options/dboptions.go | 32 +- .../mongo/options/deleteoptions.go | 16 + .../mongo/options/distinctoptions.go | 24 + .../mongo-driver/mongo/options/doc.go | 8 + .../mongo/options/encryptoptions.go | 99 +- .../mongo/options/estimatedcountoptions.go | 25 +- .../mongo-driver/mongo/options/findoptions.go | 188 +- .../mongo/options/gridfsoptions.go | 20 + .../mongo/options/indexoptions.go | 40 +- .../mongo/options/insertoptions.go | 54 +- .../mongo/options/listcollectionsoptions.go | 3 + .../mongo/options/listdatabasesoptions.go | 5 +- .../mongo/options/loggeroptions.go | 115 + .../mongo/options/mongooptions.go | 36 +- .../mongo/options/replaceoptions.go | 27 +- .../mongo/options/rewrapdatakeyoptions.go | 55 + .../mongo/options/runcmdoptions.go | 7 +- .../mongo/options/searchindexoptions.go | 48 + .../mongo/options/serverapioptions.go | 3 +- .../mongo/options/sessionoptions.go | 14 +- .../mongo/options/transactionoptions.go | 11 + .../mongo/options/updateoptions.go | 27 +- .../mongo/readconcern/readconcern.go | 75 +- .../mongo-driver/mongo/readpref/options.go | 26 +- .../mongo-driver/mongo/readpref/readpref.go | 5 +- .../mongo-driver/mongo/results.go | 24 +- .../mongo-driver/mongo/search_index_view.go | 258 + .../mongo-driver/mongo/session.go | 127 +- .../mongo-driver/mongo/single_result.go | 49 +- .../mongo/writeconcern/writeconcern.go | 325 +- vendor/go.mongodb.org/mongo-driver/tag/tag.go | 16 +- .../mongo-driver/version/version.go | 8 +- .../mongo-driver/x/bsonx/array.go | 97 - .../mongo-driver/x/bsonx/bsoncore/array.go | 10 +- .../mongo-driver/x/bsonx/bsoncore/bsoncore.go | 80 +- .../mongo-driver/x/bsonx/bsoncore/doc.go | 34 + .../mongo-driver/x/bsonx/bsoncore/document.go | 40 +- .../x/bsonx/bsoncore/document_sequence.go | 10 +- .../mongo-driver/x/bsonx/bsoncore/element.go | 10 +- .../mongo-driver/x/bsonx/bsoncore/value.go | 38 +- .../mongo-driver/x/bsonx/constructor.go | 166 - .../mongo-driver/x/bsonx/document.go | 305 - .../mongo-driver/x/bsonx/element.go | 51 - .../mongo-driver/x/bsonx/mdocument.go | 231 - .../mongo-driver/x/bsonx/primitive_codecs.go | 637 -- .../x/bsonx/reflectionfree_d_codec.go | 1025 --- .../mongo-driver/x/bsonx/registry.go | 22 - .../mongo-driver/x/bsonx/value.go | 866 -- .../mongo-driver/x/mongo/driver/DESIGN.md | 23 - .../mongo-driver/x/mongo/driver/auth/auth.go | 42 +- .../x/mongo/driver/auth/aws_conv.go | 182 +- .../mongo-driver/x/mongo/driver/auth/cred.go | 14 +- .../x/mongo/driver/auth/creds/awscreds.go | 58 + .../x/mongo/driver/auth/creds/azurecreds.go | 40 + .../x/mongo/driver/auth/creds/doc.go | 14 + .../x/mongo/driver/auth/creds/gcpcreds.go | 74 + .../x/mongo/driver/auth/default.go | 37 +- .../mongo-driver/x/mongo/driver/auth/doc.go | 21 +- .../x/mongo/driver/auth/gssapi.go | 13 +- .../x/mongo/driver/auth/gssapi_not_enabled.go | 4 +- .../mongo/driver/auth/gssapi_not_supported.go | 3 +- .../driver/auth/internal/awsv4/credentials.go | 63 - .../x/mongo/driver/auth/internal/awsv4/doc.go | 15 - .../mongo/driver/auth/internal/awsv4/rules.go | 98 - .../mongo/driver/auth/internal/gssapi/gss.go | 5 +- .../driver/auth/internal/gssapi/gss_wrapper.c | 26 +- .../driver/auth/internal/gssapi/gss_wrapper.h | 14 +- .../mongo/driver/auth/internal/gssapi/sspi.go | 7 +- .../auth/internal/gssapi/sspi_wrapper.c | 12 +- .../auth/internal/gssapi/sspi_wrapper.h | 10 +- .../x/mongo/driver/auth/mongodbaws.go | 46 +- .../x/mongo/driver/auth/mongodbcr.go | 14 +- .../mongo-driver/x/mongo/driver/auth/oidc.go | 556 ++ .../mongo-driver/x/mongo/driver/auth/plain.go | 29 +- .../mongo-driver/x/mongo/driver/auth/sasl.go | 9 +- .../mongo-driver/x/mongo/driver/auth/scram.go | 28 +- .../mongo-driver/x/mongo/driver/auth/x509.go | 26 +- .../x/mongo/driver/batch_cursor.go | 158 +- .../mongo-driver/x/mongo/driver/batches.go | 8 +- .../x/mongo/driver/compression.go | 167 +- .../x/mongo/driver/connstring/connstring.go | 1340 +-- .../mongo-driver/x/mongo/driver/crypt.go | 110 +- .../mongo-driver/x/mongo/driver/dns/dns.go | 15 +- .../mongo-driver/x/mongo/driver/driver.go | 145 +- .../mongo-driver/x/mongo/driver/errors.go | 75 +- .../mongo-driver/x/mongo/driver/legacy.go | 7 + .../driver/list_collections_batch_cursor.go | 129 - .../x/mongo/driver/mongocrypt/binary.go | 17 +- .../x/mongo/driver/mongocrypt/mongocrypt.go | 295 +- .../driver/mongocrypt/mongocrypt_context.go | 11 + .../mongocrypt_context_not_enabled.go | 7 +- .../mongocrypt_kms_context_not_enabled.go | 2 +- .../mongocrypt/mongocrypt_not_enabled.go | 54 +- .../x/mongo/driver/mongocrypt/options/doc.go | 14 + .../options/mongocrypt_context_options.go | 99 +- .../mongocrypt/options/mongocrypt_options.go | 42 +- .../x/mongo/driver/mongocrypt/state.go | 18 +- .../x/mongo/driver/ocsp/config.go | 12 +- .../mongo-driver/x/mongo/driver/ocsp/ocsp.go | 117 +- .../x/mongo/driver/ocsp/options.go | 3 + .../mongo-driver/x/mongo/driver/operation.go | 953 +- .../driver/operation/abort_transaction.go | 18 +- .../x/mongo/driver/operation/aggregate.go | 58 +- .../x/mongo/driver/operation/command.go | 56 +- .../driver/operation/commit_transaction.go | 31 +- .../x/mongo/driver/operation/count.go | 94 +- .../x/mongo/driver/operation/create.go | 104 +- .../{createIndexes.go => create_indexes.go} | 67 +- .../driver/operation/create_search_indexes.go | 251 + .../x/mongo/driver/operation/delete.go | 99 +- .../x/mongo/driver/operation/distinct.go | 61 +- .../x/mongo/driver/operation/doc.go | 14 + .../mongo/driver/operation/drop_collection.go | 53 +- .../x/mongo/driver/operation/drop_database.go | 36 +- .../x/mongo/driver/operation/drop_indexes.go | 90 +- .../driver/operation/drop_search_index.go | 225 + .../x/mongo/driver/operation/end_sessions.go | 36 +- .../x/mongo/driver/operation/find.go | 70 +- .../mongo/driver/operation/find_and_modify.go | 75 +- .../x/mongo/driver/operation/hello.go | 464 +- .../x/mongo/driver/operation/insert.go | 80 +- .../x/mongo/driver/operation/listDatabases.go | 33 +- .../driver/operation/list_collections.go | 46 +- .../x/mongo/driver/operation/list_indexes.go | 66 +- .../x/mongo/driver/operation/update.go | 83 +- .../driver/operation/update_search_index.go | 238 + .../x/mongo/driver/operation_exhaust.go | 5 +- .../x/mongo/driver/operation_legacy.go | 719 -- .../x/mongo/driver/session/client_session.go | 91 +- .../x/mongo/driver/session/doc.go | 14 + .../x/mongo/driver/session/server_session.go | 6 +- .../x/mongo/driver/session/session_pool.go | 20 +- .../x/mongo/driver/topology/DESIGN.md | 7 +- .../x/mongo/driver/topology/connection.go | 339 +- .../driver/topology/connection_legacy.go | 6 + .../driver/topology/connection_options.go | 19 + .../x/mongo/driver/topology/errors.go | 62 +- .../x/mongo/driver/topology/fsm.go | 184 +- .../driver/topology/hanging_tls_conn_1_16.go | 37 - .../driver/topology/hanging_tls_conn_1_17.go | 44 - .../x/mongo/driver/topology/pool.go | 543 +- .../topology/pool_generation_counter.go | 14 +- .../x/mongo/driver/topology/rtt_monitor.go | 206 +- .../x/mongo/driver/topology/server.go | 323 +- .../x/mongo/driver/topology/server_options.go | 54 +- .../x/mongo/driver/topology/topology.go | 394 +- .../mongo/driver/topology/topology_options.go | 710 +- .../x/mongo/driver/wiremessage/wiremessage.go | 121 +- vendor/golang.org/x/crypto/AUTHORS | 3 - vendor/golang.org/x/crypto/CONTRIBUTORS | 3 - vendor/golang.org/x/crypto/LICENSE | 4 +- vendor/golang.org/x/crypto/ocsp/ocsp.go | 20 +- vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go | 4 +- vendor/golang.org/x/crypto/scrypt/scrypt.go | 212 + vendor/golang.org/x/sync/AUTHORS | 3 - vendor/golang.org/x/sync/CONTRIBUTORS | 3 - vendor/golang.org/x/sync/LICENSE | 4 +- vendor/golang.org/x/sync/errgroup/errgroup.go | 85 +- vendor/golang.org/x/sync/errgroup/go120.go | 13 + .../golang.org/x/sync/errgroup/pre_go120.go | 14 + .../x/sync/singleflight/singleflight.go | 214 + vendor/golang.org/x/text/AUTHORS | 3 - vendor/golang.org/x/text/CONTRIBUTORS | 3 - vendor/golang.org/x/text/LICENSE | 4 +- .../x/text/unicode/norm/forminfo.go | 11 +- .../x/text/unicode/norm/normalize.go | 11 +- .../x/text/unicode/norm/tables10.0.0.go | 2 +- .../x/text/unicode/norm/tables11.0.0.go | 2 +- .../x/text/unicode/norm/tables12.0.0.go | 2 +- .../x/text/unicode/norm/tables13.0.0.go | 6 +- .../x/text/unicode/norm/tables15.0.0.go | 7907 +++++++++++++++++ .../x/text/unicode/norm/tables9.0.0.go | 2 +- vendor/golang.org/x/text/unicode/norm/trie.go | 2 +- vendor/modules.txt | 46 +- 433 files changed, 43121 insertions(+), 13642 deletions(-) delete mode 100644 vendor/github.com/go-stack/stack/.travis.yml delete mode 100644 vendor/github.com/go-stack/stack/README.md delete mode 100644 vendor/github.com/go-stack/stack/go.mod delete mode 100644 vendor/github.com/go-stack/stack/stack.go create mode 100644 vendor/github.com/golang/snappy/decode_arm64.s rename vendor/github.com/golang/snappy/{decode_amd64.go => decode_asm.go} (93%) create mode 100644 vendor/github.com/golang/snappy/encode_arm64.s rename vendor/github.com/golang/snappy/{encode_amd64.go => encode_asm.go} (97%) create mode 100644 vendor/github.com/klauspost/compress/SECURITY.md create mode 100644 vendor/github.com/klauspost/compress/huff0/decompress_amd64.go create mode 100644 vendor/github.com/klauspost/compress/huff0/decompress_amd64.s create mode 100644 vendor/github.com/klauspost/compress/huff0/decompress_generic.go create mode 100644 vendor/github.com/klauspost/compress/internal/cpuinfo/cpuinfo.go create mode 100644 vendor/github.com/klauspost/compress/internal/cpuinfo/cpuinfo_amd64.go create mode 100644 vendor/github.com/klauspost/compress/internal/cpuinfo/cpuinfo_amd64.s create mode 100644 vendor/github.com/klauspost/compress/zstd/fse_decoder_amd64.go create mode 100644 vendor/github.com/klauspost/compress/zstd/fse_decoder_amd64.s create mode 100644 vendor/github.com/klauspost/compress/zstd/fse_decoder_generic.go delete mode 100644 vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_amd64.go create mode 100644 vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_arm64.s create mode 100644 vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_asm.go create mode 100644 vendor/github.com/klauspost/compress/zstd/matchlen_amd64.go create mode 100644 vendor/github.com/klauspost/compress/zstd/matchlen_amd64.s create mode 100644 vendor/github.com/klauspost/compress/zstd/matchlen_generic.go create mode 100644 vendor/github.com/klauspost/compress/zstd/seqdec_amd64.go create mode 100644 vendor/github.com/klauspost/compress/zstd/seqdec_amd64.s create mode 100644 vendor/github.com/klauspost/compress/zstd/seqdec_generic.go create mode 100644 vendor/github.com/montanaflynn/stats/.gitignore create mode 100644 vendor/github.com/montanaflynn/stats/CHANGELOG.md create mode 100644 vendor/github.com/montanaflynn/stats/DOCUMENTATION.md rename vendor/github.com/{go-stack/stack/LICENSE.md => montanaflynn/stats/LICENSE} (94%) create mode 100644 vendor/github.com/montanaflynn/stats/Makefile create mode 100644 vendor/github.com/montanaflynn/stats/README.md create mode 100644 vendor/github.com/montanaflynn/stats/correlation.go create mode 100644 vendor/github.com/montanaflynn/stats/cumulative_sum.go create mode 100644 vendor/github.com/montanaflynn/stats/data.go create mode 100644 vendor/github.com/montanaflynn/stats/describe.go create mode 100644 vendor/github.com/montanaflynn/stats/deviation.go create mode 100644 vendor/github.com/montanaflynn/stats/distances.go create mode 100644 vendor/github.com/montanaflynn/stats/doc.go create mode 100644 vendor/github.com/montanaflynn/stats/entropy.go create mode 100644 vendor/github.com/montanaflynn/stats/errors.go create mode 100644 vendor/github.com/montanaflynn/stats/geometric_distribution.go create mode 100644 vendor/github.com/montanaflynn/stats/go.mod create mode 100644 vendor/github.com/montanaflynn/stats/legacy.go create mode 100644 vendor/github.com/montanaflynn/stats/load.go create mode 100644 vendor/github.com/montanaflynn/stats/max.go create mode 100644 vendor/github.com/montanaflynn/stats/mean.go create mode 100644 vendor/github.com/montanaflynn/stats/median.go create mode 100644 vendor/github.com/montanaflynn/stats/min.go create mode 100644 vendor/github.com/montanaflynn/stats/mode.go create mode 100644 vendor/github.com/montanaflynn/stats/norm.go create mode 100644 vendor/github.com/montanaflynn/stats/outlier.go create mode 100644 vendor/github.com/montanaflynn/stats/percentile.go create mode 100644 vendor/github.com/montanaflynn/stats/quartile.go create mode 100644 vendor/github.com/montanaflynn/stats/ranksum.go create mode 100644 vendor/github.com/montanaflynn/stats/regression.go create mode 100644 vendor/github.com/montanaflynn/stats/round.go create mode 100644 vendor/github.com/montanaflynn/stats/sample.go create mode 100644 vendor/github.com/montanaflynn/stats/sigmoid.go create mode 100644 vendor/github.com/montanaflynn/stats/softmax.go create mode 100644 vendor/github.com/montanaflynn/stats/sum.go create mode 100644 vendor/github.com/montanaflynn/stats/util.go create mode 100644 vendor/github.com/montanaflynn/stats/variance.go delete mode 100644 vendor/github.com/youmark/pkcs8/.travis.yml create mode 100644 vendor/github.com/youmark/pkcs8/cipher.go create mode 100644 vendor/github.com/youmark/pkcs8/cipher_3des.go create mode 100644 vendor/github.com/youmark/pkcs8/cipher_aes.go create mode 100644 vendor/github.com/youmark/pkcs8/go.mod create mode 100644 vendor/github.com/youmark/pkcs8/go.sum create mode 100644 vendor/github.com/youmark/pkcs8/kdf_pbkdf2.go create mode 100644 vendor/github.com/youmark/pkcs8/kdf_scrypt.go create mode 100644 vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/codec_cache.go create mode 100644 vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/doc.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/aws/awserr/error.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/aws/awserr/types.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/aws/credentials/chain_provider.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/aws/credentials/credentials.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/aws/signer/v4/header_rules.go rename vendor/go.mongodb.org/mongo-driver/{x/mongo/driver/auth/internal/awsv4 => internal/aws/signer/v4}/request.go (96%) rename vendor/go.mongodb.org/mongo-driver/{x/mongo/driver/auth/internal/awsv4/rest.go => internal/aws/signer/v4/uri_path.go} (72%) rename vendor/go.mongodb.org/mongo-driver/{x/mongo/driver/auth/internal/awsv4/signer.go => internal/aws/signer/v4/v4.go} (80%) create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/aws/types.go rename vendor/go.mongodb.org/mongo-driver/internal/{string_util.go => bsonutil/bsonutil.go} (64%) delete mode 100644 vendor/go.mongodb.org/mongo-driver/internal/cancellation_listener.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/codecutil/encoding.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/credproviders/assume_role_provider.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/credproviders/ec2_provider.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/credproviders/ecs_provider.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/credproviders/env_provider.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/credproviders/imds_provider.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/credproviders/static_provider.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/csfle/csfle.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/csot/csot.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/driverutil/hello.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/driverutil/operation.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/internal/error.go rename vendor/go.mongodb.org/mongo-driver/internal/{const.go => handshake/handshake.go} (64%) create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/httputil/httputil.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/logger/component.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/logger/context.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/logger/io_sink.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/logger/level.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/logger/logger.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/ptrutil/int64.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/rand/bits.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/rand/exp.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/rand/normal.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/rand/rand.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/rand/rng.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/internal/uri_validation_errors.go rename vendor/go.mongodb.org/mongo-driver/{x/mongo/driver => internal}/uuid/uuid.go (54%) rename vendor/go.mongodb.org/mongo-driver/{internal => mongo}/background_context.go (87%) create mode 100644 vendor/go.mongodb.org/mongo-driver/mongo/mongointernal.go create mode 100644 vendor/go.mongodb.org/mongo-driver/mongo/options/doc.go create mode 100644 vendor/go.mongodb.org/mongo-driver/mongo/options/loggeroptions.go create mode 100644 vendor/go.mongodb.org/mongo-driver/mongo/options/rewrapdatakeyoptions.go create mode 100644 vendor/go.mongodb.org/mongo-driver/mongo/options/searchindexoptions.go create mode 100644 vendor/go.mongodb.org/mongo-driver/mongo/search_index_view.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/bsonx/array.go create mode 100644 vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/doc.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/bsonx/constructor.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/bsonx/document.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/bsonx/element.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/bsonx/mdocument.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/bsonx/primitive_codecs.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/bsonx/reflectionfree_d_codec.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/bsonx/registry.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/bsonx/value.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/DESIGN.md create mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/creds/awscreds.go create mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/creds/azurecreds.go create mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/creds/doc.go create mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/creds/gcpcreds.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/awsv4/credentials.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/awsv4/doc.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/awsv4/rules.go create mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/oidc.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/list_collections_batch_cursor.go create mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options/doc.go rename vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/{createIndexes.go => create_indexes.go} (83%) create mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/create_search_indexes.go create mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/doc.go create mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/drop_search_index.go create mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/update_search_index.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation_legacy.go create mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/session/doc.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/hanging_tls_conn_1_16.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/hanging_tls_conn_1_17.go delete mode 100644 vendor/golang.org/x/crypto/AUTHORS delete mode 100644 vendor/golang.org/x/crypto/CONTRIBUTORS create mode 100644 vendor/golang.org/x/crypto/scrypt/scrypt.go delete mode 100644 vendor/golang.org/x/sync/AUTHORS delete mode 100644 vendor/golang.org/x/sync/CONTRIBUTORS create mode 100644 vendor/golang.org/x/sync/errgroup/go120.go create mode 100644 vendor/golang.org/x/sync/errgroup/pre_go120.go create mode 100644 vendor/golang.org/x/sync/singleflight/singleflight.go delete mode 100644 vendor/golang.org/x/text/AUTHORS delete mode 100644 vendor/golang.org/x/text/CONTRIBUTORS create mode 100644 vendor/golang.org/x/text/unicode/norm/tables15.0.0.go diff --git a/go.mod b/go.mod index a1ad3ec..732475d 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,9 @@ go 1.13 require ( github.com/ClickHouse/clickhouse-go v1.5.4 + github.com/kr/pretty v0.1.0 // indirect github.com/pkg/errors v0.9.1 - go.mongodb.org/mongo-driver v1.9.0 + github.com/stretchr/testify v1.6.1 // indirect + go.mongodb.org/mongo-driver v1.17.7 + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) diff --git a/go.sum b/go.sum index 975a439..0e42d8e 100644 --- a/go.sum +++ b/go.sum @@ -1,60 +1,121 @@ github.com/ClickHouse/clickhouse-go v1.5.4 h1:cKjXeYLNWVJIx2J1K6H2CqyRmfwVJVY1OV1coaaFcI0= github.com/ClickHouse/clickhouse-go v1.5.4/go.mod h1:EaI/sW7Azgz9UATzd5ZdZHRUhHgv5+JMS9NSr2smCJI= +github.com/bkaradzic/go-lz4 v1.0.0 h1:RXc4wYsyz985CkXXeX04y4VnZFGG8Rd43pRaHsOXAKk= github.com/bkaradzic/go-lz4 v1.0.0/go.mod h1:0YdlkowM3VswSROI7qDxhRvJ3sLhlFrRRwjwegp5jy4= github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 h1:F1EaeKL/ta07PY/k9Os/UFtwERei2/XzGemhpGnBKNg= github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= -github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= -github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= -github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= +github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= +github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE= +github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= +github.com/pierrec/lz4 v2.0.5+incompatible h1:2xWsjqPFWcplujydGg4WmhC/6fZqK42wMM8aXeqhl0I= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= -github.com/xdg-go/scram v1.0.2 h1:akYIkZ28e6A96dkWNJQu3nmCzH3YfwMPQExUYDaRv7w= -github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= -github.com/xdg-go/stringprep v1.0.2 h1:6iq84/ryjjeRmMJwxutI51F2GIPlP5BfTvXHeYjyhBc= -github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= -github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= -github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= -go.mongodb.org/mongo-driver v1.9.0 h1:f3aLGJvQmBl8d9S40IL+jEyBC6hfLPbJjv9t5hEM9ck= -go.mongodb.org/mongo-driver v1.9.0/go.mod h1:0sQWfOeY63QTntERDJJ/0SuKK0T1uVSgKCuAROlKEPY= +github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= +github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.mongodb.org/mongo-driver v1.17.7 h1:a9w+U3Vt67eYzcfq3k/OAv284/uUUkL0uP75VE5rCOU= +go.mongodb.org/mongo-driver v1.17.7/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f h1:aZp0e2vLN4MToVqnjNEYEtrEA8RH8U8FN1CU7JgqsPU= -golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk= +golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= +golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.5 h1:i6eZZ+zk0SOf0xgBpEpPD18qWcJda6q1sxt3S0kzyUQ= -golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/vendor/github.com/go-stack/stack/.travis.yml b/vendor/github.com/go-stack/stack/.travis.yml deleted file mode 100644 index 5c5a2b5..0000000 --- a/vendor/github.com/go-stack/stack/.travis.yml +++ /dev/null @@ -1,15 +0,0 @@ -language: go -sudo: false -go: - - 1.7.x - - 1.8.x - - 1.9.x - - 1.10.x - - 1.11.x - - tip - -before_install: - - go get github.com/mattn/goveralls - -script: - - goveralls -service=travis-ci diff --git a/vendor/github.com/go-stack/stack/README.md b/vendor/github.com/go-stack/stack/README.md deleted file mode 100644 index f11cccc..0000000 --- a/vendor/github.com/go-stack/stack/README.md +++ /dev/null @@ -1,38 +0,0 @@ -[![GoDoc](https://godoc.org/github.com/go-stack/stack?status.svg)](https://godoc.org/github.com/go-stack/stack) -[![Go Report Card](https://goreportcard.com/badge/go-stack/stack)](https://goreportcard.com/report/go-stack/stack) -[![TravisCI](https://travis-ci.org/go-stack/stack.svg?branch=master)](https://travis-ci.org/go-stack/stack) -[![Coverage Status](https://coveralls.io/repos/github/go-stack/stack/badge.svg?branch=master)](https://coveralls.io/github/go-stack/stack?branch=master) - -# stack - -Package stack implements utilities to capture, manipulate, and format call -stacks. It provides a simpler API than package runtime. - -The implementation takes care of the minutia and special cases of interpreting -the program counter (pc) values returned by runtime.Callers. - -## Versioning - -Package stack publishes releases via [semver](http://semver.org/) compatible Git -tags prefixed with a single 'v'. The master branch always contains the latest -release. The develop branch contains unreleased commits. - -## Formatting - -Package stack's types implement fmt.Formatter, which provides a simple and -flexible way to declaratively configure formatting when used with logging or -error tracking packages. - -```go -func DoTheThing() { - c := stack.Caller(0) - log.Print(c) // "source.go:10" - log.Printf("%+v", c) // "pkg/path/source.go:10" - log.Printf("%n", c) // "DoTheThing" - - s := stack.Trace().TrimRuntime() - log.Print(s) // "[source.go:15 caller.go:42 main.go:14]" -} -``` - -See the docs for all of the supported formatting options. diff --git a/vendor/github.com/go-stack/stack/go.mod b/vendor/github.com/go-stack/stack/go.mod deleted file mode 100644 index 96a53a1..0000000 --- a/vendor/github.com/go-stack/stack/go.mod +++ /dev/null @@ -1 +0,0 @@ -module github.com/go-stack/stack diff --git a/vendor/github.com/go-stack/stack/stack.go b/vendor/github.com/go-stack/stack/stack.go deleted file mode 100644 index ac3b93b..0000000 --- a/vendor/github.com/go-stack/stack/stack.go +++ /dev/null @@ -1,400 +0,0 @@ -// +build go1.7 - -// Package stack implements utilities to capture, manipulate, and format call -// stacks. It provides a simpler API than package runtime. -// -// The implementation takes care of the minutia and special cases of -// interpreting the program counter (pc) values returned by runtime.Callers. -// -// Package stack's types implement fmt.Formatter, which provides a simple and -// flexible way to declaratively configure formatting when used with logging -// or error tracking packages. -package stack - -import ( - "bytes" - "errors" - "fmt" - "io" - "runtime" - "strconv" - "strings" -) - -// Call records a single function invocation from a goroutine stack. -type Call struct { - frame runtime.Frame -} - -// Caller returns a Call from the stack of the current goroutine. The argument -// skip is the number of stack frames to ascend, with 0 identifying the -// calling function. -func Caller(skip int) Call { - // As of Go 1.9 we need room for up to three PC entries. - // - // 0. An entry for the stack frame prior to the target to check for - // special handling needed if that prior entry is runtime.sigpanic. - // 1. A possible second entry to hold metadata about skipped inlined - // functions. If inline functions were not skipped the target frame - // PC will be here. - // 2. A third entry for the target frame PC when the second entry - // is used for skipped inline functions. - var pcs [3]uintptr - n := runtime.Callers(skip+1, pcs[:]) - frames := runtime.CallersFrames(pcs[:n]) - frame, _ := frames.Next() - frame, _ = frames.Next() - - return Call{ - frame: frame, - } -} - -// String implements fmt.Stinger. It is equivalent to fmt.Sprintf("%v", c). -func (c Call) String() string { - return fmt.Sprint(c) -} - -// MarshalText implements encoding.TextMarshaler. It formats the Call the same -// as fmt.Sprintf("%v", c). -func (c Call) MarshalText() ([]byte, error) { - if c.frame == (runtime.Frame{}) { - return nil, ErrNoFunc - } - - buf := bytes.Buffer{} - fmt.Fprint(&buf, c) - return buf.Bytes(), nil -} - -// ErrNoFunc means that the Call has a nil *runtime.Func. The most likely -// cause is a Call with the zero value. -var ErrNoFunc = errors.New("no call stack information") - -// Format implements fmt.Formatter with support for the following verbs. -// -// %s source file -// %d line number -// %n function name -// %k last segment of the package path -// %v equivalent to %s:%d -// -// It accepts the '+' and '#' flags for most of the verbs as follows. -// -// %+s path of source file relative to the compile time GOPATH, -// or the module path joined to the path of source file relative -// to module root -// %#s full path of source file -// %+n import path qualified function name -// %+k full package path -// %+v equivalent to %+s:%d -// %#v equivalent to %#s:%d -func (c Call) Format(s fmt.State, verb rune) { - if c.frame == (runtime.Frame{}) { - fmt.Fprintf(s, "%%!%c(NOFUNC)", verb) - return - } - - switch verb { - case 's', 'v': - file := c.frame.File - switch { - case s.Flag('#'): - // done - case s.Flag('+'): - file = pkgFilePath(&c.frame) - default: - const sep = "/" - if i := strings.LastIndex(file, sep); i != -1 { - file = file[i+len(sep):] - } - } - io.WriteString(s, file) - if verb == 'v' { - buf := [7]byte{':'} - s.Write(strconv.AppendInt(buf[:1], int64(c.frame.Line), 10)) - } - - case 'd': - buf := [6]byte{} - s.Write(strconv.AppendInt(buf[:0], int64(c.frame.Line), 10)) - - case 'k': - name := c.frame.Function - const pathSep = "/" - start, end := 0, len(name) - if i := strings.LastIndex(name, pathSep); i != -1 { - start = i + len(pathSep) - } - const pkgSep = "." - if i := strings.Index(name[start:], pkgSep); i != -1 { - end = start + i - } - if s.Flag('+') { - start = 0 - } - io.WriteString(s, name[start:end]) - - case 'n': - name := c.frame.Function - if !s.Flag('+') { - const pathSep = "/" - if i := strings.LastIndex(name, pathSep); i != -1 { - name = name[i+len(pathSep):] - } - const pkgSep = "." - if i := strings.Index(name, pkgSep); i != -1 { - name = name[i+len(pkgSep):] - } - } - io.WriteString(s, name) - } -} - -// Frame returns the call frame infomation for the Call. -func (c Call) Frame() runtime.Frame { - return c.frame -} - -// PC returns the program counter for this call frame; multiple frames may -// have the same PC value. -// -// Deprecated: Use Call.Frame instead. -func (c Call) PC() uintptr { - return c.frame.PC -} - -// CallStack records a sequence of function invocations from a goroutine -// stack. -type CallStack []Call - -// String implements fmt.Stinger. It is equivalent to fmt.Sprintf("%v", cs). -func (cs CallStack) String() string { - return fmt.Sprint(cs) -} - -var ( - openBracketBytes = []byte("[") - closeBracketBytes = []byte("]") - spaceBytes = []byte(" ") -) - -// MarshalText implements encoding.TextMarshaler. It formats the CallStack the -// same as fmt.Sprintf("%v", cs). -func (cs CallStack) MarshalText() ([]byte, error) { - buf := bytes.Buffer{} - buf.Write(openBracketBytes) - for i, pc := range cs { - if i > 0 { - buf.Write(spaceBytes) - } - fmt.Fprint(&buf, pc) - } - buf.Write(closeBracketBytes) - return buf.Bytes(), nil -} - -// Format implements fmt.Formatter by printing the CallStack as square brackets -// ([, ]) surrounding a space separated list of Calls each formatted with the -// supplied verb and options. -func (cs CallStack) Format(s fmt.State, verb rune) { - s.Write(openBracketBytes) - for i, pc := range cs { - if i > 0 { - s.Write(spaceBytes) - } - pc.Format(s, verb) - } - s.Write(closeBracketBytes) -} - -// Trace returns a CallStack for the current goroutine with element 0 -// identifying the calling function. -func Trace() CallStack { - var pcs [512]uintptr - n := runtime.Callers(1, pcs[:]) - - frames := runtime.CallersFrames(pcs[:n]) - cs := make(CallStack, 0, n) - - // Skip extra frame retrieved just to make sure the runtime.sigpanic - // special case is handled. - frame, more := frames.Next() - - for more { - frame, more = frames.Next() - cs = append(cs, Call{frame: frame}) - } - - return cs -} - -// TrimBelow returns a slice of the CallStack with all entries below c -// removed. -func (cs CallStack) TrimBelow(c Call) CallStack { - for len(cs) > 0 && cs[0] != c { - cs = cs[1:] - } - return cs -} - -// TrimAbove returns a slice of the CallStack with all entries above c -// removed. -func (cs CallStack) TrimAbove(c Call) CallStack { - for len(cs) > 0 && cs[len(cs)-1] != c { - cs = cs[:len(cs)-1] - } - return cs -} - -// pkgIndex returns the index that results in file[index:] being the path of -// file relative to the compile time GOPATH, and file[:index] being the -// $GOPATH/src/ portion of file. funcName must be the name of a function in -// file as returned by runtime.Func.Name. -func pkgIndex(file, funcName string) int { - // As of Go 1.6.2 there is no direct way to know the compile time GOPATH - // at runtime, but we can infer the number of path segments in the GOPATH. - // We note that runtime.Func.Name() returns the function name qualified by - // the import path, which does not include the GOPATH. Thus we can trim - // segments from the beginning of the file path until the number of path - // separators remaining is one more than the number of path separators in - // the function name. For example, given: - // - // GOPATH /home/user - // file /home/user/src/pkg/sub/file.go - // fn.Name() pkg/sub.Type.Method - // - // We want to produce: - // - // file[:idx] == /home/user/src/ - // file[idx:] == pkg/sub/file.go - // - // From this we can easily see that fn.Name() has one less path separator - // than our desired result for file[idx:]. We count separators from the - // end of the file path until it finds two more than in the function name - // and then move one character forward to preserve the initial path - // segment without a leading separator. - const sep = "/" - i := len(file) - for n := strings.Count(funcName, sep) + 2; n > 0; n-- { - i = strings.LastIndex(file[:i], sep) - if i == -1 { - i = -len(sep) - break - } - } - // get back to 0 or trim the leading separator - return i + len(sep) -} - -// pkgFilePath returns the frame's filepath relative to the compile-time GOPATH, -// or its module path joined to its path relative to the module root. -// -// As of Go 1.11 there is no direct way to know the compile time GOPATH or -// module paths at runtime, but we can piece together the desired information -// from available information. We note that runtime.Frame.Function contains the -// function name qualified by the package path, which includes the module path -// but not the GOPATH. We can extract the package path from that and append the -// last segments of the file path to arrive at the desired package qualified -// file path. For example, given: -// -// GOPATH /home/user -// import path pkg/sub -// frame.File /home/user/src/pkg/sub/file.go -// frame.Function pkg/sub.Type.Method -// Desired return pkg/sub/file.go -// -// It appears that we simply need to trim ".Type.Method" from frame.Function and -// append "/" + path.Base(file). -// -// But there are other wrinkles. Although it is idiomatic to do so, the internal -// name of a package is not required to match the last segment of its import -// path. In addition, the introduction of modules in Go 1.11 allows working -// without a GOPATH. So we also must make these work right: -// -// GOPATH /home/user -// import path pkg/go-sub -// package name sub -// frame.File /home/user/src/pkg/go-sub/file.go -// frame.Function pkg/sub.Type.Method -// Desired return pkg/go-sub/file.go -// -// Module path pkg/v2 -// import path pkg/v2/go-sub -// package name sub -// frame.File /home/user/cloned-pkg/go-sub/file.go -// frame.Function pkg/v2/sub.Type.Method -// Desired return pkg/v2/go-sub/file.go -// -// We can handle all of these situations by using the package path extracted -// from frame.Function up to, but not including, the last segment as the prefix -// and the last two segments of frame.File as the suffix of the returned path. -// This preserves the existing behavior when working in a GOPATH without modules -// and a semantically equivalent behavior when used in module aware project. -func pkgFilePath(frame *runtime.Frame) string { - pre := pkgPrefix(frame.Function) - post := pathSuffix(frame.File) - if pre == "" { - return post - } - return pre + "/" + post -} - -// pkgPrefix returns the import path of the function's package with the final -// segment removed. -func pkgPrefix(funcName string) string { - const pathSep = "/" - end := strings.LastIndex(funcName, pathSep) - if end == -1 { - return "" - } - return funcName[:end] -} - -// pathSuffix returns the last two segments of path. -func pathSuffix(path string) string { - const pathSep = "/" - lastSep := strings.LastIndex(path, pathSep) - if lastSep == -1 { - return path - } - return path[strings.LastIndex(path[:lastSep], pathSep)+1:] -} - -var runtimePath string - -func init() { - var pcs [3]uintptr - runtime.Callers(0, pcs[:]) - frames := runtime.CallersFrames(pcs[:]) - frame, _ := frames.Next() - file := frame.File - - idx := pkgIndex(frame.File, frame.Function) - - runtimePath = file[:idx] - if runtime.GOOS == "windows" { - runtimePath = strings.ToLower(runtimePath) - } -} - -func inGoroot(c Call) bool { - file := c.frame.File - if len(file) == 0 || file[0] == '?' { - return true - } - if runtime.GOOS == "windows" { - file = strings.ToLower(file) - } - return strings.HasPrefix(file, runtimePath) || strings.HasSuffix(file, "/_testmain.go") -} - -// TrimRuntime returns a slice of the CallStack with the topmost entries from -// the go runtime removed. It considers any calls originating from unknown -// files, files under GOROOT, or _testmain.go as part of the runtime. -func (cs CallStack) TrimRuntime() CallStack { - for len(cs) > 0 && inGoroot(cs[len(cs)-1]) { - cs = cs[:len(cs)-1] - } - return cs -} diff --git a/vendor/github.com/golang/snappy/AUTHORS b/vendor/github.com/golang/snappy/AUTHORS index bcfa195..52ccb5a 100644 --- a/vendor/github.com/golang/snappy/AUTHORS +++ b/vendor/github.com/golang/snappy/AUTHORS @@ -8,8 +8,11 @@ # Please keep the list sorted. +Amazon.com, Inc Damian Gryski +Eric Buth Google Inc. Jan Mercl <0xjnml@gmail.com> +Klaus Post Rodolfo Carvalho Sebastien Binet diff --git a/vendor/github.com/golang/snappy/CONTRIBUTORS b/vendor/github.com/golang/snappy/CONTRIBUTORS index 931ae31..ea6524d 100644 --- a/vendor/github.com/golang/snappy/CONTRIBUTORS +++ b/vendor/github.com/golang/snappy/CONTRIBUTORS @@ -26,9 +26,13 @@ # Please keep the list sorted. +Alex Legg Damian Gryski +Eric Buth Jan Mercl <0xjnml@gmail.com> +Jonathan Swinney Kai Backman +Klaus Post Marc-Antoine Ruel Nigel Tao Rob Pike diff --git a/vendor/github.com/golang/snappy/decode.go b/vendor/github.com/golang/snappy/decode.go index 72efb03..23c6e26 100644 --- a/vendor/github.com/golang/snappy/decode.go +++ b/vendor/github.com/golang/snappy/decode.go @@ -52,6 +52,8 @@ const ( // Otherwise, a newly allocated slice will be returned. // // The dst and src must not overlap. It is valid to pass a nil dst. +// +// Decode handles the Snappy block format, not the Snappy stream format. func Decode(dst, src []byte) ([]byte, error) { dLen, s, err := decodedLen(src) if err != nil { @@ -83,6 +85,8 @@ func NewReader(r io.Reader) *Reader { } // Reader is an io.Reader that can read Snappy-compressed bytes. +// +// Reader handles the Snappy stream format, not the Snappy block format. type Reader struct { r io.Reader err error @@ -114,32 +118,23 @@ func (r *Reader) readFull(p []byte, allowEOF bool) (ok bool) { return true } -// Read satisfies the io.Reader interface. -func (r *Reader) Read(p []byte) (int, error) { - if r.err != nil { - return 0, r.err - } - for { - if r.i < r.j { - n := copy(p, r.decoded[r.i:r.j]) - r.i += n - return n, nil - } +func (r *Reader) fill() error { + for r.i >= r.j { if !r.readFull(r.buf[:4], true) { - return 0, r.err + return r.err } chunkType := r.buf[0] if !r.readHeader { if chunkType != chunkTypeStreamIdentifier { r.err = ErrCorrupt - return 0, r.err + return r.err } r.readHeader = true } chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16 if chunkLen > len(r.buf) { r.err = ErrUnsupported - return 0, r.err + return r.err } // The chunk types are specified at @@ -149,11 +144,11 @@ func (r *Reader) Read(p []byte) (int, error) { // Section 4.2. Compressed data (chunk type 0x00). if chunkLen < checksumSize { r.err = ErrCorrupt - return 0, r.err + return r.err } buf := r.buf[:chunkLen] if !r.readFull(buf, false) { - return 0, r.err + return r.err } checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24 buf = buf[checksumSize:] @@ -161,19 +156,19 @@ func (r *Reader) Read(p []byte) (int, error) { n, err := DecodedLen(buf) if err != nil { r.err = err - return 0, r.err + return r.err } if n > len(r.decoded) { r.err = ErrCorrupt - return 0, r.err + return r.err } if _, err := Decode(r.decoded, buf); err != nil { r.err = err - return 0, r.err + return r.err } if crc(r.decoded[:n]) != checksum { r.err = ErrCorrupt - return 0, r.err + return r.err } r.i, r.j = 0, n continue @@ -182,25 +177,25 @@ func (r *Reader) Read(p []byte) (int, error) { // Section 4.3. Uncompressed data (chunk type 0x01). if chunkLen < checksumSize { r.err = ErrCorrupt - return 0, r.err + return r.err } buf := r.buf[:checksumSize] if !r.readFull(buf, false) { - return 0, r.err + return r.err } checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24 // Read directly into r.decoded instead of via r.buf. n := chunkLen - checksumSize if n > len(r.decoded) { r.err = ErrCorrupt - return 0, r.err + return r.err } if !r.readFull(r.decoded[:n], false) { - return 0, r.err + return r.err } if crc(r.decoded[:n]) != checksum { r.err = ErrCorrupt - return 0, r.err + return r.err } r.i, r.j = 0, n continue @@ -209,15 +204,15 @@ func (r *Reader) Read(p []byte) (int, error) { // Section 4.1. Stream identifier (chunk type 0xff). if chunkLen != len(magicBody) { r.err = ErrCorrupt - return 0, r.err + return r.err } if !r.readFull(r.buf[:len(magicBody)], false) { - return 0, r.err + return r.err } for i := 0; i < len(magicBody); i++ { if r.buf[i] != magicBody[i] { r.err = ErrCorrupt - return 0, r.err + return r.err } } continue @@ -226,12 +221,44 @@ func (r *Reader) Read(p []byte) (int, error) { if chunkType <= 0x7f { // Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f). r.err = ErrUnsupported - return 0, r.err + return r.err } // Section 4.4 Padding (chunk type 0xfe). // Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd). if !r.readFull(r.buf[:chunkLen], false) { - return 0, r.err + return r.err } } + + return nil +} + +// Read satisfies the io.Reader interface. +func (r *Reader) Read(p []byte) (int, error) { + if r.err != nil { + return 0, r.err + } + + if err := r.fill(); err != nil { + return 0, err + } + + n := copy(p, r.decoded[r.i:r.j]) + r.i += n + return n, nil +} + +// ReadByte satisfies the io.ByteReader interface. +func (r *Reader) ReadByte() (byte, error) { + if r.err != nil { + return 0, r.err + } + + if err := r.fill(); err != nil { + return 0, err + } + + c := r.decoded[r.i] + r.i++ + return c, nil } diff --git a/vendor/github.com/golang/snappy/decode_arm64.s b/vendor/github.com/golang/snappy/decode_arm64.s new file mode 100644 index 0000000..7a3ead1 --- /dev/null +++ b/vendor/github.com/golang/snappy/decode_arm64.s @@ -0,0 +1,494 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !appengine +// +build gc +// +build !noasm + +#include "textflag.h" + +// The asm code generally follows the pure Go code in decode_other.go, except +// where marked with a "!!!". + +// func decode(dst, src []byte) int +// +// All local variables fit into registers. The non-zero stack size is only to +// spill registers and push args when issuing a CALL. The register allocation: +// - R2 scratch +// - R3 scratch +// - R4 length or x +// - R5 offset +// - R6 &src[s] +// - R7 &dst[d] +// + R8 dst_base +// + R9 dst_len +// + R10 dst_base + dst_len +// + R11 src_base +// + R12 src_len +// + R13 src_base + src_len +// - R14 used by doCopy +// - R15 used by doCopy +// +// The registers R8-R13 (marked with a "+") are set at the start of the +// function, and after a CALL returns, and are not otherwise modified. +// +// The d variable is implicitly R7 - R8, and len(dst)-d is R10 - R7. +// The s variable is implicitly R6 - R11, and len(src)-s is R13 - R6. +TEXT ·decode(SB), NOSPLIT, $56-56 + // Initialize R6, R7 and R8-R13. + MOVD dst_base+0(FP), R8 + MOVD dst_len+8(FP), R9 + MOVD R8, R7 + MOVD R8, R10 + ADD R9, R10, R10 + MOVD src_base+24(FP), R11 + MOVD src_len+32(FP), R12 + MOVD R11, R6 + MOVD R11, R13 + ADD R12, R13, R13 + +loop: + // for s < len(src) + CMP R13, R6 + BEQ end + + // R4 = uint32(src[s]) + // + // switch src[s] & 0x03 + MOVBU (R6), R4 + MOVW R4, R3 + ANDW $3, R3 + MOVW $1, R1 + CMPW R1, R3 + BGE tagCopy + + // ---------------------------------------- + // The code below handles literal tags. + + // case tagLiteral: + // x := uint32(src[s] >> 2) + // switch + MOVW $60, R1 + LSRW $2, R4, R4 + CMPW R4, R1 + BLS tagLit60Plus + + // case x < 60: + // s++ + ADD $1, R6, R6 + +doLit: + // This is the end of the inner "switch", when we have a literal tag. + // + // We assume that R4 == x and x fits in a uint32, where x is the variable + // used in the pure Go decode_other.go code. + + // length = int(x) + 1 + // + // Unlike the pure Go code, we don't need to check if length <= 0 because + // R4 can hold 64 bits, so the increment cannot overflow. + ADD $1, R4, R4 + + // Prepare to check if copying length bytes will run past the end of dst or + // src. + // + // R2 = len(dst) - d + // R3 = len(src) - s + MOVD R10, R2 + SUB R7, R2, R2 + MOVD R13, R3 + SUB R6, R3, R3 + + // !!! Try a faster technique for short (16 or fewer bytes) copies. + // + // if length > 16 || len(dst)-d < 16 || len(src)-s < 16 { + // goto callMemmove // Fall back on calling runtime·memmove. + // } + // + // The C++ snappy code calls this TryFastAppend. It also checks len(src)-s + // against 21 instead of 16, because it cannot assume that all of its input + // is contiguous in memory and so it needs to leave enough source bytes to + // read the next tag without refilling buffers, but Go's Decode assumes + // contiguousness (the src argument is a []byte). + CMP $16, R4 + BGT callMemmove + CMP $16, R2 + BLT callMemmove + CMP $16, R3 + BLT callMemmove + + // !!! Implement the copy from src to dst as a 16-byte load and store. + // (Decode's documentation says that dst and src must not overlap.) + // + // This always copies 16 bytes, instead of only length bytes, but that's + // OK. If the input is a valid Snappy encoding then subsequent iterations + // will fix up the overrun. Otherwise, Decode returns a nil []byte (and a + // non-nil error), so the overrun will be ignored. + // + // Note that on arm64, it is legal and cheap to issue unaligned 8-byte or + // 16-byte loads and stores. This technique probably wouldn't be as + // effective on architectures that are fussier about alignment. + LDP 0(R6), (R14, R15) + STP (R14, R15), 0(R7) + + // d += length + // s += length + ADD R4, R7, R7 + ADD R4, R6, R6 + B loop + +callMemmove: + // if length > len(dst)-d || length > len(src)-s { etc } + CMP R2, R4 + BGT errCorrupt + CMP R3, R4 + BGT errCorrupt + + // copy(dst[d:], src[s:s+length]) + // + // This means calling runtime·memmove(&dst[d], &src[s], length), so we push + // R7, R6 and R4 as arguments. Coincidentally, we also need to spill those + // three registers to the stack, to save local variables across the CALL. + MOVD R7, 8(RSP) + MOVD R6, 16(RSP) + MOVD R4, 24(RSP) + MOVD R7, 32(RSP) + MOVD R6, 40(RSP) + MOVD R4, 48(RSP) + CALL runtime·memmove(SB) + + // Restore local variables: unspill registers from the stack and + // re-calculate R8-R13. + MOVD 32(RSP), R7 + MOVD 40(RSP), R6 + MOVD 48(RSP), R4 + MOVD dst_base+0(FP), R8 + MOVD dst_len+8(FP), R9 + MOVD R8, R10 + ADD R9, R10, R10 + MOVD src_base+24(FP), R11 + MOVD src_len+32(FP), R12 + MOVD R11, R13 + ADD R12, R13, R13 + + // d += length + // s += length + ADD R4, R7, R7 + ADD R4, R6, R6 + B loop + +tagLit60Plus: + // !!! This fragment does the + // + // s += x - 58; if uint(s) > uint(len(src)) { etc } + // + // checks. In the asm version, we code it once instead of once per switch case. + ADD R4, R6, R6 + SUB $58, R6, R6 + MOVD R6, R3 + SUB R11, R3, R3 + CMP R12, R3 + BGT errCorrupt + + // case x == 60: + MOVW $61, R1 + CMPW R1, R4 + BEQ tagLit61 + BGT tagLit62Plus + + // x = uint32(src[s-1]) + MOVBU -1(R6), R4 + B doLit + +tagLit61: + // case x == 61: + // x = uint32(src[s-2]) | uint32(src[s-1])<<8 + MOVHU -2(R6), R4 + B doLit + +tagLit62Plus: + CMPW $62, R4 + BHI tagLit63 + + // case x == 62: + // x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16 + MOVHU -3(R6), R4 + MOVBU -1(R6), R3 + ORR R3<<16, R4 + B doLit + +tagLit63: + // case x == 63: + // x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24 + MOVWU -4(R6), R4 + B doLit + + // The code above handles literal tags. + // ---------------------------------------- + // The code below handles copy tags. + +tagCopy4: + // case tagCopy4: + // s += 5 + ADD $5, R6, R6 + + // if uint(s) > uint(len(src)) { etc } + MOVD R6, R3 + SUB R11, R3, R3 + CMP R12, R3 + BGT errCorrupt + + // length = 1 + int(src[s-5])>>2 + MOVD $1, R1 + ADD R4>>2, R1, R4 + + // offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24) + MOVWU -4(R6), R5 + B doCopy + +tagCopy2: + // case tagCopy2: + // s += 3 + ADD $3, R6, R6 + + // if uint(s) > uint(len(src)) { etc } + MOVD R6, R3 + SUB R11, R3, R3 + CMP R12, R3 + BGT errCorrupt + + // length = 1 + int(src[s-3])>>2 + MOVD $1, R1 + ADD R4>>2, R1, R4 + + // offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8) + MOVHU -2(R6), R5 + B doCopy + +tagCopy: + // We have a copy tag. We assume that: + // - R3 == src[s] & 0x03 + // - R4 == src[s] + CMP $2, R3 + BEQ tagCopy2 + BGT tagCopy4 + + // case tagCopy1: + // s += 2 + ADD $2, R6, R6 + + // if uint(s) > uint(len(src)) { etc } + MOVD R6, R3 + SUB R11, R3, R3 + CMP R12, R3 + BGT errCorrupt + + // offset = int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1])) + MOVD R4, R5 + AND $0xe0, R5 + MOVBU -1(R6), R3 + ORR R5<<3, R3, R5 + + // length = 4 + int(src[s-2])>>2&0x7 + MOVD $7, R1 + AND R4>>2, R1, R4 + ADD $4, R4, R4 + +doCopy: + // This is the end of the outer "switch", when we have a copy tag. + // + // We assume that: + // - R4 == length && R4 > 0 + // - R5 == offset + + // if offset <= 0 { etc } + MOVD $0, R1 + CMP R1, R5 + BLE errCorrupt + + // if d < offset { etc } + MOVD R7, R3 + SUB R8, R3, R3 + CMP R5, R3 + BLT errCorrupt + + // if length > len(dst)-d { etc } + MOVD R10, R3 + SUB R7, R3, R3 + CMP R3, R4 + BGT errCorrupt + + // forwardCopy(dst[d:d+length], dst[d-offset:]); d += length + // + // Set: + // - R14 = len(dst)-d + // - R15 = &dst[d-offset] + MOVD R10, R14 + SUB R7, R14, R14 + MOVD R7, R15 + SUB R5, R15, R15 + + // !!! Try a faster technique for short (16 or fewer bytes) forward copies. + // + // First, try using two 8-byte load/stores, similar to the doLit technique + // above. Even if dst[d:d+length] and dst[d-offset:] can overlap, this is + // still OK if offset >= 8. Note that this has to be two 8-byte load/stores + // and not one 16-byte load/store, and the first store has to be before the + // second load, due to the overlap if offset is in the range [8, 16). + // + // if length > 16 || offset < 8 || len(dst)-d < 16 { + // goto slowForwardCopy + // } + // copy 16 bytes + // d += length + CMP $16, R4 + BGT slowForwardCopy + CMP $8, R5 + BLT slowForwardCopy + CMP $16, R14 + BLT slowForwardCopy + MOVD 0(R15), R2 + MOVD R2, 0(R7) + MOVD 8(R15), R3 + MOVD R3, 8(R7) + ADD R4, R7, R7 + B loop + +slowForwardCopy: + // !!! If the forward copy is longer than 16 bytes, or if offset < 8, we + // can still try 8-byte load stores, provided we can overrun up to 10 extra + // bytes. As above, the overrun will be fixed up by subsequent iterations + // of the outermost loop. + // + // The C++ snappy code calls this technique IncrementalCopyFastPath. Its + // commentary says: + // + // ---- + // + // The main part of this loop is a simple copy of eight bytes at a time + // until we've copied (at least) the requested amount of bytes. However, + // if d and d-offset are less than eight bytes apart (indicating a + // repeating pattern of length < 8), we first need to expand the pattern in + // order to get the correct results. For instance, if the buffer looks like + // this, with the eight-byte and patterns marked as + // intervals: + // + // abxxxxxxxxxxxx + // [------] d-offset + // [------] d + // + // a single eight-byte copy from to will repeat the pattern + // once, after which we can move two bytes without moving : + // + // ababxxxxxxxxxx + // [------] d-offset + // [------] d + // + // and repeat the exercise until the two no longer overlap. + // + // This allows us to do very well in the special case of one single byte + // repeated many times, without taking a big hit for more general cases. + // + // The worst case of extra writing past the end of the match occurs when + // offset == 1 and length == 1; the last copy will read from byte positions + // [0..7] and write to [4..11], whereas it was only supposed to write to + // position 1. Thus, ten excess bytes. + // + // ---- + // + // That "10 byte overrun" worst case is confirmed by Go's + // TestSlowForwardCopyOverrun, which also tests the fixUpSlowForwardCopy + // and finishSlowForwardCopy algorithm. + // + // if length > len(dst)-d-10 { + // goto verySlowForwardCopy + // } + SUB $10, R14, R14 + CMP R14, R4 + BGT verySlowForwardCopy + +makeOffsetAtLeast8: + // !!! As above, expand the pattern so that offset >= 8 and we can use + // 8-byte load/stores. + // + // for offset < 8 { + // copy 8 bytes from dst[d-offset:] to dst[d:] + // length -= offset + // d += offset + // offset += offset + // // The two previous lines together means that d-offset, and therefore + // // R15, is unchanged. + // } + CMP $8, R5 + BGE fixUpSlowForwardCopy + MOVD (R15), R3 + MOVD R3, (R7) + SUB R5, R4, R4 + ADD R5, R7, R7 + ADD R5, R5, R5 + B makeOffsetAtLeast8 + +fixUpSlowForwardCopy: + // !!! Add length (which might be negative now) to d (implied by R7 being + // &dst[d]) so that d ends up at the right place when we jump back to the + // top of the loop. Before we do that, though, we save R7 to R2 so that, if + // length is positive, copying the remaining length bytes will write to the + // right place. + MOVD R7, R2 + ADD R4, R7, R7 + +finishSlowForwardCopy: + // !!! Repeat 8-byte load/stores until length <= 0. Ending with a negative + // length means that we overrun, but as above, that will be fixed up by + // subsequent iterations of the outermost loop. + MOVD $0, R1 + CMP R1, R4 + BLE loop + MOVD (R15), R3 + MOVD R3, (R2) + ADD $8, R15, R15 + ADD $8, R2, R2 + SUB $8, R4, R4 + B finishSlowForwardCopy + +verySlowForwardCopy: + // verySlowForwardCopy is a simple implementation of forward copy. In C + // parlance, this is a do/while loop instead of a while loop, since we know + // that length > 0. In Go syntax: + // + // for { + // dst[d] = dst[d - offset] + // d++ + // length-- + // if length == 0 { + // break + // } + // } + MOVB (R15), R3 + MOVB R3, (R7) + ADD $1, R15, R15 + ADD $1, R7, R7 + SUB $1, R4, R4 + CBNZ R4, verySlowForwardCopy + B loop + + // The code above handles copy tags. + // ---------------------------------------- + +end: + // This is the end of the "for s < len(src)". + // + // if d != len(dst) { etc } + CMP R10, R7 + BNE errCorrupt + + // return 0 + MOVD $0, ret+48(FP) + RET + +errCorrupt: + // return decodeErrCodeCorrupt + MOVD $1, R2 + MOVD R2, ret+48(FP) + RET diff --git a/vendor/github.com/golang/snappy/decode_amd64.go b/vendor/github.com/golang/snappy/decode_asm.go similarity index 93% rename from vendor/github.com/golang/snappy/decode_amd64.go rename to vendor/github.com/golang/snappy/decode_asm.go index fcd192b..7082b34 100644 --- a/vendor/github.com/golang/snappy/decode_amd64.go +++ b/vendor/github.com/golang/snappy/decode_asm.go @@ -5,6 +5,7 @@ // +build !appengine // +build gc // +build !noasm +// +build amd64 arm64 package snappy diff --git a/vendor/github.com/golang/snappy/decode_other.go b/vendor/github.com/golang/snappy/decode_other.go index 8c9f204..2f672be 100644 --- a/vendor/github.com/golang/snappy/decode_other.go +++ b/vendor/github.com/golang/snappy/decode_other.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build !amd64 appengine !gc noasm +// +build !amd64,!arm64 appengine !gc noasm package snappy @@ -85,14 +85,28 @@ func decode(dst, src []byte) int { if offset <= 0 || d < offset || length > len(dst)-d { return decodeErrCodeCorrupt } - // Copy from an earlier sub-slice of dst to a later sub-slice. Unlike - // the built-in copy function, this byte-by-byte copy always runs + // Copy from an earlier sub-slice of dst to a later sub-slice. + // If no overlap, use the built-in copy: + if offset >= length { + copy(dst[d:d+length], dst[d-offset:]) + d += length + continue + } + + // Unlike the built-in copy function, this byte-by-byte copy always runs // forwards, even if the slices overlap. Conceptually, this is: // // d += forwardCopy(dst[d:d+length], dst[d-offset:]) - for end := d + length; d != end; d++ { - dst[d] = dst[d-offset] + // + // We align the slices into a and b and show the compiler they are the same size. + // This allows the loop to run without bounds checks. + a := dst[d : d+length] + b := dst[d-offset:] + b = b[:len(a)] + for i := range a { + a[i] = b[i] } + d += length } if d != len(dst) { return decodeErrCodeCorrupt diff --git a/vendor/github.com/golang/snappy/encode.go b/vendor/github.com/golang/snappy/encode.go index 8d393e9..7f23657 100644 --- a/vendor/github.com/golang/snappy/encode.go +++ b/vendor/github.com/golang/snappy/encode.go @@ -15,6 +15,8 @@ import ( // Otherwise, a newly allocated slice will be returned. // // The dst and src must not overlap. It is valid to pass a nil dst. +// +// Encode handles the Snappy block format, not the Snappy stream format. func Encode(dst, src []byte) []byte { if n := MaxEncodedLen(len(src)); n < 0 { panic(ErrTooLarge) @@ -139,6 +141,8 @@ func NewBufferedWriter(w io.Writer) *Writer { } // Writer is an io.Writer that can write Snappy-compressed bytes. +// +// Writer handles the Snappy stream format, not the Snappy block format. type Writer struct { w io.Writer err error diff --git a/vendor/github.com/golang/snappy/encode_arm64.s b/vendor/github.com/golang/snappy/encode_arm64.s new file mode 100644 index 0000000..f8d54ad --- /dev/null +++ b/vendor/github.com/golang/snappy/encode_arm64.s @@ -0,0 +1,722 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !appengine +// +build gc +// +build !noasm + +#include "textflag.h" + +// The asm code generally follows the pure Go code in encode_other.go, except +// where marked with a "!!!". + +// ---------------------------------------------------------------------------- + +// func emitLiteral(dst, lit []byte) int +// +// All local variables fit into registers. The register allocation: +// - R3 len(lit) +// - R4 n +// - R6 return value +// - R8 &dst[i] +// - R10 &lit[0] +// +// The 32 bytes of stack space is to call runtime·memmove. +// +// The unusual register allocation of local variables, such as R10 for the +// source pointer, matches the allocation used at the call site in encodeBlock, +// which makes it easier to manually inline this function. +TEXT ·emitLiteral(SB), NOSPLIT, $32-56 + MOVD dst_base+0(FP), R8 + MOVD lit_base+24(FP), R10 + MOVD lit_len+32(FP), R3 + MOVD R3, R6 + MOVW R3, R4 + SUBW $1, R4, R4 + + CMPW $60, R4 + BLT oneByte + CMPW $256, R4 + BLT twoBytes + +threeBytes: + MOVD $0xf4, R2 + MOVB R2, 0(R8) + MOVW R4, 1(R8) + ADD $3, R8, R8 + ADD $3, R6, R6 + B memmove + +twoBytes: + MOVD $0xf0, R2 + MOVB R2, 0(R8) + MOVB R4, 1(R8) + ADD $2, R8, R8 + ADD $2, R6, R6 + B memmove + +oneByte: + LSLW $2, R4, R4 + MOVB R4, 0(R8) + ADD $1, R8, R8 + ADD $1, R6, R6 + +memmove: + MOVD R6, ret+48(FP) + + // copy(dst[i:], lit) + // + // This means calling runtime·memmove(&dst[i], &lit[0], len(lit)), so we push + // R8, R10 and R3 as arguments. + MOVD R8, 8(RSP) + MOVD R10, 16(RSP) + MOVD R3, 24(RSP) + CALL runtime·memmove(SB) + RET + +// ---------------------------------------------------------------------------- + +// func emitCopy(dst []byte, offset, length int) int +// +// All local variables fit into registers. The register allocation: +// - R3 length +// - R7 &dst[0] +// - R8 &dst[i] +// - R11 offset +// +// The unusual register allocation of local variables, such as R11 for the +// offset, matches the allocation used at the call site in encodeBlock, which +// makes it easier to manually inline this function. +TEXT ·emitCopy(SB), NOSPLIT, $0-48 + MOVD dst_base+0(FP), R8 + MOVD R8, R7 + MOVD offset+24(FP), R11 + MOVD length+32(FP), R3 + +loop0: + // for length >= 68 { etc } + CMPW $68, R3 + BLT step1 + + // Emit a length 64 copy, encoded as 3 bytes. + MOVD $0xfe, R2 + MOVB R2, 0(R8) + MOVW R11, 1(R8) + ADD $3, R8, R8 + SUB $64, R3, R3 + B loop0 + +step1: + // if length > 64 { etc } + CMP $64, R3 + BLE step2 + + // Emit a length 60 copy, encoded as 3 bytes. + MOVD $0xee, R2 + MOVB R2, 0(R8) + MOVW R11, 1(R8) + ADD $3, R8, R8 + SUB $60, R3, R3 + +step2: + // if length >= 12 || offset >= 2048 { goto step3 } + CMP $12, R3 + BGE step3 + CMPW $2048, R11 + BGE step3 + + // Emit the remaining copy, encoded as 2 bytes. + MOVB R11, 1(R8) + LSRW $3, R11, R11 + AND $0xe0, R11, R11 + SUB $4, R3, R3 + LSLW $2, R3 + AND $0xff, R3, R3 + ORRW R3, R11, R11 + ORRW $1, R11, R11 + MOVB R11, 0(R8) + ADD $2, R8, R8 + + // Return the number of bytes written. + SUB R7, R8, R8 + MOVD R8, ret+40(FP) + RET + +step3: + // Emit the remaining copy, encoded as 3 bytes. + SUB $1, R3, R3 + AND $0xff, R3, R3 + LSLW $2, R3, R3 + ORRW $2, R3, R3 + MOVB R3, 0(R8) + MOVW R11, 1(R8) + ADD $3, R8, R8 + + // Return the number of bytes written. + SUB R7, R8, R8 + MOVD R8, ret+40(FP) + RET + +// ---------------------------------------------------------------------------- + +// func extendMatch(src []byte, i, j int) int +// +// All local variables fit into registers. The register allocation: +// - R6 &src[0] +// - R7 &src[j] +// - R13 &src[len(src) - 8] +// - R14 &src[len(src)] +// - R15 &src[i] +// +// The unusual register allocation of local variables, such as R15 for a source +// pointer, matches the allocation used at the call site in encodeBlock, which +// makes it easier to manually inline this function. +TEXT ·extendMatch(SB), NOSPLIT, $0-48 + MOVD src_base+0(FP), R6 + MOVD src_len+8(FP), R14 + MOVD i+24(FP), R15 + MOVD j+32(FP), R7 + ADD R6, R14, R14 + ADD R6, R15, R15 + ADD R6, R7, R7 + MOVD R14, R13 + SUB $8, R13, R13 + +cmp8: + // As long as we are 8 or more bytes before the end of src, we can load and + // compare 8 bytes at a time. If those 8 bytes are equal, repeat. + CMP R13, R7 + BHI cmp1 + MOVD (R15), R3 + MOVD (R7), R4 + CMP R4, R3 + BNE bsf + ADD $8, R15, R15 + ADD $8, R7, R7 + B cmp8 + +bsf: + // If those 8 bytes were not equal, XOR the two 8 byte values, and return + // the index of the first byte that differs. + // RBIT reverses the bit order, then CLZ counts the leading zeros, the + // combination of which finds the least significant bit which is set. + // The arm64 architecture is little-endian, and the shift by 3 converts + // a bit index to a byte index. + EOR R3, R4, R4 + RBIT R4, R4 + CLZ R4, R4 + ADD R4>>3, R7, R7 + + // Convert from &src[ret] to ret. + SUB R6, R7, R7 + MOVD R7, ret+40(FP) + RET + +cmp1: + // In src's tail, compare 1 byte at a time. + CMP R7, R14 + BLS extendMatchEnd + MOVB (R15), R3 + MOVB (R7), R4 + CMP R4, R3 + BNE extendMatchEnd + ADD $1, R15, R15 + ADD $1, R7, R7 + B cmp1 + +extendMatchEnd: + // Convert from &src[ret] to ret. + SUB R6, R7, R7 + MOVD R7, ret+40(FP) + RET + +// ---------------------------------------------------------------------------- + +// func encodeBlock(dst, src []byte) (d int) +// +// All local variables fit into registers, other than "var table". The register +// allocation: +// - R3 . . +// - R4 . . +// - R5 64 shift +// - R6 72 &src[0], tableSize +// - R7 80 &src[s] +// - R8 88 &dst[d] +// - R9 96 sLimit +// - R10 . &src[nextEmit] +// - R11 104 prevHash, currHash, nextHash, offset +// - R12 112 &src[base], skip +// - R13 . &src[nextS], &src[len(src) - 8] +// - R14 . len(src), bytesBetweenHashLookups, &src[len(src)], x +// - R15 120 candidate +// - R16 . hash constant, 0x1e35a7bd +// - R17 . &table +// - . 128 table +// +// The second column (64, 72, etc) is the stack offset to spill the registers +// when calling other functions. We could pack this slightly tighter, but it's +// simpler to have a dedicated spill map independent of the function called. +// +// "var table [maxTableSize]uint16" takes up 32768 bytes of stack space. An +// extra 64 bytes, to call other functions, and an extra 64 bytes, to spill +// local variables (registers) during calls gives 32768 + 64 + 64 = 32896. +TEXT ·encodeBlock(SB), 0, $32896-56 + MOVD dst_base+0(FP), R8 + MOVD src_base+24(FP), R7 + MOVD src_len+32(FP), R14 + + // shift, tableSize := uint32(32-8), 1<<8 + MOVD $24, R5 + MOVD $256, R6 + MOVW $0xa7bd, R16 + MOVKW $(0x1e35<<16), R16 + +calcShift: + // for ; tableSize < maxTableSize && tableSize < len(src); tableSize *= 2 { + // shift-- + // } + MOVD $16384, R2 + CMP R2, R6 + BGE varTable + CMP R14, R6 + BGE varTable + SUB $1, R5, R5 + LSL $1, R6, R6 + B calcShift + +varTable: + // var table [maxTableSize]uint16 + // + // In the asm code, unlike the Go code, we can zero-initialize only the + // first tableSize elements. Each uint16 element is 2 bytes and each + // iterations writes 64 bytes, so we can do only tableSize/32 writes + // instead of the 2048 writes that would zero-initialize all of table's + // 32768 bytes. This clear could overrun the first tableSize elements, but + // it won't overrun the allocated stack size. + ADD $128, RSP, R17 + MOVD R17, R4 + + // !!! R6 = &src[tableSize] + ADD R6<<1, R17, R6 + +memclr: + STP.P (ZR, ZR), 64(R4) + STP (ZR, ZR), -48(R4) + STP (ZR, ZR), -32(R4) + STP (ZR, ZR), -16(R4) + CMP R4, R6 + BHI memclr + + // !!! R6 = &src[0] + MOVD R7, R6 + + // sLimit := len(src) - inputMargin + MOVD R14, R9 + SUB $15, R9, R9 + + // !!! Pre-emptively spill R5, R6 and R9 to the stack. Their values don't + // change for the rest of the function. + MOVD R5, 64(RSP) + MOVD R6, 72(RSP) + MOVD R9, 96(RSP) + + // nextEmit := 0 + MOVD R6, R10 + + // s := 1 + ADD $1, R7, R7 + + // nextHash := hash(load32(src, s), shift) + MOVW 0(R7), R11 + MULW R16, R11, R11 + LSRW R5, R11, R11 + +outer: + // for { etc } + + // skip := 32 + MOVD $32, R12 + + // nextS := s + MOVD R7, R13 + + // candidate := 0 + MOVD $0, R15 + +inner0: + // for { etc } + + // s := nextS + MOVD R13, R7 + + // bytesBetweenHashLookups := skip >> 5 + MOVD R12, R14 + LSR $5, R14, R14 + + // nextS = s + bytesBetweenHashLookups + ADD R14, R13, R13 + + // skip += bytesBetweenHashLookups + ADD R14, R12, R12 + + // if nextS > sLimit { goto emitRemainder } + MOVD R13, R3 + SUB R6, R3, R3 + CMP R9, R3 + BHI emitRemainder + + // candidate = int(table[nextHash]) + MOVHU 0(R17)(R11<<1), R15 + + // table[nextHash] = uint16(s) + MOVD R7, R3 + SUB R6, R3, R3 + + MOVH R3, 0(R17)(R11<<1) + + // nextHash = hash(load32(src, nextS), shift) + MOVW 0(R13), R11 + MULW R16, R11 + LSRW R5, R11, R11 + + // if load32(src, s) != load32(src, candidate) { continue } break + MOVW 0(R7), R3 + MOVW (R6)(R15), R4 + CMPW R4, R3 + BNE inner0 + +fourByteMatch: + // As per the encode_other.go code: + // + // A 4-byte match has been found. We'll later see etc. + + // !!! Jump to a fast path for short (<= 16 byte) literals. See the comment + // on inputMargin in encode.go. + MOVD R7, R3 + SUB R10, R3, R3 + CMP $16, R3 + BLE emitLiteralFastPath + + // ---------------------------------------- + // Begin inline of the emitLiteral call. + // + // d += emitLiteral(dst[d:], src[nextEmit:s]) + + MOVW R3, R4 + SUBW $1, R4, R4 + + MOVW $60, R2 + CMPW R2, R4 + BLT inlineEmitLiteralOneByte + MOVW $256, R2 + CMPW R2, R4 + BLT inlineEmitLiteralTwoBytes + +inlineEmitLiteralThreeBytes: + MOVD $0xf4, R1 + MOVB R1, 0(R8) + MOVW R4, 1(R8) + ADD $3, R8, R8 + B inlineEmitLiteralMemmove + +inlineEmitLiteralTwoBytes: + MOVD $0xf0, R1 + MOVB R1, 0(R8) + MOVB R4, 1(R8) + ADD $2, R8, R8 + B inlineEmitLiteralMemmove + +inlineEmitLiteralOneByte: + LSLW $2, R4, R4 + MOVB R4, 0(R8) + ADD $1, R8, R8 + +inlineEmitLiteralMemmove: + // Spill local variables (registers) onto the stack; call; unspill. + // + // copy(dst[i:], lit) + // + // This means calling runtime·memmove(&dst[i], &lit[0], len(lit)), so we push + // R8, R10 and R3 as arguments. + MOVD R8, 8(RSP) + MOVD R10, 16(RSP) + MOVD R3, 24(RSP) + + // Finish the "d +=" part of "d += emitLiteral(etc)". + ADD R3, R8, R8 + MOVD R7, 80(RSP) + MOVD R8, 88(RSP) + MOVD R15, 120(RSP) + CALL runtime·memmove(SB) + MOVD 64(RSP), R5 + MOVD 72(RSP), R6 + MOVD 80(RSP), R7 + MOVD 88(RSP), R8 + MOVD 96(RSP), R9 + MOVD 120(RSP), R15 + ADD $128, RSP, R17 + MOVW $0xa7bd, R16 + MOVKW $(0x1e35<<16), R16 + B inner1 + +inlineEmitLiteralEnd: + // End inline of the emitLiteral call. + // ---------------------------------------- + +emitLiteralFastPath: + // !!! Emit the 1-byte encoding "uint8(len(lit)-1)<<2". + MOVB R3, R4 + SUBW $1, R4, R4 + AND $0xff, R4, R4 + LSLW $2, R4, R4 + MOVB R4, (R8) + ADD $1, R8, R8 + + // !!! Implement the copy from lit to dst as a 16-byte load and store. + // (Encode's documentation says that dst and src must not overlap.) + // + // This always copies 16 bytes, instead of only len(lit) bytes, but that's + // OK. Subsequent iterations will fix up the overrun. + // + // Note that on arm64, it is legal and cheap to issue unaligned 8-byte or + // 16-byte loads and stores. This technique probably wouldn't be as + // effective on architectures that are fussier about alignment. + LDP 0(R10), (R0, R1) + STP (R0, R1), 0(R8) + ADD R3, R8, R8 + +inner1: + // for { etc } + + // base := s + MOVD R7, R12 + + // !!! offset := base - candidate + MOVD R12, R11 + SUB R15, R11, R11 + SUB R6, R11, R11 + + // ---------------------------------------- + // Begin inline of the extendMatch call. + // + // s = extendMatch(src, candidate+4, s+4) + + // !!! R14 = &src[len(src)] + MOVD src_len+32(FP), R14 + ADD R6, R14, R14 + + // !!! R13 = &src[len(src) - 8] + MOVD R14, R13 + SUB $8, R13, R13 + + // !!! R15 = &src[candidate + 4] + ADD $4, R15, R15 + ADD R6, R15, R15 + + // !!! s += 4 + ADD $4, R7, R7 + +inlineExtendMatchCmp8: + // As long as we are 8 or more bytes before the end of src, we can load and + // compare 8 bytes at a time. If those 8 bytes are equal, repeat. + CMP R13, R7 + BHI inlineExtendMatchCmp1 + MOVD (R15), R3 + MOVD (R7), R4 + CMP R4, R3 + BNE inlineExtendMatchBSF + ADD $8, R15, R15 + ADD $8, R7, R7 + B inlineExtendMatchCmp8 + +inlineExtendMatchBSF: + // If those 8 bytes were not equal, XOR the two 8 byte values, and return + // the index of the first byte that differs. + // RBIT reverses the bit order, then CLZ counts the leading zeros, the + // combination of which finds the least significant bit which is set. + // The arm64 architecture is little-endian, and the shift by 3 converts + // a bit index to a byte index. + EOR R3, R4, R4 + RBIT R4, R4 + CLZ R4, R4 + ADD R4>>3, R7, R7 + B inlineExtendMatchEnd + +inlineExtendMatchCmp1: + // In src's tail, compare 1 byte at a time. + CMP R7, R14 + BLS inlineExtendMatchEnd + MOVB (R15), R3 + MOVB (R7), R4 + CMP R4, R3 + BNE inlineExtendMatchEnd + ADD $1, R15, R15 + ADD $1, R7, R7 + B inlineExtendMatchCmp1 + +inlineExtendMatchEnd: + // End inline of the extendMatch call. + // ---------------------------------------- + + // ---------------------------------------- + // Begin inline of the emitCopy call. + // + // d += emitCopy(dst[d:], base-candidate, s-base) + + // !!! length := s - base + MOVD R7, R3 + SUB R12, R3, R3 + +inlineEmitCopyLoop0: + // for length >= 68 { etc } + MOVW $68, R2 + CMPW R2, R3 + BLT inlineEmitCopyStep1 + + // Emit a length 64 copy, encoded as 3 bytes. + MOVD $0xfe, R1 + MOVB R1, 0(R8) + MOVW R11, 1(R8) + ADD $3, R8, R8 + SUBW $64, R3, R3 + B inlineEmitCopyLoop0 + +inlineEmitCopyStep1: + // if length > 64 { etc } + MOVW $64, R2 + CMPW R2, R3 + BLE inlineEmitCopyStep2 + + // Emit a length 60 copy, encoded as 3 bytes. + MOVD $0xee, R1 + MOVB R1, 0(R8) + MOVW R11, 1(R8) + ADD $3, R8, R8 + SUBW $60, R3, R3 + +inlineEmitCopyStep2: + // if length >= 12 || offset >= 2048 { goto inlineEmitCopyStep3 } + MOVW $12, R2 + CMPW R2, R3 + BGE inlineEmitCopyStep3 + MOVW $2048, R2 + CMPW R2, R11 + BGE inlineEmitCopyStep3 + + // Emit the remaining copy, encoded as 2 bytes. + MOVB R11, 1(R8) + LSRW $8, R11, R11 + LSLW $5, R11, R11 + SUBW $4, R3, R3 + AND $0xff, R3, R3 + LSLW $2, R3, R3 + ORRW R3, R11, R11 + ORRW $1, R11, R11 + MOVB R11, 0(R8) + ADD $2, R8, R8 + B inlineEmitCopyEnd + +inlineEmitCopyStep3: + // Emit the remaining copy, encoded as 3 bytes. + SUBW $1, R3, R3 + LSLW $2, R3, R3 + ORRW $2, R3, R3 + MOVB R3, 0(R8) + MOVW R11, 1(R8) + ADD $3, R8, R8 + +inlineEmitCopyEnd: + // End inline of the emitCopy call. + // ---------------------------------------- + + // nextEmit = s + MOVD R7, R10 + + // if s >= sLimit { goto emitRemainder } + MOVD R7, R3 + SUB R6, R3, R3 + CMP R3, R9 + BLS emitRemainder + + // As per the encode_other.go code: + // + // We could immediately etc. + + // x := load64(src, s-1) + MOVD -1(R7), R14 + + // prevHash := hash(uint32(x>>0), shift) + MOVW R14, R11 + MULW R16, R11, R11 + LSRW R5, R11, R11 + + // table[prevHash] = uint16(s-1) + MOVD R7, R3 + SUB R6, R3, R3 + SUB $1, R3, R3 + + MOVHU R3, 0(R17)(R11<<1) + + // currHash := hash(uint32(x>>8), shift) + LSR $8, R14, R14 + MOVW R14, R11 + MULW R16, R11, R11 + LSRW R5, R11, R11 + + // candidate = int(table[currHash]) + MOVHU 0(R17)(R11<<1), R15 + + // table[currHash] = uint16(s) + ADD $1, R3, R3 + MOVHU R3, 0(R17)(R11<<1) + + // if uint32(x>>8) == load32(src, candidate) { continue } + MOVW (R6)(R15), R4 + CMPW R4, R14 + BEQ inner1 + + // nextHash = hash(uint32(x>>16), shift) + LSR $8, R14, R14 + MOVW R14, R11 + MULW R16, R11, R11 + LSRW R5, R11, R11 + + // s++ + ADD $1, R7, R7 + + // break out of the inner1 for loop, i.e. continue the outer loop. + B outer + +emitRemainder: + // if nextEmit < len(src) { etc } + MOVD src_len+32(FP), R3 + ADD R6, R3, R3 + CMP R3, R10 + BEQ encodeBlockEnd + + // d += emitLiteral(dst[d:], src[nextEmit:]) + // + // Push args. + MOVD R8, 8(RSP) + MOVD $0, 16(RSP) // Unnecessary, as the callee ignores it, but conservative. + MOVD $0, 24(RSP) // Unnecessary, as the callee ignores it, but conservative. + MOVD R10, 32(RSP) + SUB R10, R3, R3 + MOVD R3, 40(RSP) + MOVD R3, 48(RSP) // Unnecessary, as the callee ignores it, but conservative. + + // Spill local variables (registers) onto the stack; call; unspill. + MOVD R8, 88(RSP) + CALL ·emitLiteral(SB) + MOVD 88(RSP), R8 + + // Finish the "d +=" part of "d += emitLiteral(etc)". + MOVD 56(RSP), R1 + ADD R1, R8, R8 + +encodeBlockEnd: + MOVD dst_base+0(FP), R3 + SUB R3, R8, R8 + MOVD R8, d+48(FP) + RET diff --git a/vendor/github.com/golang/snappy/encode_amd64.go b/vendor/github.com/golang/snappy/encode_asm.go similarity index 97% rename from vendor/github.com/golang/snappy/encode_amd64.go rename to vendor/github.com/golang/snappy/encode_asm.go index 150d91b..107c1e7 100644 --- a/vendor/github.com/golang/snappy/encode_amd64.go +++ b/vendor/github.com/golang/snappy/encode_asm.go @@ -5,6 +5,7 @@ // +build !appengine // +build gc // +build !noasm +// +build amd64 arm64 package snappy diff --git a/vendor/github.com/golang/snappy/encode_other.go b/vendor/github.com/golang/snappy/encode_other.go index dbcae90..296d7f0 100644 --- a/vendor/github.com/golang/snappy/encode_other.go +++ b/vendor/github.com/golang/snappy/encode_other.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build !amd64 appengine !gc noasm +// +build !amd64,!arm64 appengine !gc noasm package snappy diff --git a/vendor/github.com/klauspost/compress/.gitignore b/vendor/github.com/klauspost/compress/.gitignore index b35f844..d31b378 100644 --- a/vendor/github.com/klauspost/compress/.gitignore +++ b/vendor/github.com/klauspost/compress/.gitignore @@ -23,3 +23,10 @@ _testmain.go *.test *.prof /s2/cmd/_s2sx/sfx-exe + +# Linux perf files +perf.data +perf.data.old + +# gdb history +.gdb_history diff --git a/vendor/github.com/klauspost/compress/.goreleaser.yml b/vendor/github.com/klauspost/compress/.goreleaser.yml index c9014ce..7a008a4 100644 --- a/vendor/github.com/klauspost/compress/.goreleaser.yml +++ b/vendor/github.com/klauspost/compress/.goreleaser.yml @@ -3,6 +3,7 @@ before: hooks: - ./gen.sh + - go install mvdan.cc/garble@v0.9.3 builds: - @@ -31,6 +32,7 @@ builds: - mips64le goarm: - 7 + gobinary: garble - id: "s2d" binary: s2d @@ -57,6 +59,7 @@ builds: - mips64le goarm: - 7 + gobinary: garble - id: "s2sx" binary: s2sx @@ -84,6 +87,7 @@ builds: - mips64le goarm: - 7 + gobinary: garble archives: - diff --git a/vendor/github.com/klauspost/compress/README.md b/vendor/github.com/klauspost/compress/README.md index 3429879..4002a16 100644 --- a/vendor/github.com/klauspost/compress/README.md +++ b/vendor/github.com/klauspost/compress/README.md @@ -9,7 +9,6 @@ This package provides various compression algorithms. * [huff0](https://github.com/klauspost/compress/tree/master/huff0) and [FSE](https://github.com/klauspost/compress/tree/master/fse) implementations for raw entropy encoding. * [gzhttp](https://github.com/klauspost/compress/tree/master/gzhttp) Provides client and server wrappers for handling gzipped requests efficiently. * [pgzip](https://github.com/klauspost/pgzip) is a separate package that provides a very fast parallel gzip implementation. -* [fuzz package](https://github.com/klauspost/compress-fuzz) for fuzz testing all compressors/decompressors here. [![Go Reference](https://pkg.go.dev/badge/klauspost/compress.svg)](https://pkg.go.dev/github.com/klauspost/compress?tab=subdirectories) [![Go](https://github.com/klauspost/compress/actions/workflows/go.yml/badge.svg)](https://github.com/klauspost/compress/actions/workflows/go.yml) @@ -17,6 +16,199 @@ This package provides various compression algorithms. # changelog +* June 13, 2023 - [v1.16.6](https://github.com/klauspost/compress/releases/tag/v1.16.6) + * zstd: correctly ignore WithEncoderPadding(1) by @ianlancetaylor in https://github.com/klauspost/compress/pull/806 + * zstd: Add amd64 match length assembly https://github.com/klauspost/compress/pull/824 + * gzhttp: Handle informational headers by @rtribotte in https://github.com/klauspost/compress/pull/815 + * s2: Improve Better compression slightly https://github.com/klauspost/compress/pull/663 + +* Apr 16, 2023 - [v1.16.5](https://github.com/klauspost/compress/releases/tag/v1.16.5) + * zstd: readByte needs to use io.ReadFull by @jnoxon in https://github.com/klauspost/compress/pull/802 + * gzip: Fix WriterTo after initial read https://github.com/klauspost/compress/pull/804 + +* Apr 5, 2023 - [v1.16.4](https://github.com/klauspost/compress/releases/tag/v1.16.4) + * zstd: Improve zstd best efficiency by @greatroar and @klauspost in https://github.com/klauspost/compress/pull/784 + * zstd: Respect WithAllLitEntropyCompression https://github.com/klauspost/compress/pull/792 + * zstd: Fix amd64 not always detecting corrupt data https://github.com/klauspost/compress/pull/785 + * zstd: Various minor improvements by @greatroar in https://github.com/klauspost/compress/pull/788 https://github.com/klauspost/compress/pull/794 https://github.com/klauspost/compress/pull/795 + * s2: Fix huge block overflow https://github.com/klauspost/compress/pull/779 + * s2: Allow CustomEncoder fallback https://github.com/klauspost/compress/pull/780 + * gzhttp: Suppport ResponseWriter Unwrap() in gzhttp handler by @jgimenez in https://github.com/klauspost/compress/pull/799 + +* Mar 13, 2023 - [v1.16.1](https://github.com/klauspost/compress/releases/tag/v1.16.1) + * zstd: Speed up + improve best encoder by @greatroar in https://github.com/klauspost/compress/pull/776 + * gzhttp: Add optional [BREACH mitigation](https://github.com/klauspost/compress/tree/master/gzhttp#breach-mitigation). https://github.com/klauspost/compress/pull/762 https://github.com/klauspost/compress/pull/768 https://github.com/klauspost/compress/pull/769 https://github.com/klauspost/compress/pull/770 https://github.com/klauspost/compress/pull/767 + * s2: Add Intel LZ4s converter https://github.com/klauspost/compress/pull/766 + * zstd: Minor bug fixes https://github.com/klauspost/compress/pull/771 https://github.com/klauspost/compress/pull/772 https://github.com/klauspost/compress/pull/773 + * huff0: Speed up compress1xDo by @greatroar in https://github.com/klauspost/compress/pull/774 + +* Feb 26, 2023 - [v1.16.0](https://github.com/klauspost/compress/releases/tag/v1.16.0) + * s2: Add [Dictionary](https://github.com/klauspost/compress/tree/master/s2#dictionaries) support. https://github.com/klauspost/compress/pull/685 + * s2: Add Compression Size Estimate. https://github.com/klauspost/compress/pull/752 + * s2: Add support for custom stream encoder. https://github.com/klauspost/compress/pull/755 + * s2: Add LZ4 block converter. https://github.com/klauspost/compress/pull/748 + * s2: Support io.ReaderAt in ReadSeeker. https://github.com/klauspost/compress/pull/747 + * s2c/s2sx: Use concurrent decoding. https://github.com/klauspost/compress/pull/746 + +* Jan 21st, 2023 (v1.15.15) + * deflate: Improve level 7-9 by @klauspost in https://github.com/klauspost/compress/pull/739 + * zstd: Add delta encoding support by @greatroar in https://github.com/klauspost/compress/pull/728 + * zstd: Various speed improvements by @greatroar https://github.com/klauspost/compress/pull/741 https://github.com/klauspost/compress/pull/734 https://github.com/klauspost/compress/pull/736 https://github.com/klauspost/compress/pull/744 https://github.com/klauspost/compress/pull/743 https://github.com/klauspost/compress/pull/745 + * gzhttp: Add SuffixETag() and DropETag() options to prevent ETag collisions on compressed responses by @willbicks in https://github.com/klauspost/compress/pull/740 + +* Jan 3rd, 2023 (v1.15.14) + + * flate: Improve speed in big stateless blocks https://github.com/klauspost/compress/pull/718 + * zstd: Minor speed tweaks by @greatroar in https://github.com/klauspost/compress/pull/716 https://github.com/klauspost/compress/pull/720 + * export NoGzipResponseWriter for custom ResponseWriter wrappers by @harshavardhana in https://github.com/klauspost/compress/pull/722 + * s2: Add example for indexing and existing stream https://github.com/klauspost/compress/pull/723 + +* Dec 11, 2022 (v1.15.13) + * zstd: Add [MaxEncodedSize](https://pkg.go.dev/github.com/klauspost/compress@v1.15.13/zstd#Encoder.MaxEncodedSize) to encoder https://github.com/klauspost/compress/pull/691 + * zstd: Various tweaks and improvements https://github.com/klauspost/compress/pull/693 https://github.com/klauspost/compress/pull/695 https://github.com/klauspost/compress/pull/696 https://github.com/klauspost/compress/pull/701 https://github.com/klauspost/compress/pull/702 https://github.com/klauspost/compress/pull/703 https://github.com/klauspost/compress/pull/704 https://github.com/klauspost/compress/pull/705 https://github.com/klauspost/compress/pull/706 https://github.com/klauspost/compress/pull/707 https://github.com/klauspost/compress/pull/708 + +* Oct 26, 2022 (v1.15.12) + + * zstd: Tweak decoder allocs. https://github.com/klauspost/compress/pull/680 + * gzhttp: Always delete `HeaderNoCompression` https://github.com/klauspost/compress/pull/683 + +* Sept 26, 2022 (v1.15.11) + + * flate: Improve level 1-3 compression https://github.com/klauspost/compress/pull/678 + * zstd: Improve "best" compression by @nightwolfz in https://github.com/klauspost/compress/pull/677 + * zstd: Fix+reduce decompression allocations https://github.com/klauspost/compress/pull/668 + * zstd: Fix non-effective noescape tag https://github.com/klauspost/compress/pull/667 + +* Sept 16, 2022 (v1.15.10) + + * zstd: Add [WithDecodeAllCapLimit](https://pkg.go.dev/github.com/klauspost/compress@v1.15.10/zstd#WithDecodeAllCapLimit) https://github.com/klauspost/compress/pull/649 + * Add Go 1.19 - deprecate Go 1.16 https://github.com/klauspost/compress/pull/651 + * flate: Improve level 5+6 compression https://github.com/klauspost/compress/pull/656 + * zstd: Improve "better" compresssion https://github.com/klauspost/compress/pull/657 + * s2: Improve "best" compression https://github.com/klauspost/compress/pull/658 + * s2: Improve "better" compression. https://github.com/klauspost/compress/pull/635 + * s2: Slightly faster non-assembly decompression https://github.com/klauspost/compress/pull/646 + * Use arrays for constant size copies https://github.com/klauspost/compress/pull/659 + +* July 21, 2022 (v1.15.9) + + * zstd: Fix decoder crash on amd64 (no BMI) on invalid input https://github.com/klauspost/compress/pull/645 + * zstd: Disable decoder extended memory copies (amd64) due to possible crashes https://github.com/klauspost/compress/pull/644 + * zstd: Allow single segments up to "max decoded size" by @klauspost in https://github.com/klauspost/compress/pull/643 + +* July 13, 2022 (v1.15.8) + + * gzip: fix stack exhaustion bug in Reader.Read https://github.com/klauspost/compress/pull/641 + * s2: Add Index header trim/restore https://github.com/klauspost/compress/pull/638 + * zstd: Optimize seqdeq amd64 asm by @greatroar in https://github.com/klauspost/compress/pull/636 + * zstd: Improve decoder memcopy https://github.com/klauspost/compress/pull/637 + * huff0: Pass a single bitReader pointer to asm by @greatroar in https://github.com/klauspost/compress/pull/634 + * zstd: Branchless getBits for amd64 w/o BMI2 by @greatroar in https://github.com/klauspost/compress/pull/640 + * gzhttp: Remove header before writing https://github.com/klauspost/compress/pull/639 + +* June 29, 2022 (v1.15.7) + + * s2: Fix absolute forward seeks https://github.com/klauspost/compress/pull/633 + * zip: Merge upstream https://github.com/klauspost/compress/pull/631 + * zip: Re-add zip64 fix https://github.com/klauspost/compress/pull/624 + * zstd: translate fseDecoder.buildDtable into asm by @WojciechMula in https://github.com/klauspost/compress/pull/598 + * flate: Faster histograms https://github.com/klauspost/compress/pull/620 + * deflate: Use compound hcode https://github.com/klauspost/compress/pull/622 + +* June 3, 2022 (v1.15.6) + * s2: Improve coding for long, close matches https://github.com/klauspost/compress/pull/613 + * s2c: Add Snappy/S2 stream recompression https://github.com/klauspost/compress/pull/611 + * zstd: Always use configured block size https://github.com/klauspost/compress/pull/605 + * zstd: Fix incorrect hash table placement for dict encoding in default https://github.com/klauspost/compress/pull/606 + * zstd: Apply default config to ZipDecompressor without options https://github.com/klauspost/compress/pull/608 + * gzhttp: Exclude more common archive formats https://github.com/klauspost/compress/pull/612 + * s2: Add ReaderIgnoreCRC https://github.com/klauspost/compress/pull/609 + * s2: Remove sanity load on index creation https://github.com/klauspost/compress/pull/607 + * snappy: Use dedicated function for scoring https://github.com/klauspost/compress/pull/614 + * s2c+s2d: Use official snappy framed extension https://github.com/klauspost/compress/pull/610 + +* May 25, 2022 (v1.15.5) + * s2: Add concurrent stream decompression https://github.com/klauspost/compress/pull/602 + * s2: Fix final emit oob read crash on amd64 https://github.com/klauspost/compress/pull/601 + * huff0: asm implementation of Decompress1X by @WojciechMula https://github.com/klauspost/compress/pull/596 + * zstd: Use 1 less goroutine for stream decoding https://github.com/klauspost/compress/pull/588 + * zstd: Copy literal in 16 byte blocks when possible https://github.com/klauspost/compress/pull/592 + * zstd: Speed up when WithDecoderLowmem(false) https://github.com/klauspost/compress/pull/599 + * zstd: faster next state update in BMI2 version of decode by @WojciechMula in https://github.com/klauspost/compress/pull/593 + * huff0: Do not check max size when reading table. https://github.com/klauspost/compress/pull/586 + * flate: Inplace hashing for level 7-9 by @klauspost in https://github.com/klauspost/compress/pull/590 + + +* May 11, 2022 (v1.15.4) + * huff0: decompress directly into output by @WojciechMula in [#577](https://github.com/klauspost/compress/pull/577) + * inflate: Keep dict on stack [#581](https://github.com/klauspost/compress/pull/581) + * zstd: Faster decoding memcopy in asm [#583](https://github.com/klauspost/compress/pull/583) + * zstd: Fix ignored crc [#580](https://github.com/klauspost/compress/pull/580) + +* May 5, 2022 (v1.15.3) + * zstd: Allow to ignore checksum checking by @WojciechMula [#572](https://github.com/klauspost/compress/pull/572) + * s2: Fix incorrect seek for io.SeekEnd in [#575](https://github.com/klauspost/compress/pull/575) + +* Apr 26, 2022 (v1.15.2) + * zstd: Add x86-64 assembly for decompression on streams and blocks. Contributed by [@WojciechMula](https://github.com/WojciechMula). Typically 2x faster. [#528](https://github.com/klauspost/compress/pull/528) [#531](https://github.com/klauspost/compress/pull/531) [#545](https://github.com/klauspost/compress/pull/545) [#537](https://github.com/klauspost/compress/pull/537) + * zstd: Add options to ZipDecompressor and fixes [#539](https://github.com/klauspost/compress/pull/539) + * s2: Use sorted search for index [#555](https://github.com/klauspost/compress/pull/555) + * Minimum version is Go 1.16, added CI test on 1.18. + +* Mar 11, 2022 (v1.15.1) + * huff0: Add x86 assembly of Decode4X by @WojciechMula in [#512](https://github.com/klauspost/compress/pull/512) + * zstd: Reuse zip decoders in [#514](https://github.com/klauspost/compress/pull/514) + * zstd: Detect extra block data and report as corrupted in [#520](https://github.com/klauspost/compress/pull/520) + * zstd: Handle zero sized frame content size stricter in [#521](https://github.com/klauspost/compress/pull/521) + * zstd: Add stricter block size checks in [#523](https://github.com/klauspost/compress/pull/523) + +* Mar 3, 2022 (v1.15.0) + * zstd: Refactor decoder by @klauspost in [#498](https://github.com/klauspost/compress/pull/498) + * zstd: Add stream encoding without goroutines by @klauspost in [#505](https://github.com/klauspost/compress/pull/505) + * huff0: Prevent single blocks exceeding 16 bits by @klauspost in[#507](https://github.com/klauspost/compress/pull/507) + * flate: Inline literal emission by @klauspost in [#509](https://github.com/klauspost/compress/pull/509) + * gzhttp: Add zstd to transport by @klauspost in [#400](https://github.com/klauspost/compress/pull/400) + * gzhttp: Make content-type optional by @klauspost in [#510](https://github.com/klauspost/compress/pull/510) + +Both compression and decompression now supports "synchronous" stream operations. This means that whenever "concurrency" is set to 1, they will operate without spawning goroutines. + +Stream decompression is now faster on asynchronous, since the goroutine allocation much more effectively splits the workload. On typical streams this will typically use 2 cores fully for decompression. When a stream has finished decoding no goroutines will be left over, so decoders can now safely be pooled and still be garbage collected. + +While the release has been extensively tested, it is recommended to testing when upgrading. + +
+ See changes to v1.14.x + +* Feb 22, 2022 (v1.14.4) + * flate: Fix rare huffman only (-2) corruption. [#503](https://github.com/klauspost/compress/pull/503) + * zip: Update deprecated CreateHeaderRaw to correctly call CreateRaw by @saracen in [#502](https://github.com/klauspost/compress/pull/502) + * zip: don't read data descriptor early by @saracen in [#501](https://github.com/klauspost/compress/pull/501) #501 + * huff0: Use static decompression buffer up to 30% faster by @klauspost in [#499](https://github.com/klauspost/compress/pull/499) [#500](https://github.com/klauspost/compress/pull/500) + +* Feb 17, 2022 (v1.14.3) + * flate: Improve fastest levels compression speed ~10% more throughput. [#482](https://github.com/klauspost/compress/pull/482) [#489](https://github.com/klauspost/compress/pull/489) [#490](https://github.com/klauspost/compress/pull/490) [#491](https://github.com/klauspost/compress/pull/491) [#494](https://github.com/klauspost/compress/pull/494) [#478](https://github.com/klauspost/compress/pull/478) + * flate: Faster decompression speed, ~5-10%. [#483](https://github.com/klauspost/compress/pull/483) + * s2: Faster compression with Go v1.18 and amd64 microarch level 3+. [#484](https://github.com/klauspost/compress/pull/484) [#486](https://github.com/klauspost/compress/pull/486) + +* Jan 25, 2022 (v1.14.2) + * zstd: improve header decoder by @dsnet [#476](https://github.com/klauspost/compress/pull/476) + * zstd: Add bigger default blocks [#469](https://github.com/klauspost/compress/pull/469) + * zstd: Remove unused decompression buffer [#470](https://github.com/klauspost/compress/pull/470) + * zstd: Fix logically dead code by @ningmingxiao [#472](https://github.com/klauspost/compress/pull/472) + * flate: Improve level 7-9 [#471](https://github.com/klauspost/compress/pull/471) [#473](https://github.com/klauspost/compress/pull/473) + * zstd: Add noasm tag for xxhash [#475](https://github.com/klauspost/compress/pull/475) + +* Jan 11, 2022 (v1.14.1) + * s2: Add stream index in [#462](https://github.com/klauspost/compress/pull/462) + * flate: Speed and efficiency improvements in [#439](https://github.com/klauspost/compress/pull/439) [#461](https://github.com/klauspost/compress/pull/461) [#455](https://github.com/klauspost/compress/pull/455) [#452](https://github.com/klauspost/compress/pull/452) [#458](https://github.com/klauspost/compress/pull/458) + * zstd: Performance improvement in [#420]( https://github.com/klauspost/compress/pull/420) [#456](https://github.com/klauspost/compress/pull/456) [#437](https://github.com/klauspost/compress/pull/437) [#467](https://github.com/klauspost/compress/pull/467) [#468](https://github.com/klauspost/compress/pull/468) + * zstd: add arm64 xxhash assembly in [#464](https://github.com/klauspost/compress/pull/464) + * Add garbled for binaries for s2 in [#445](https://github.com/klauspost/compress/pull/445) +
+ +
+ See changes to v1.13.x + * Aug 30, 2021 (v1.13.5) * gz/zlib/flate: Alias stdlib errors [#425](https://github.com/klauspost/compress/pull/425) * s2: Add block support to commandline tools [#413](https://github.com/klauspost/compress/pull/413) @@ -45,7 +237,12 @@ This package provides various compression algorithms. * Added [gzhttp](https://github.com/klauspost/compress/tree/master/gzhttp#gzip-handler) which allows wrapping HTTP servers and clients with GZIP compressors. * zstd: Detect short invalid signatures [#382](https://github.com/klauspost/compress/pull/382) * zstd: Spawn decoder goroutine only if needed. [#380](https://github.com/klauspost/compress/pull/380) +
+ +
+ See changes to v1.12.x + * May 25, 2021 (v1.12.3) * deflate: Better/faster Huffman encoding [#374](https://github.com/klauspost/compress/pull/374) * deflate: Allocate less for history. [#375](https://github.com/klauspost/compress/pull/375) @@ -67,9 +264,10 @@ This package provides various compression algorithms. * s2c/s2d/s2sx: Always truncate when writing files [#352](https://github.com/klauspost/compress/pull/352) * zstd: Reduce memory usage further when using [WithLowerEncoderMem](https://pkg.go.dev/github.com/klauspost/compress/zstd#WithLowerEncoderMem) [#346](https://github.com/klauspost/compress/pull/346) * s2: Fix potential problem with amd64 assembly and profilers [#349](https://github.com/klauspost/compress/pull/349) +
- See changes prior to v1.12.1 + See changes to v1.11.x * Mar 26, 2021 (v1.11.13) * zstd: Big speedup on small dictionary encodes [#344](https://github.com/klauspost/compress/pull/344) [#345](https://github.com/klauspost/compress/pull/345) @@ -128,7 +326,7 @@ This package provides various compression algorithms.
- See changes prior to v1.11.0 + See changes to v1.10.x * July 8, 2020 (v1.10.11) * zstd: Fix extra block when compressing with ReadFrom. [#278](https://github.com/klauspost/compress/pull/278) @@ -290,11 +488,6 @@ This package provides various compression algorithms. # deflate usage -* [High Throughput Benchmark](http://blog.klauspost.com/go-gzipdeflate-benchmarks/). -* [Small Payload/Webserver Benchmarks](http://blog.klauspost.com/gzip-performance-for-go-webservers/). -* [Linear Time Compression](http://blog.klauspost.com/constant-time-gzipzip-compression/). -* [Re-balancing Deflate Compression Levels](https://blog.klauspost.com/rebalancing-deflate-compression-levels/) - The packages are drop-in replacements for standard libraries. Simply replace the import path to use them: | old import | new import | Documentation @@ -316,6 +509,8 @@ Memory usage is typically 1MB for a Writer. stdlib is in the same range. If you expect to have a lot of concurrently allocated Writers consider using the stateless compress described below. +For compression performance, see: [this spreadsheet](https://docs.google.com/spreadsheets/d/1nuNE2nPfuINCZJRMt6wFWhKpToF95I47XjSsc-1rbPQ/edit?usp=sharing). + # Stateless compression This package offers stateless compression as a special option for gzip/deflate. @@ -432,6 +627,15 @@ For more information see my blog post on [Fast Linear Time Compression](http://b This is implemented on Go 1.7 as "Huffman Only" mode, though not exposed for gzip. +# Other packages + +Here are other packages of good quality and pure Go (no cgo wrappers or autoconverted code): + +* [github.com/pierrec/lz4](https://github.com/pierrec/lz4) - strong multithreaded LZ4 compression. +* [github.com/cosnicolaou/pbzip2](https://github.com/cosnicolaou/pbzip2) - multithreaded bzip2 decompression. +* [github.com/dsnet/compress](https://github.com/dsnet/compress) - brotli decompression, bzip2 writer. +* [github.com/ronanh/intcomp](https://github.com/ronanh/intcomp) - Integer compression. +* [github.com/spenczar/fpc](https://github.com/spenczar/fpc) - Float compression. # license diff --git a/vendor/github.com/klauspost/compress/SECURITY.md b/vendor/github.com/klauspost/compress/SECURITY.md new file mode 100644 index 0000000..ca6685e --- /dev/null +++ b/vendor/github.com/klauspost/compress/SECURITY.md @@ -0,0 +1,25 @@ +# Security Policy + +## Supported Versions + +Security updates are applied only to the latest release. + +## Vulnerability Definition + +A security vulnerability is a bug that with certain input triggers a crash or an infinite loop. Most calls will have varying execution time and only in rare cases will slow operation be considered a security vulnerability. + +Corrupted output generally is not considered a security vulnerability, unless independent operations are able to affect each other. Note that not all functionality is re-entrant and safe to use concurrently. + +Out-of-memory crashes only applies if the en/decoder uses an abnormal amount of memory, with appropriate options applied, to limit maximum window size, concurrency, etc. However, if you are in doubt you are welcome to file a security issue. + +It is assumed that all callers are trusted, meaning internal data exposed through reflection or inspection of returned data structures is not considered a vulnerability. + +Vulnerabilities resulting from compiler/assembler errors should be reported upstream. Depending on the severity this package may or may not implement a workaround. + +## Reporting a Vulnerability + +If you have discovered a security vulnerability in this project, please report it privately. **Do not disclose it as a public issue.** This gives us time to work with you to fix the issue before public exposure, reducing the chance that the exploit will be used before a patch is released. + +Please disclose it at [security advisory](https://github.com/klauspost/compress/security/advisories/new). If possible please provide a minimal reproducer. If the issue only applies to a single platform, it would be helpful to provide access to that. + +This project is maintained by a team of volunteers on a reasonable-effort basis. As such, vulnerabilities will be disclosed in a best effort base. diff --git a/vendor/github.com/klauspost/compress/fse/compress.go b/vendor/github.com/klauspost/compress/fse/compress.go index 6f34191..dac97e5 100644 --- a/vendor/github.com/klauspost/compress/fse/compress.go +++ b/vendor/github.com/klauspost/compress/fse/compress.go @@ -146,54 +146,51 @@ func (s *Scratch) compress(src []byte) error { c1.encodeZero(tt[src[ip-2]]) ip -= 2 } + src = src[:ip] // Main compression loop. switch { case !s.zeroBits && s.actualTableLog <= 8: // We can encode 4 symbols without requiring a flush. // We do not need to check if any output is 0 bits. - for ip >= 4 { + for ; len(src) >= 4; src = src[:len(src)-4] { s.bw.flush32() - v3, v2, v1, v0 := src[ip-4], src[ip-3], src[ip-2], src[ip-1] + v3, v2, v1, v0 := src[len(src)-4], src[len(src)-3], src[len(src)-2], src[len(src)-1] c2.encode(tt[v0]) c1.encode(tt[v1]) c2.encode(tt[v2]) c1.encode(tt[v3]) - ip -= 4 } case !s.zeroBits: // We do not need to check if any output is 0 bits. - for ip >= 4 { + for ; len(src) >= 4; src = src[:len(src)-4] { s.bw.flush32() - v3, v2, v1, v0 := src[ip-4], src[ip-3], src[ip-2], src[ip-1] + v3, v2, v1, v0 := src[len(src)-4], src[len(src)-3], src[len(src)-2], src[len(src)-1] c2.encode(tt[v0]) c1.encode(tt[v1]) s.bw.flush32() c2.encode(tt[v2]) c1.encode(tt[v3]) - ip -= 4 } case s.actualTableLog <= 8: // We can encode 4 symbols without requiring a flush - for ip >= 4 { + for ; len(src) >= 4; src = src[:len(src)-4] { s.bw.flush32() - v3, v2, v1, v0 := src[ip-4], src[ip-3], src[ip-2], src[ip-1] + v3, v2, v1, v0 := src[len(src)-4], src[len(src)-3], src[len(src)-2], src[len(src)-1] c2.encodeZero(tt[v0]) c1.encodeZero(tt[v1]) c2.encodeZero(tt[v2]) c1.encodeZero(tt[v3]) - ip -= 4 } default: - for ip >= 4 { + for ; len(src) >= 4; src = src[:len(src)-4] { s.bw.flush32() - v3, v2, v1, v0 := src[ip-4], src[ip-3], src[ip-2], src[ip-1] + v3, v2, v1, v0 := src[len(src)-4], src[len(src)-3], src[len(src)-2], src[len(src)-1] c2.encodeZero(tt[v0]) c1.encodeZero(tt[v1]) s.bw.flush32() c2.encodeZero(tt[v2]) c1.encodeZero(tt[v3]) - ip -= 4 } } @@ -459,15 +456,17 @@ func (s *Scratch) countSimple(in []byte) (max int) { for _, v := range in { s.count[v]++ } - m := uint32(0) + m, symlen := uint32(0), s.symbolLen for i, v := range s.count[:] { + if v == 0 { + continue + } if v > m { m = v } - if v > 0 { - s.symbolLen = uint16(i) + 1 - } + symlen = uint16(i) + 1 } + s.symbolLen = symlen return int(m) } diff --git a/vendor/github.com/klauspost/compress/fse/decompress.go b/vendor/github.com/klauspost/compress/fse/decompress.go index 926f5f1..cc05d0f 100644 --- a/vendor/github.com/klauspost/compress/fse/decompress.go +++ b/vendor/github.com/klauspost/compress/fse/decompress.go @@ -260,7 +260,9 @@ func (s *Scratch) buildDtable() error { // If the buffer is over-read an error is returned. func (s *Scratch) decompress() error { br := &s.bits - br.init(s.br.unread()) + if err := br.init(s.br.unread()); err != nil { + return err + } var s1, s2 decoder // Initialize and decode first state and symbol. diff --git a/vendor/github.com/klauspost/compress/go.mod b/vendor/github.com/klauspost/compress/go.mod index 5aa64a4..44ba820 100644 --- a/vendor/github.com/klauspost/compress/go.mod +++ b/vendor/github.com/klauspost/compress/go.mod @@ -1,3 +1,10 @@ module github.com/klauspost/compress -go 1.15 +go 1.18 + +retract ( + // https://github.com/klauspost/compress/pull/503 + v1.14.3 + v1.14.2 + v1.14.1 +) diff --git a/vendor/github.com/klauspost/compress/huff0/bitreader.go b/vendor/github.com/klauspost/compress/huff0/bitreader.go index a4979e8..e36d974 100644 --- a/vendor/github.com/klauspost/compress/huff0/bitreader.go +++ b/vendor/github.com/klauspost/compress/huff0/bitreader.go @@ -8,115 +8,10 @@ package huff0 import ( "encoding/binary" "errors" + "fmt" "io" ) -// bitReader reads a bitstream in reverse. -// The last set bit indicates the start of the stream and is used -// for aligning the input. -type bitReader struct { - in []byte - off uint // next byte to read is at in[off - 1] - value uint64 - bitsRead uint8 -} - -// init initializes and resets the bit reader. -func (b *bitReader) init(in []byte) error { - if len(in) < 1 { - return errors.New("corrupt stream: too short") - } - b.in = in - b.off = uint(len(in)) - // The highest bit of the last byte indicates where to start - v := in[len(in)-1] - if v == 0 { - return errors.New("corrupt stream, did not find end of stream") - } - b.bitsRead = 64 - b.value = 0 - if len(in) >= 8 { - b.fillFastStart() - } else { - b.fill() - b.fill() - } - b.bitsRead += 8 - uint8(highBit32(uint32(v))) - return nil -} - -// peekBitsFast requires that at least one bit is requested every time. -// There are no checks if the buffer is filled. -func (b *bitReader) peekBitsFast(n uint8) uint16 { - const regMask = 64 - 1 - v := uint16((b.value << (b.bitsRead & regMask)) >> ((regMask + 1 - n) & regMask)) - return v -} - -// fillFast() will make sure at least 32 bits are available. -// There must be at least 4 bytes available. -func (b *bitReader) fillFast() { - if b.bitsRead < 32 { - return - } - - // 2 bounds checks. - v := b.in[b.off-4 : b.off] - v = v[:4] - low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) - b.value = (b.value << 32) | uint64(low) - b.bitsRead -= 32 - b.off -= 4 -} - -func (b *bitReader) advance(n uint8) { - b.bitsRead += n -} - -// fillFastStart() assumes the bitreader is empty and there is at least 8 bytes to read. -func (b *bitReader) fillFastStart() { - // Do single re-slice to avoid bounds checks. - b.value = binary.LittleEndian.Uint64(b.in[b.off-8:]) - b.bitsRead = 0 - b.off -= 8 -} - -// fill() will make sure at least 32 bits are available. -func (b *bitReader) fill() { - if b.bitsRead < 32 { - return - } - if b.off > 4 { - v := b.in[b.off-4:] - v = v[:4] - low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) - b.value = (b.value << 32) | uint64(low) - b.bitsRead -= 32 - b.off -= 4 - return - } - for b.off > 0 { - b.value = (b.value << 8) | uint64(b.in[b.off-1]) - b.bitsRead -= 8 - b.off-- - } -} - -// finished returns true if all bits have been read from the bit stream. -func (b *bitReader) finished() bool { - return b.off == 0 && b.bitsRead >= 64 -} - -// close the bitstream and returns an error if out-of-buffer reads occurred. -func (b *bitReader) close() error { - // Release reference. - b.in = nil - if b.bitsRead > 64 { - return io.ErrUnexpectedEOF - } - return nil -} - // bitReader reads a bitstream in reverse. // The last set bit indicates the start of the stream and is used // for aligning the input. @@ -172,7 +67,6 @@ func (b *bitReaderBytes) fillFast() { // 2 bounds checks. v := b.in[b.off-4 : b.off] - v = v[:4] low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) b.value |= uint64(low) << (b.bitsRead - 32) b.bitsRead -= 32 @@ -193,8 +87,7 @@ func (b *bitReaderBytes) fill() { return } if b.off > 4 { - v := b.in[b.off-4:] - v = v[:4] + v := b.in[b.off-4 : b.off] low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) b.value |= uint64(low) << (b.bitsRead - 32) b.bitsRead -= 32 @@ -213,10 +106,17 @@ func (b *bitReaderBytes) finished() bool { return b.off == 0 && b.bitsRead >= 64 } +func (b *bitReaderBytes) remaining() uint { + return b.off*8 + uint(64-b.bitsRead) +} + // close the bitstream and returns an error if out-of-buffer reads occurred. func (b *bitReaderBytes) close() error { // Release reference. b.in = nil + if b.remaining() > 0 { + return fmt.Errorf("corrupt input: %d bits remain on stream", b.remaining()) + } if b.bitsRead > 64 { return io.ErrUnexpectedEOF } @@ -277,7 +177,6 @@ func (b *bitReaderShifted) fillFast() { // 2 bounds checks. v := b.in[b.off-4 : b.off] - v = v[:4] low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) b.value |= uint64(low) << ((b.bitsRead - 32) & 63) b.bitsRead -= 32 @@ -298,8 +197,7 @@ func (b *bitReaderShifted) fill() { return } if b.off > 4 { - v := b.in[b.off-4:] - v = v[:4] + v := b.in[b.off-4 : b.off] low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) b.value |= uint64(low) << ((b.bitsRead - 32) & 63) b.bitsRead -= 32 @@ -313,15 +211,17 @@ func (b *bitReaderShifted) fill() { } } -// finished returns true if all bits have been read from the bit stream. -func (b *bitReaderShifted) finished() bool { - return b.off == 0 && b.bitsRead >= 64 +func (b *bitReaderShifted) remaining() uint { + return b.off*8 + uint(64-b.bitsRead) } // close the bitstream and returns an error if out-of-buffer reads occurred. func (b *bitReaderShifted) close() error { // Release reference. b.in = nil + if b.remaining() > 0 { + return fmt.Errorf("corrupt input: %d bits remain on stream", b.remaining()) + } if b.bitsRead > 64 { return io.ErrUnexpectedEOF } diff --git a/vendor/github.com/klauspost/compress/huff0/bitwriter.go b/vendor/github.com/klauspost/compress/huff0/bitwriter.go index 6bce4e8..b4d7164 100644 --- a/vendor/github.com/klauspost/compress/huff0/bitwriter.go +++ b/vendor/github.com/klauspost/compress/huff0/bitwriter.go @@ -5,8 +5,6 @@ package huff0 -import "fmt" - // bitWriter will write bits. // First bit will be LSB of the first byte of output. type bitWriter struct { @@ -15,22 +13,6 @@ type bitWriter struct { out []byte } -// bitMask16 is bitmasks. Has extra to avoid bounds check. -var bitMask16 = [32]uint16{ - 0, 1, 3, 7, 0xF, 0x1F, - 0x3F, 0x7F, 0xFF, 0x1FF, 0x3FF, 0x7FF, - 0xFFF, 0x1FFF, 0x3FFF, 0x7FFF, 0xFFFF, 0xFFFF, - 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, - 0xFFFF, 0xFFFF} /* up to 16 bits */ - -// addBits16NC will add up to 16 bits. -// It will not check if there is space for them, -// so the caller must ensure that it has flushed recently. -func (b *bitWriter) addBits16NC(value uint16, bits uint8) { - b.bitContainer |= uint64(value&bitMask16[bits&31]) << (b.nBits & 63) - b.nBits += bits -} - // addBits16Clean will add up to 16 bits. value may not contain more set bits than indicated. // It will not check if there is space for them, so the caller must ensure that it has flushed recently. func (b *bitWriter) addBits16Clean(value uint16, bits uint8) { @@ -70,102 +52,20 @@ func (b *bitWriter) encTwoSymbols(ct cTable, av, bv byte) { b.nBits += encA.nBits + encB.nBits } -// addBits16ZeroNC will add up to 16 bits. +// encFourSymbols adds up to 32 bits from four symbols. // It will not check if there is space for them, -// so the caller must ensure that it has flushed recently. -// This is fastest if bits can be zero. -func (b *bitWriter) addBits16ZeroNC(value uint16, bits uint8) { - if bits == 0 { - return - } - value <<= (16 - bits) & 15 - value >>= (16 - bits) & 15 - b.bitContainer |= uint64(value) << (b.nBits & 63) - b.nBits += bits -} - -// flush will flush all pending full bytes. -// There will be at least 56 bits available for writing when this has been called. -// Using flush32 is faster, but leaves less space for writing. -func (b *bitWriter) flush() { - v := b.nBits >> 3 - switch v { - case 0: - return - case 1: - b.out = append(b.out, - byte(b.bitContainer), - ) - b.bitContainer >>= 1 << 3 - case 2: - b.out = append(b.out, - byte(b.bitContainer), - byte(b.bitContainer>>8), - ) - b.bitContainer >>= 2 << 3 - case 3: - b.out = append(b.out, - byte(b.bitContainer), - byte(b.bitContainer>>8), - byte(b.bitContainer>>16), - ) - b.bitContainer >>= 3 << 3 - case 4: - b.out = append(b.out, - byte(b.bitContainer), - byte(b.bitContainer>>8), - byte(b.bitContainer>>16), - byte(b.bitContainer>>24), - ) - b.bitContainer >>= 4 << 3 - case 5: - b.out = append(b.out, - byte(b.bitContainer), - byte(b.bitContainer>>8), - byte(b.bitContainer>>16), - byte(b.bitContainer>>24), - byte(b.bitContainer>>32), - ) - b.bitContainer >>= 5 << 3 - case 6: - b.out = append(b.out, - byte(b.bitContainer), - byte(b.bitContainer>>8), - byte(b.bitContainer>>16), - byte(b.bitContainer>>24), - byte(b.bitContainer>>32), - byte(b.bitContainer>>40), - ) - b.bitContainer >>= 6 << 3 - case 7: - b.out = append(b.out, - byte(b.bitContainer), - byte(b.bitContainer>>8), - byte(b.bitContainer>>16), - byte(b.bitContainer>>24), - byte(b.bitContainer>>32), - byte(b.bitContainer>>40), - byte(b.bitContainer>>48), - ) - b.bitContainer >>= 7 << 3 - case 8: - b.out = append(b.out, - byte(b.bitContainer), - byte(b.bitContainer>>8), - byte(b.bitContainer>>16), - byte(b.bitContainer>>24), - byte(b.bitContainer>>32), - byte(b.bitContainer>>40), - byte(b.bitContainer>>48), - byte(b.bitContainer>>56), - ) - b.bitContainer = 0 - b.nBits = 0 - return - default: - panic(fmt.Errorf("bits (%d) > 64", b.nBits)) - } - b.nBits &= 7 +// so the caller must ensure that b has been flushed recently. +func (b *bitWriter) encFourSymbols(encA, encB, encC, encD cTableEntry) { + bitsA := encA.nBits + bitsB := bitsA + encB.nBits + bitsC := bitsB + encC.nBits + bitsD := bitsC + encD.nBits + combined := uint64(encA.val) | + (uint64(encB.val) << (bitsA & 63)) | + (uint64(encC.val) << (bitsB & 63)) | + (uint64(encD.val) << (bitsC & 63)) + b.bitContainer |= combined << (b.nBits & 63) + b.nBits += bitsD } // flush32 will flush out, so there are at least 32 bits available for writing. @@ -201,10 +101,3 @@ func (b *bitWriter) close() error { b.flushAlign() return nil } - -// reset and continue writing by appending to out. -func (b *bitWriter) reset(out []byte) { - b.bitContainer = 0 - b.nBits = 0 - b.out = out -} diff --git a/vendor/github.com/klauspost/compress/huff0/bytereader.go b/vendor/github.com/klauspost/compress/huff0/bytereader.go index 50bcdf6..4dcab8d 100644 --- a/vendor/github.com/klauspost/compress/huff0/bytereader.go +++ b/vendor/github.com/klauspost/compress/huff0/bytereader.go @@ -20,11 +20,6 @@ func (b *byteReader) init(in []byte) { b.off = 0 } -// advance the stream b n bytes. -func (b *byteReader) advance(n uint) { - b.off += int(n) -} - // Int32 returns a little endian int32 starting at current offset. func (b byteReader) Int32() int32 { v3 := int32(b.b[b.off+3]) @@ -43,11 +38,6 @@ func (b byteReader) Uint32() uint32 { return (v3 << 24) | (v2 << 16) | (v1 << 8) | v0 } -// unread returns the unread portion of the input. -func (b byteReader) unread() []byte { - return b.b[b.off:] -} - // remain will return the number of bytes remaining. func (b byteReader) remain() int { return len(b.b) - b.off diff --git a/vendor/github.com/klauspost/compress/huff0/compress.go b/vendor/github.com/klauspost/compress/huff0/compress.go index 8323dc0..4ee4fa1 100644 --- a/vendor/github.com/klauspost/compress/huff0/compress.go +++ b/vendor/github.com/klauspost/compress/huff0/compress.go @@ -2,6 +2,7 @@ package huff0 import ( "fmt" + "math" "runtime" "sync" ) @@ -247,8 +248,7 @@ func (s *Scratch) compress1xDo(dst, src []byte) ([]byte, error) { tmp := src[n : n+4] // tmp should be len 4 bw.flush32() - bw.encTwoSymbols(cTable, tmp[3], tmp[2]) - bw.encTwoSymbols(cTable, tmp[1], tmp[0]) + bw.encFourSymbols(cTable[tmp[3]], cTable[tmp[2]], cTable[tmp[1]], cTable[tmp[0]]) } } else { for ; n >= 0; n -= 4 { @@ -289,6 +289,10 @@ func (s *Scratch) compress4X(src []byte) ([]byte, error) { if err != nil { return nil, err } + if len(s.Out)-idx > math.MaxUint16 { + // We cannot store the size in the jump table + return nil, ErrIncompressible + } // Write compressed length as little endian before block. if i < 3 { // Last length is not written. @@ -332,6 +336,10 @@ func (s *Scratch) compress4Xp(src []byte) ([]byte, error) { return nil, errs[i] } o := s.tmpOut[i] + if len(o) > math.MaxUint16 { + // We cannot store the size in the jump table + return nil, ErrIncompressible + } // Write compressed length as little endian before block. if i < 3 { // Last length is not written. @@ -356,29 +364,29 @@ func (s *Scratch) countSimple(in []byte) (max int, reuse bool) { m := uint32(0) if len(s.prevTable) > 0 { for i, v := range s.count[:] { + if v == 0 { + continue + } if v > m { m = v } - if v > 0 { - s.symbolLen = uint16(i) + 1 - if i >= len(s.prevTable) { - reuse = false - } else { - if s.prevTable[i].nBits == 0 { - reuse = false - } - } + s.symbolLen = uint16(i) + 1 + if i >= len(s.prevTable) { + reuse = false + } else if s.prevTable[i].nBits == 0 { + reuse = false } } return int(m), reuse } for i, v := range s.count[:] { + if v == 0 { + continue + } if v > m { m = v } - if v > 0 { - s.symbolLen = uint16(i) + 1 - } + s.symbolLen = uint16(i) + 1 } return int(m), false } @@ -395,6 +403,7 @@ func (s *Scratch) canUseTable(c cTable) bool { return true } +//lint:ignore U1000 used for debugging func (s *Scratch) validateTable(c cTable) bool { if len(c) < int(s.symbolLen) { return false @@ -474,34 +483,35 @@ func (s *Scratch) buildCTable() error { // Different from reference implementation. huffNode0 := s.nodes[0 : huffNodesLen+1] - for huffNode[nonNullRank].count == 0 { + for huffNode[nonNullRank].count() == 0 { nonNullRank-- } lowS := int16(nonNullRank) nodeRoot := nodeNb + lowS - 1 lowN := nodeNb - huffNode[nodeNb].count = huffNode[lowS].count + huffNode[lowS-1].count - huffNode[lowS].parent, huffNode[lowS-1].parent = uint16(nodeNb), uint16(nodeNb) + huffNode[nodeNb].setCount(huffNode[lowS].count() + huffNode[lowS-1].count()) + huffNode[lowS].setParent(nodeNb) + huffNode[lowS-1].setParent(nodeNb) nodeNb++ lowS -= 2 for n := nodeNb; n <= nodeRoot; n++ { - huffNode[n].count = 1 << 30 + huffNode[n].setCount(1 << 30) } // fake entry, strong barrier - huffNode0[0].count = 1 << 31 + huffNode0[0].setCount(1 << 31) // create parents for nodeNb <= nodeRoot { var n1, n2 int16 - if huffNode0[lowS+1].count < huffNode0[lowN+1].count { + if huffNode0[lowS+1].count() < huffNode0[lowN+1].count() { n1 = lowS lowS-- } else { n1 = lowN lowN++ } - if huffNode0[lowS+1].count < huffNode0[lowN+1].count { + if huffNode0[lowS+1].count() < huffNode0[lowN+1].count() { n2 = lowS lowS-- } else { @@ -509,18 +519,19 @@ func (s *Scratch) buildCTable() error { lowN++ } - huffNode[nodeNb].count = huffNode0[n1+1].count + huffNode0[n2+1].count - huffNode0[n1+1].parent, huffNode0[n2+1].parent = uint16(nodeNb), uint16(nodeNb) + huffNode[nodeNb].setCount(huffNode0[n1+1].count() + huffNode0[n2+1].count()) + huffNode0[n1+1].setParent(nodeNb) + huffNode0[n2+1].setParent(nodeNb) nodeNb++ } // distribute weights (unlimited tree height) - huffNode[nodeRoot].nbBits = 0 + huffNode[nodeRoot].setNbBits(0) for n := nodeRoot - 1; n >= startNode; n-- { - huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1 + huffNode[n].setNbBits(huffNode[huffNode[n].parent()].nbBits() + 1) } for n := uint16(0); n <= nonNullRank; n++ { - huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1 + huffNode[n].setNbBits(huffNode[huffNode[n].parent()].nbBits() + 1) } s.actualTableLog = s.setMaxHeight(int(nonNullRank)) maxNbBits := s.actualTableLog @@ -532,7 +543,7 @@ func (s *Scratch) buildCTable() error { var nbPerRank [tableLogMax + 1]uint16 var valPerRank [16]uint16 for _, v := range huffNode[:nonNullRank+1] { - nbPerRank[v.nbBits]++ + nbPerRank[v.nbBits()]++ } // determine stating value per rank { @@ -547,7 +558,7 @@ func (s *Scratch) buildCTable() error { // push nbBits per symbol, symbol order for _, v := range huffNode[:nonNullRank+1] { - s.cTable[v.symbol].nBits = v.nbBits + s.cTable[v.symbol()].nBits = v.nbBits() } // assign value within rank, symbol order @@ -593,12 +604,12 @@ func (s *Scratch) huffSort() { pos := rank[r].current rank[r].current++ prev := nodes[(pos-1)&huffNodesMask] - for pos > rank[r].base && c > prev.count { + for pos > rank[r].base && c > prev.count() { nodes[pos&huffNodesMask] = prev pos-- prev = nodes[(pos-1)&huffNodesMask] } - nodes[pos&huffNodesMask] = nodeElt{count: c, symbol: byte(n)} + nodes[pos&huffNodesMask] = makeNodeElt(c, byte(n)) } } @@ -607,7 +618,7 @@ func (s *Scratch) setMaxHeight(lastNonNull int) uint8 { huffNode := s.nodes[1 : huffNodesLen+1] //huffNode = huffNode[: huffNodesLen] - largestBits := huffNode[lastNonNull].nbBits + largestBits := huffNode[lastNonNull].nbBits() // early exit : no elt > maxNbBits if largestBits <= maxNbBits { @@ -617,14 +628,14 @@ func (s *Scratch) setMaxHeight(lastNonNull int) uint8 { baseCost := int(1) << (largestBits - maxNbBits) n := uint32(lastNonNull) - for huffNode[n].nbBits > maxNbBits { - totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits)) - huffNode[n].nbBits = maxNbBits + for huffNode[n].nbBits() > maxNbBits { + totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits())) + huffNode[n].setNbBits(maxNbBits) n-- } // n stops at huffNode[n].nbBits <= maxNbBits - for huffNode[n].nbBits == maxNbBits { + for huffNode[n].nbBits() == maxNbBits { n-- } // n end at index of smallest symbol using < maxNbBits @@ -645,10 +656,10 @@ func (s *Scratch) setMaxHeight(lastNonNull int) uint8 { { currentNbBits := maxNbBits for pos := int(n); pos >= 0; pos-- { - if huffNode[pos].nbBits >= currentNbBits { + if huffNode[pos].nbBits() >= currentNbBits { continue } - currentNbBits = huffNode[pos].nbBits // < maxNbBits + currentNbBits = huffNode[pos].nbBits() // < maxNbBits rankLast[maxNbBits-currentNbBits] = uint32(pos) } } @@ -665,8 +676,8 @@ func (s *Scratch) setMaxHeight(lastNonNull int) uint8 { if lowPos == noSymbol { break } - highTotal := huffNode[highPos].count - lowTotal := 2 * huffNode[lowPos].count + highTotal := huffNode[highPos].count() + lowTotal := 2 * huffNode[lowPos].count() if highTotal <= lowTotal { break } @@ -682,13 +693,14 @@ func (s *Scratch) setMaxHeight(lastNonNull int) uint8 { // this rank is no longer empty rankLast[nBitsToDecrease-1] = rankLast[nBitsToDecrease] } - huffNode[rankLast[nBitsToDecrease]].nbBits++ + huffNode[rankLast[nBitsToDecrease]].setNbBits(1 + + huffNode[rankLast[nBitsToDecrease]].nbBits()) if rankLast[nBitsToDecrease] == 0 { /* special case, reached largest symbol */ rankLast[nBitsToDecrease] = noSymbol } else { rankLast[nBitsToDecrease]-- - if huffNode[rankLast[nBitsToDecrease]].nbBits != maxNbBits-nBitsToDecrease { + if huffNode[rankLast[nBitsToDecrease]].nbBits() != maxNbBits-nBitsToDecrease { rankLast[nBitsToDecrease] = noSymbol /* this rank is now empty */ } } @@ -696,15 +708,15 @@ func (s *Scratch) setMaxHeight(lastNonNull int) uint8 { for totalCost < 0 { /* Sometimes, cost correction overshoot */ if rankLast[1] == noSymbol { /* special case : no rank 1 symbol (using maxNbBits-1); let's create one from largest rank 0 (using maxNbBits) */ - for huffNode[n].nbBits == maxNbBits { + for huffNode[n].nbBits() == maxNbBits { n-- } - huffNode[n+1].nbBits-- + huffNode[n+1].setNbBits(huffNode[n+1].nbBits() - 1) rankLast[1] = n + 1 totalCost++ continue } - huffNode[rankLast[1]+1].nbBits-- + huffNode[rankLast[1]+1].setNbBits(huffNode[rankLast[1]+1].nbBits() - 1) rankLast[1]++ totalCost++ } @@ -712,9 +724,26 @@ func (s *Scratch) setMaxHeight(lastNonNull int) uint8 { return maxNbBits } -type nodeElt struct { - count uint32 - parent uint16 - symbol byte - nbBits uint8 +// A nodeElt is the fields +// +// count uint32 +// parent uint16 +// symbol byte +// nbBits uint8 +// +// in some order, all squashed into an integer so that the compiler +// always loads and stores entire nodeElts instead of separate fields. +type nodeElt uint64 + +func makeNodeElt(count uint32, symbol byte) nodeElt { + return nodeElt(count) | nodeElt(symbol)<<48 } + +func (e *nodeElt) count() uint32 { return uint32(*e) } +func (e *nodeElt) parent() uint16 { return uint16(*e >> 32) } +func (e *nodeElt) symbol() byte { return byte(*e >> 48) } +func (e *nodeElt) nbBits() uint8 { return uint8(*e >> 56) } + +func (e *nodeElt) setCount(c uint32) { *e = (*e)&0xffffffff00000000 | nodeElt(c) } +func (e *nodeElt) setParent(p int16) { *e = (*e)&0xffff0000ffffffff | nodeElt(uint16(p))<<32 } +func (e *nodeElt) setNbBits(n uint8) { *e = (*e)&0x00ffffffffffffff | nodeElt(n)<<56 } diff --git a/vendor/github.com/klauspost/compress/huff0/decompress.go b/vendor/github.com/klauspost/compress/huff0/decompress.go index 9b7cc8e..54bd08b 100644 --- a/vendor/github.com/klauspost/compress/huff0/decompress.go +++ b/vendor/github.com/klauspost/compress/huff0/decompress.go @@ -4,13 +4,13 @@ import ( "errors" "fmt" "io" + "sync" "github.com/klauspost/compress/fse" ) type dTable struct { single []dEntrySingle - double []dEntryDouble } // single-symbols decoding @@ -18,13 +18,6 @@ type dEntrySingle struct { entry uint16 } -// double-symbols decoding -type dEntryDouble struct { - seq uint16 - nBits uint8 - len uint8 -} - // Uses special code for all tables that are < 8 bits. const use8BitTables = true @@ -34,7 +27,7 @@ const use8BitTables = true // If no Scratch is provided a new one is allocated. // The returned Scratch can be used for encoding or decoding input using this table. func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) { - s, err = s.prepare(in) + s, err = s.prepare(nil) if err != nil { return s, nil, err } @@ -68,7 +61,7 @@ func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) { b, err := fse.Decompress(in[:iSize], s.fse) s.fse.Out = nil if err != nil { - return s, nil, err + return s, nil, fmt.Errorf("fse decompress returned: %w", err) } if len(b) > 255 { return s, nil, errors.New("corrupt input: output table too large") @@ -216,6 +209,7 @@ func (s *Scratch) Decoder() *Decoder { return &Decoder{ dt: s.dt, actualTableLog: s.actualTableLog, + bufs: &s.decPool, } } @@ -223,103 +217,15 @@ func (s *Scratch) Decoder() *Decoder { type Decoder struct { dt dTable actualTableLog uint8 + bufs *sync.Pool } -// Decompress1X will decompress a 1X encoded stream. -// The cap of the output buffer will be the maximum decompressed size. -// The length of the supplied input must match the end of a block exactly. -func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) { - if len(d.dt.single) == 0 { - return nil, errors.New("no table loaded") - } - if use8BitTables && d.actualTableLog <= 8 { - return d.decompress1X8Bit(dst, src) - } - var br bitReaderShifted - err := br.init(src) - if err != nil { - return dst, err - } - maxDecodedSize := cap(dst) - dst = dst[:0] - - // Avoid bounds check by always having full sized table. - const tlSize = 1 << tableLogMax - const tlMask = tlSize - 1 - dt := d.dt.single[:tlSize] - - // Use temp table to avoid bound checks/append penalty. - var buf [256]byte - var off uint8 - - for br.off >= 8 { - br.fillFast() - v := dt[br.peekBitsFast(d.actualTableLog)&tlMask] - br.advance(uint8(v.entry)) - buf[off+0] = uint8(v.entry >> 8) - - v = dt[br.peekBitsFast(d.actualTableLog)&tlMask] - br.advance(uint8(v.entry)) - buf[off+1] = uint8(v.entry >> 8) - - // Refill - br.fillFast() - - v = dt[br.peekBitsFast(d.actualTableLog)&tlMask] - br.advance(uint8(v.entry)) - buf[off+2] = uint8(v.entry >> 8) - - v = dt[br.peekBitsFast(d.actualTableLog)&tlMask] - br.advance(uint8(v.entry)) - buf[off+3] = uint8(v.entry >> 8) - - off += 4 - if off == 0 { - if len(dst)+256 > maxDecodedSize { - br.close() - return nil, ErrMaxDecodedSizeExceeded - } - dst = append(dst, buf[:]...) - } - } - - if len(dst)+int(off) > maxDecodedSize { - br.close() - return nil, ErrMaxDecodedSizeExceeded - } - dst = append(dst, buf[:off]...) - - // br < 8, so uint8 is fine - bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead - for bitsLeft > 0 { - br.fill() - if false && br.bitsRead >= 32 { - if br.off >= 4 { - v := br.in[br.off-4:] - v = v[:4] - low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) - br.value = (br.value << 32) | uint64(low) - br.bitsRead -= 32 - br.off -= 4 - } else { - for br.off > 0 { - br.value = (br.value << 8) | uint64(br.in[br.off-1]) - br.bitsRead -= 8 - br.off-- - } - } - } - if len(dst) >= maxDecodedSize { - br.close() - return nil, ErrMaxDecodedSizeExceeded - } - v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask] - nBits := uint8(v.entry) - br.advance(nBits) - bitsLeft -= nBits - dst = append(dst, uint8(v.entry>>8)) +func (d *Decoder) buffer() *[4][256]byte { + buf, ok := d.bufs.Get().(*[4][256]byte) + if ok { + return buf } - return dst, br.close() + return &[4][256]byte{} } // decompress1X8Bit will decompress a 1X encoded stream with tablelog <= 8. @@ -341,12 +247,13 @@ func (d *Decoder) decompress1X8Bit(dst, src []byte) ([]byte, error) { dt := d.dt.single[:256] // Use temp table to avoid bound checks/append penalty. - var buf [256]byte + bufs := d.buffer() + buf := &bufs[0] var off uint8 switch d.actualTableLog { case 8: - const shift = 8 - 8 + const shift = 0 for br.off >= 4 { br.fillFast() v := dt[uint8(br.value>>(56+shift))] @@ -369,6 +276,7 @@ func (d *Decoder) decompress1X8Bit(dst, src []byte) ([]byte, error) { if off == 0 { if len(dst)+256 > maxDecodedSize { br.close() + d.bufs.Put(bufs) return nil, ErrMaxDecodedSizeExceeded } dst = append(dst, buf[:]...) @@ -398,6 +306,7 @@ func (d *Decoder) decompress1X8Bit(dst, src []byte) ([]byte, error) { if off == 0 { if len(dst)+256 > maxDecodedSize { br.close() + d.bufs.Put(bufs) return nil, ErrMaxDecodedSizeExceeded } dst = append(dst, buf[:]...) @@ -426,6 +335,7 @@ func (d *Decoder) decompress1X8Bit(dst, src []byte) ([]byte, error) { off += 4 if off == 0 { if len(dst)+256 > maxDecodedSize { + d.bufs.Put(bufs) br.close() return nil, ErrMaxDecodedSizeExceeded } @@ -455,6 +365,7 @@ func (d *Decoder) decompress1X8Bit(dst, src []byte) ([]byte, error) { off += 4 if off == 0 { if len(dst)+256 > maxDecodedSize { + d.bufs.Put(bufs) br.close() return nil, ErrMaxDecodedSizeExceeded } @@ -484,6 +395,7 @@ func (d *Decoder) decompress1X8Bit(dst, src []byte) ([]byte, error) { off += 4 if off == 0 { if len(dst)+256 > maxDecodedSize { + d.bufs.Put(bufs) br.close() return nil, ErrMaxDecodedSizeExceeded } @@ -513,6 +425,7 @@ func (d *Decoder) decompress1X8Bit(dst, src []byte) ([]byte, error) { off += 4 if off == 0 { if len(dst)+256 > maxDecodedSize { + d.bufs.Put(bufs) br.close() return nil, ErrMaxDecodedSizeExceeded } @@ -542,6 +455,7 @@ func (d *Decoder) decompress1X8Bit(dst, src []byte) ([]byte, error) { off += 4 if off == 0 { if len(dst)+256 > maxDecodedSize { + d.bufs.Put(bufs) br.close() return nil, ErrMaxDecodedSizeExceeded } @@ -571,6 +485,7 @@ func (d *Decoder) decompress1X8Bit(dst, src []byte) ([]byte, error) { off += 4 if off == 0 { if len(dst)+256 > maxDecodedSize { + d.bufs.Put(bufs) br.close() return nil, ErrMaxDecodedSizeExceeded } @@ -578,10 +493,12 @@ func (d *Decoder) decompress1X8Bit(dst, src []byte) ([]byte, error) { } } default: + d.bufs.Put(bufs) return nil, fmt.Errorf("invalid tablelog: %d", d.actualTableLog) } if len(dst)+int(off) > maxDecodedSize { + d.bufs.Put(bufs) br.close() return nil, ErrMaxDecodedSizeExceeded } @@ -601,6 +518,7 @@ func (d *Decoder) decompress1X8Bit(dst, src []byte) ([]byte, error) { } if len(dst) >= maxDecodedSize { br.close() + d.bufs.Put(bufs) return nil, ErrMaxDecodedSizeExceeded } v := dt[br.peekByteFast()>>shift] @@ -609,6 +527,7 @@ func (d *Decoder) decompress1X8Bit(dst, src []byte) ([]byte, error) { bitsLeft -= int8(nBits) dst = append(dst, uint8(v.entry>>8)) } + d.bufs.Put(bufs) return dst, br.close() } @@ -628,7 +547,8 @@ func (d *Decoder) decompress1X8BitExactly(dst, src []byte) ([]byte, error) { dt := d.dt.single[:256] // Use temp table to avoid bound checks/append penalty. - var buf [256]byte + bufs := d.buffer() + buf := &bufs[0] var off uint8 const shift = 56 @@ -655,6 +575,7 @@ func (d *Decoder) decompress1X8BitExactly(dst, src []byte) ([]byte, error) { off += 4 if off == 0 { if len(dst)+256 > maxDecodedSize { + d.bufs.Put(bufs) br.close() return nil, ErrMaxDecodedSizeExceeded } @@ -663,6 +584,7 @@ func (d *Decoder) decompress1X8BitExactly(dst, src []byte) ([]byte, error) { } if len(dst)+int(off) > maxDecodedSize { + d.bufs.Put(bufs) br.close() return nil, ErrMaxDecodedSizeExceeded } @@ -679,6 +601,7 @@ func (d *Decoder) decompress1X8BitExactly(dst, src []byte) ([]byte, error) { } } if len(dst) >= maxDecodedSize { + d.bufs.Put(bufs) br.close() return nil, ErrMaxDecodedSizeExceeded } @@ -688,199 +611,10 @@ func (d *Decoder) decompress1X8BitExactly(dst, src []byte) ([]byte, error) { bitsLeft -= int8(nBits) dst = append(dst, uint8(v.entry>>8)) } + d.bufs.Put(bufs) return dst, br.close() } -// Decompress4X will decompress a 4X encoded stream. -// The length of the supplied input must match the end of a block exactly. -// The *capacity* of the dst slice must match the destination size of -// the uncompressed data exactly. -func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) { - if len(d.dt.single) == 0 { - return nil, errors.New("no table loaded") - } - if len(src) < 6+(4*1) { - return nil, errors.New("input too small") - } - if use8BitTables && d.actualTableLog <= 8 { - return d.decompress4X8bit(dst, src) - } - - var br [4]bitReaderShifted - start := 6 - for i := 0; i < 3; i++ { - length := int(src[i*2]) | (int(src[i*2+1]) << 8) - if start+length >= len(src) { - return nil, errors.New("truncated input (or invalid offset)") - } - err := br[i].init(src[start : start+length]) - if err != nil { - return nil, err - } - start += length - } - err := br[3].init(src[start:]) - if err != nil { - return nil, err - } - - // destination, offset to match first output - dstSize := cap(dst) - dst = dst[:dstSize] - out := dst - dstEvery := (dstSize + 3) / 4 - - const tlSize = 1 << tableLogMax - const tlMask = tlSize - 1 - single := d.dt.single[:tlSize] - - // Use temp table to avoid bound checks/append penalty. - var buf [256]byte - var off uint8 - var decoded int - - // Decode 2 values from each decoder/loop. - const bufoff = 256 / 4 - for { - if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 { - break - } - - { - const stream = 0 - const stream2 = 1 - br[stream].fillFast() - br[stream2].fillFast() - - val := br[stream].peekBitsFast(d.actualTableLog) - v := single[val&tlMask] - br[stream].advance(uint8(v.entry)) - buf[off+bufoff*stream] = uint8(v.entry >> 8) - - val2 := br[stream2].peekBitsFast(d.actualTableLog) - v2 := single[val2&tlMask] - br[stream2].advance(uint8(v2.entry)) - buf[off+bufoff*stream2] = uint8(v2.entry >> 8) - - val = br[stream].peekBitsFast(d.actualTableLog) - v = single[val&tlMask] - br[stream].advance(uint8(v.entry)) - buf[off+bufoff*stream+1] = uint8(v.entry >> 8) - - val2 = br[stream2].peekBitsFast(d.actualTableLog) - v2 = single[val2&tlMask] - br[stream2].advance(uint8(v2.entry)) - buf[off+bufoff*stream2+1] = uint8(v2.entry >> 8) - } - - { - const stream = 2 - const stream2 = 3 - br[stream].fillFast() - br[stream2].fillFast() - - val := br[stream].peekBitsFast(d.actualTableLog) - v := single[val&tlMask] - br[stream].advance(uint8(v.entry)) - buf[off+bufoff*stream] = uint8(v.entry >> 8) - - val2 := br[stream2].peekBitsFast(d.actualTableLog) - v2 := single[val2&tlMask] - br[stream2].advance(uint8(v2.entry)) - buf[off+bufoff*stream2] = uint8(v2.entry >> 8) - - val = br[stream].peekBitsFast(d.actualTableLog) - v = single[val&tlMask] - br[stream].advance(uint8(v.entry)) - buf[off+bufoff*stream+1] = uint8(v.entry >> 8) - - val2 = br[stream2].peekBitsFast(d.actualTableLog) - v2 = single[val2&tlMask] - br[stream2].advance(uint8(v2.entry)) - buf[off+bufoff*stream2+1] = uint8(v2.entry >> 8) - } - - off += 2 - - if off == bufoff { - if bufoff > dstEvery { - return nil, errors.New("corruption detected: stream overrun 1") - } - copy(out, buf[:bufoff]) - copy(out[dstEvery:], buf[bufoff:bufoff*2]) - copy(out[dstEvery*2:], buf[bufoff*2:bufoff*3]) - copy(out[dstEvery*3:], buf[bufoff*3:bufoff*4]) - off = 0 - out = out[bufoff:] - decoded += 256 - // There must at least be 3 buffers left. - if len(out) < dstEvery*3 { - return nil, errors.New("corruption detected: stream overrun 2") - } - } - } - if off > 0 { - ioff := int(off) - if len(out) < dstEvery*3+ioff { - return nil, errors.New("corruption detected: stream overrun 3") - } - copy(out, buf[:off]) - copy(out[dstEvery:dstEvery+ioff], buf[bufoff:bufoff*2]) - copy(out[dstEvery*2:dstEvery*2+ioff], buf[bufoff*2:bufoff*3]) - copy(out[dstEvery*3:dstEvery*3+ioff], buf[bufoff*3:bufoff*4]) - decoded += int(off) * 4 - out = out[off:] - } - - // Decode remaining. - for i := range br { - offset := dstEvery * i - br := &br[i] - bitsLeft := br.off*8 + uint(64-br.bitsRead) - for bitsLeft > 0 { - br.fill() - if false && br.bitsRead >= 32 { - if br.off >= 4 { - v := br.in[br.off-4:] - v = v[:4] - low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) - br.value = (br.value << 32) | uint64(low) - br.bitsRead -= 32 - br.off -= 4 - } else { - for br.off > 0 { - br.value = (br.value << 8) | uint64(br.in[br.off-1]) - br.bitsRead -= 8 - br.off-- - } - } - } - // end inline... - if offset >= len(out) { - return nil, errors.New("corruption detected: stream overrun 4") - } - - // Read value and increment offset. - val := br.peekBitsFast(d.actualTableLog) - v := single[val&tlMask].entry - nBits := uint8(v) - br.advance(nBits) - bitsLeft -= uint(nBits) - out[offset] = uint8(v >> 8) - offset++ - } - decoded += offset - dstEvery*i - err = br.close() - if err != nil { - return nil, err - } - } - if dstSize != decoded { - return nil, errors.New("corruption detected: short output block") - } - return dst, nil -} - // Decompress4X will decompress a 4X encoded stream. // The length of the supplied input must match the end of a block exactly. // The *capacity* of the dst slice must match the destination size of @@ -914,18 +648,18 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) { out := dst dstEvery := (dstSize + 3) / 4 - shift := (8 - d.actualTableLog) & 7 + shift := (56 + (8 - d.actualTableLog)) & 63 const tlSize = 1 << 8 single := d.dt.single[:tlSize] // Use temp table to avoid bound checks/append penalty. - var buf [256]byte + buf := d.buffer() var off uint8 var decoded int // Decode 4 values from each decoder/loop. - const bufoff = 256 / 4 + const bufoff = 256 for { if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 { break @@ -935,120 +669,144 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) { // Interleave 2 decodes. const stream = 0 const stream2 = 1 - br[stream].fillFast() - br[stream2].fillFast() - - v := single[br[stream].peekByteFast()>>shift].entry - buf[off+bufoff*stream] = uint8(v >> 8) - br[stream].advance(uint8(v)) - - v2 := single[br[stream2].peekByteFast()>>shift].entry - buf[off+bufoff*stream2] = uint8(v2 >> 8) - br[stream2].advance(uint8(v2)) - - v = single[br[stream].peekByteFast()>>shift].entry - buf[off+bufoff*stream+1] = uint8(v >> 8) - br[stream].advance(uint8(v)) - - v2 = single[br[stream2].peekByteFast()>>shift].entry - buf[off+bufoff*stream2+1] = uint8(v2 >> 8) - br[stream2].advance(uint8(v2)) - - v = single[br[stream].peekByteFast()>>shift].entry - buf[off+bufoff*stream+2] = uint8(v >> 8) - br[stream].advance(uint8(v)) - - v2 = single[br[stream2].peekByteFast()>>shift].entry - buf[off+bufoff*stream2+2] = uint8(v2 >> 8) - br[stream2].advance(uint8(v2)) - - v = single[br[stream].peekByteFast()>>shift].entry - buf[off+bufoff*stream+3] = uint8(v >> 8) - br[stream].advance(uint8(v)) - - v2 = single[br[stream2].peekByteFast()>>shift].entry - buf[off+bufoff*stream2+3] = uint8(v2 >> 8) - br[stream2].advance(uint8(v2)) + br1 := &br[stream] + br2 := &br[stream2] + br1.fillFast() + br2.fillFast() + + v := single[uint8(br1.value>>shift)].entry + v2 := single[uint8(br2.value>>shift)].entry + br1.bitsRead += uint8(v) + br1.value <<= v & 63 + br2.bitsRead += uint8(v2) + br2.value <<= v2 & 63 + buf[stream][off] = uint8(v >> 8) + buf[stream2][off] = uint8(v2 >> 8) + + v = single[uint8(br1.value>>shift)].entry + v2 = single[uint8(br2.value>>shift)].entry + br1.bitsRead += uint8(v) + br1.value <<= v & 63 + br2.bitsRead += uint8(v2) + br2.value <<= v2 & 63 + buf[stream][off+1] = uint8(v >> 8) + buf[stream2][off+1] = uint8(v2 >> 8) + + v = single[uint8(br1.value>>shift)].entry + v2 = single[uint8(br2.value>>shift)].entry + br1.bitsRead += uint8(v) + br1.value <<= v & 63 + br2.bitsRead += uint8(v2) + br2.value <<= v2 & 63 + buf[stream][off+2] = uint8(v >> 8) + buf[stream2][off+2] = uint8(v2 >> 8) + + v = single[uint8(br1.value>>shift)].entry + v2 = single[uint8(br2.value>>shift)].entry + br1.bitsRead += uint8(v) + br1.value <<= v & 63 + br2.bitsRead += uint8(v2) + br2.value <<= v2 & 63 + buf[stream][off+3] = uint8(v >> 8) + buf[stream2][off+3] = uint8(v2 >> 8) } { const stream = 2 const stream2 = 3 - br[stream].fillFast() - br[stream2].fillFast() - - v := single[br[stream].peekByteFast()>>shift].entry - buf[off+bufoff*stream] = uint8(v >> 8) - br[stream].advance(uint8(v)) - - v2 := single[br[stream2].peekByteFast()>>shift].entry - buf[off+bufoff*stream2] = uint8(v2 >> 8) - br[stream2].advance(uint8(v2)) - - v = single[br[stream].peekByteFast()>>shift].entry - buf[off+bufoff*stream+1] = uint8(v >> 8) - br[stream].advance(uint8(v)) - - v2 = single[br[stream2].peekByteFast()>>shift].entry - buf[off+bufoff*stream2+1] = uint8(v2 >> 8) - br[stream2].advance(uint8(v2)) - - v = single[br[stream].peekByteFast()>>shift].entry - buf[off+bufoff*stream+2] = uint8(v >> 8) - br[stream].advance(uint8(v)) - - v2 = single[br[stream2].peekByteFast()>>shift].entry - buf[off+bufoff*stream2+2] = uint8(v2 >> 8) - br[stream2].advance(uint8(v2)) - - v = single[br[stream].peekByteFast()>>shift].entry - buf[off+bufoff*stream+3] = uint8(v >> 8) - br[stream].advance(uint8(v)) - - v2 = single[br[stream2].peekByteFast()>>shift].entry - buf[off+bufoff*stream2+3] = uint8(v2 >> 8) - br[stream2].advance(uint8(v2)) + br1 := &br[stream] + br2 := &br[stream2] + br1.fillFast() + br2.fillFast() + + v := single[uint8(br1.value>>shift)].entry + v2 := single[uint8(br2.value>>shift)].entry + br1.bitsRead += uint8(v) + br1.value <<= v & 63 + br2.bitsRead += uint8(v2) + br2.value <<= v2 & 63 + buf[stream][off] = uint8(v >> 8) + buf[stream2][off] = uint8(v2 >> 8) + + v = single[uint8(br1.value>>shift)].entry + v2 = single[uint8(br2.value>>shift)].entry + br1.bitsRead += uint8(v) + br1.value <<= v & 63 + br2.bitsRead += uint8(v2) + br2.value <<= v2 & 63 + buf[stream][off+1] = uint8(v >> 8) + buf[stream2][off+1] = uint8(v2 >> 8) + + v = single[uint8(br1.value>>shift)].entry + v2 = single[uint8(br2.value>>shift)].entry + br1.bitsRead += uint8(v) + br1.value <<= v & 63 + br2.bitsRead += uint8(v2) + br2.value <<= v2 & 63 + buf[stream][off+2] = uint8(v >> 8) + buf[stream2][off+2] = uint8(v2 >> 8) + + v = single[uint8(br1.value>>shift)].entry + v2 = single[uint8(br2.value>>shift)].entry + br1.bitsRead += uint8(v) + br1.value <<= v & 63 + br2.bitsRead += uint8(v2) + br2.value <<= v2 & 63 + buf[stream][off+3] = uint8(v >> 8) + buf[stream2][off+3] = uint8(v2 >> 8) } off += 4 - if off == bufoff { + if off == 0 { if bufoff > dstEvery { + d.bufs.Put(buf) return nil, errors.New("corruption detected: stream overrun 1") } - copy(out, buf[:bufoff]) - copy(out[dstEvery:], buf[bufoff:bufoff*2]) - copy(out[dstEvery*2:], buf[bufoff*2:bufoff*3]) - copy(out[dstEvery*3:], buf[bufoff*3:bufoff*4]) - off = 0 - out = out[bufoff:] - decoded += 256 // There must at least be 3 buffers left. - if len(out) < dstEvery*3 { + if len(out)-bufoff < dstEvery*3 { + d.bufs.Put(buf) return nil, errors.New("corruption detected: stream overrun 2") } + //copy(out, buf[0][:]) + //copy(out[dstEvery:], buf[1][:]) + //copy(out[dstEvery*2:], buf[2][:]) + *(*[bufoff]byte)(out) = buf[0] + *(*[bufoff]byte)(out[dstEvery:]) = buf[1] + *(*[bufoff]byte)(out[dstEvery*2:]) = buf[2] + *(*[bufoff]byte)(out[dstEvery*3:]) = buf[3] + out = out[bufoff:] + decoded += bufoff * 4 } } if off > 0 { ioff := int(off) if len(out) < dstEvery*3+ioff { + d.bufs.Put(buf) return nil, errors.New("corruption detected: stream overrun 3") } - copy(out, buf[:off]) - copy(out[dstEvery:dstEvery+ioff], buf[bufoff:bufoff*2]) - copy(out[dstEvery*2:dstEvery*2+ioff], buf[bufoff*2:bufoff*3]) - copy(out[dstEvery*3:dstEvery*3+ioff], buf[bufoff*3:bufoff*4]) + copy(out, buf[0][:off]) + copy(out[dstEvery:], buf[1][:off]) + copy(out[dstEvery*2:], buf[2][:off]) + copy(out[dstEvery*3:], buf[3][:off]) decoded += int(off) * 4 out = out[off:] } // Decode remaining. + // Decode remaining. + remainBytes := dstEvery - (decoded / 4) for i := range br { offset := dstEvery * i + endsAt := offset + remainBytes + if endsAt > len(out) { + endsAt = len(out) + } br := &br[i] - bitsLeft := int(br.off*8) + int(64-br.bitsRead) + bitsLeft := br.remaining() for bitsLeft > 0 { if br.finished() { + d.bufs.Put(buf) return nil, io.ErrUnexpectedEOF } if br.bitsRead >= 56 { @@ -1068,24 +826,31 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) { } } // end inline... - if offset >= len(out) { + if offset >= endsAt { + d.bufs.Put(buf) return nil, errors.New("corruption detected: stream overrun 4") } // Read value and increment offset. - v := single[br.peekByteFast()>>shift].entry + v := single[uint8(br.value>>shift)].entry nBits := uint8(v) br.advance(nBits) - bitsLeft -= int(nBits) + bitsLeft -= uint(nBits) out[offset] = uint8(v >> 8) offset++ } + if offset != endsAt { + d.bufs.Put(buf) + return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt) + } decoded += offset - dstEvery*i err = br.close() if err != nil { + d.bufs.Put(buf) return nil, err } } + d.bufs.Put(buf) if dstSize != decoded { return nil, errors.New("corruption detected: short output block") } @@ -1121,18 +886,17 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) { out := dst dstEvery := (dstSize + 3) / 4 - const shift = 0 + const shift = 56 const tlSize = 1 << 8 - const tlMask = tlSize - 1 single := d.dt.single[:tlSize] // Use temp table to avoid bound checks/append penalty. - var buf [256]byte + buf := d.buffer() var off uint8 var decoded int // Decode 4 values from each decoder/loop. - const bufoff = 256 / 4 + const bufoff = 256 for { if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 { break @@ -1142,98 +906,116 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) { // Interleave 2 decodes. const stream = 0 const stream2 = 1 - br[stream].fillFast() - br[stream2].fillFast() - - v := single[br[stream].peekByteFast()>>shift].entry - buf[off+bufoff*stream] = uint8(v >> 8) - br[stream].advance(uint8(v)) - - v2 := single[br[stream2].peekByteFast()>>shift].entry - buf[off+bufoff*stream2] = uint8(v2 >> 8) - br[stream2].advance(uint8(v2)) - - v = single[br[stream].peekByteFast()>>shift].entry - buf[off+bufoff*stream+1] = uint8(v >> 8) - br[stream].advance(uint8(v)) - - v2 = single[br[stream2].peekByteFast()>>shift].entry - buf[off+bufoff*stream2+1] = uint8(v2 >> 8) - br[stream2].advance(uint8(v2)) - - v = single[br[stream].peekByteFast()>>shift].entry - buf[off+bufoff*stream+2] = uint8(v >> 8) - br[stream].advance(uint8(v)) - - v2 = single[br[stream2].peekByteFast()>>shift].entry - buf[off+bufoff*stream2+2] = uint8(v2 >> 8) - br[stream2].advance(uint8(v2)) - - v = single[br[stream].peekByteFast()>>shift].entry - buf[off+bufoff*stream+3] = uint8(v >> 8) - br[stream].advance(uint8(v)) - - v2 = single[br[stream2].peekByteFast()>>shift].entry - buf[off+bufoff*stream2+3] = uint8(v2 >> 8) - br[stream2].advance(uint8(v2)) + br1 := &br[stream] + br2 := &br[stream2] + br1.fillFast() + br2.fillFast() + + v := single[uint8(br1.value>>shift)].entry + v2 := single[uint8(br2.value>>shift)].entry + br1.bitsRead += uint8(v) + br1.value <<= v & 63 + br2.bitsRead += uint8(v2) + br2.value <<= v2 & 63 + buf[stream][off] = uint8(v >> 8) + buf[stream2][off] = uint8(v2 >> 8) + + v = single[uint8(br1.value>>shift)].entry + v2 = single[uint8(br2.value>>shift)].entry + br1.bitsRead += uint8(v) + br1.value <<= v & 63 + br2.bitsRead += uint8(v2) + br2.value <<= v2 & 63 + buf[stream][off+1] = uint8(v >> 8) + buf[stream2][off+1] = uint8(v2 >> 8) + + v = single[uint8(br1.value>>shift)].entry + v2 = single[uint8(br2.value>>shift)].entry + br1.bitsRead += uint8(v) + br1.value <<= v & 63 + br2.bitsRead += uint8(v2) + br2.value <<= v2 & 63 + buf[stream][off+2] = uint8(v >> 8) + buf[stream2][off+2] = uint8(v2 >> 8) + + v = single[uint8(br1.value>>shift)].entry + v2 = single[uint8(br2.value>>shift)].entry + br1.bitsRead += uint8(v) + br1.value <<= v & 63 + br2.bitsRead += uint8(v2) + br2.value <<= v2 & 63 + buf[stream][off+3] = uint8(v >> 8) + buf[stream2][off+3] = uint8(v2 >> 8) } { const stream = 2 const stream2 = 3 - br[stream].fillFast() - br[stream2].fillFast() - - v := single[br[stream].peekByteFast()>>shift].entry - buf[off+bufoff*stream] = uint8(v >> 8) - br[stream].advance(uint8(v)) - - v2 := single[br[stream2].peekByteFast()>>shift].entry - buf[off+bufoff*stream2] = uint8(v2 >> 8) - br[stream2].advance(uint8(v2)) - - v = single[br[stream].peekByteFast()>>shift].entry - buf[off+bufoff*stream+1] = uint8(v >> 8) - br[stream].advance(uint8(v)) - - v2 = single[br[stream2].peekByteFast()>>shift].entry - buf[off+bufoff*stream2+1] = uint8(v2 >> 8) - br[stream2].advance(uint8(v2)) - - v = single[br[stream].peekByteFast()>>shift].entry - buf[off+bufoff*stream+2] = uint8(v >> 8) - br[stream].advance(uint8(v)) - - v2 = single[br[stream2].peekByteFast()>>shift].entry - buf[off+bufoff*stream2+2] = uint8(v2 >> 8) - br[stream2].advance(uint8(v2)) - - v = single[br[stream].peekByteFast()>>shift].entry - buf[off+bufoff*stream+3] = uint8(v >> 8) - br[stream].advance(uint8(v)) - - v2 = single[br[stream2].peekByteFast()>>shift].entry - buf[off+bufoff*stream2+3] = uint8(v2 >> 8) - br[stream2].advance(uint8(v2)) + br1 := &br[stream] + br2 := &br[stream2] + br1.fillFast() + br2.fillFast() + + v := single[uint8(br1.value>>shift)].entry + v2 := single[uint8(br2.value>>shift)].entry + br1.bitsRead += uint8(v) + br1.value <<= v & 63 + br2.bitsRead += uint8(v2) + br2.value <<= v2 & 63 + buf[stream][off] = uint8(v >> 8) + buf[stream2][off] = uint8(v2 >> 8) + + v = single[uint8(br1.value>>shift)].entry + v2 = single[uint8(br2.value>>shift)].entry + br1.bitsRead += uint8(v) + br1.value <<= v & 63 + br2.bitsRead += uint8(v2) + br2.value <<= v2 & 63 + buf[stream][off+1] = uint8(v >> 8) + buf[stream2][off+1] = uint8(v2 >> 8) + + v = single[uint8(br1.value>>shift)].entry + v2 = single[uint8(br2.value>>shift)].entry + br1.bitsRead += uint8(v) + br1.value <<= v & 63 + br2.bitsRead += uint8(v2) + br2.value <<= v2 & 63 + buf[stream][off+2] = uint8(v >> 8) + buf[stream2][off+2] = uint8(v2 >> 8) + + v = single[uint8(br1.value>>shift)].entry + v2 = single[uint8(br2.value>>shift)].entry + br1.bitsRead += uint8(v) + br1.value <<= v & 63 + br2.bitsRead += uint8(v2) + br2.value <<= v2 & 63 + buf[stream][off+3] = uint8(v >> 8) + buf[stream2][off+3] = uint8(v2 >> 8) } off += 4 - if off == bufoff { + if off == 0 { if bufoff > dstEvery { + d.bufs.Put(buf) return nil, errors.New("corruption detected: stream overrun 1") } - copy(out, buf[:bufoff]) - copy(out[dstEvery:], buf[bufoff:bufoff*2]) - copy(out[dstEvery*2:], buf[bufoff*2:bufoff*3]) - copy(out[dstEvery*3:], buf[bufoff*3:bufoff*4]) - off = 0 - out = out[bufoff:] - decoded += 256 // There must at least be 3 buffers left. - if len(out) < dstEvery*3 { + if len(out)-bufoff < dstEvery*3 { + d.bufs.Put(buf) return nil, errors.New("corruption detected: stream overrun 2") } + + //copy(out, buf[0][:]) + //copy(out[dstEvery:], buf[1][:]) + //copy(out[dstEvery*2:], buf[2][:]) + // copy(out[dstEvery*3:], buf[3][:]) + *(*[bufoff]byte)(out) = buf[0] + *(*[bufoff]byte)(out[dstEvery:]) = buf[1] + *(*[bufoff]byte)(out[dstEvery*2:]) = buf[2] + *(*[bufoff]byte)(out[dstEvery*3:]) = buf[3] + out = out[bufoff:] + decoded += bufoff * 4 } } if off > 0 { @@ -1241,21 +1023,27 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) { if len(out) < dstEvery*3+ioff { return nil, errors.New("corruption detected: stream overrun 3") } - copy(out, buf[:off]) - copy(out[dstEvery:dstEvery+ioff], buf[bufoff:bufoff*2]) - copy(out[dstEvery*2:dstEvery*2+ioff], buf[bufoff*2:bufoff*3]) - copy(out[dstEvery*3:dstEvery*3+ioff], buf[bufoff*3:bufoff*4]) + copy(out, buf[0][:off]) + copy(out[dstEvery:], buf[1][:off]) + copy(out[dstEvery*2:], buf[2][:off]) + copy(out[dstEvery*3:], buf[3][:off]) decoded += int(off) * 4 out = out[off:] } // Decode remaining. + remainBytes := dstEvery - (decoded / 4) for i := range br { offset := dstEvery * i + endsAt := offset + remainBytes + if endsAt > len(out) { + endsAt = len(out) + } br := &br[i] - bitsLeft := int(br.off*8) + int(64-br.bitsRead) + bitsLeft := br.remaining() for bitsLeft > 0 { if br.finished() { + d.bufs.Put(buf) return nil, io.ErrUnexpectedEOF } if br.bitsRead >= 56 { @@ -1275,24 +1063,32 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) { } } // end inline... - if offset >= len(out) { + if offset >= endsAt { + d.bufs.Put(buf) return nil, errors.New("corruption detected: stream overrun 4") } // Read value and increment offset. - v := single[br.peekByteFast()>>shift].entry + v := single[br.peekByteFast()].entry nBits := uint8(v) br.advance(nBits) - bitsLeft -= int(nBits) + bitsLeft -= uint(nBits) out[offset] = uint8(v >> 8) offset++ } + if offset != endsAt { + d.bufs.Put(buf) + return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt) + } + decoded += offset - dstEvery*i err = br.close() if err != nil { + d.bufs.Put(buf) return nil, err } } + d.bufs.Put(buf) if dstSize != decoded { return nil, errors.New("corruption detected: short output block") } diff --git a/vendor/github.com/klauspost/compress/huff0/decompress_amd64.go b/vendor/github.com/klauspost/compress/huff0/decompress_amd64.go new file mode 100644 index 0000000..ba7e8e6 --- /dev/null +++ b/vendor/github.com/klauspost/compress/huff0/decompress_amd64.go @@ -0,0 +1,226 @@ +//go:build amd64 && !appengine && !noasm && gc +// +build amd64,!appengine,!noasm,gc + +// This file contains the specialisation of Decoder.Decompress4X +// and Decoder.Decompress1X that use an asm implementation of thir main loops. +package huff0 + +import ( + "errors" + "fmt" + + "github.com/klauspost/compress/internal/cpuinfo" +) + +// decompress4x_main_loop_x86 is an x86 assembler implementation +// of Decompress4X when tablelog > 8. +// +//go:noescape +func decompress4x_main_loop_amd64(ctx *decompress4xContext) + +// decompress4x_8b_loop_x86 is an x86 assembler implementation +// of Decompress4X when tablelog <= 8 which decodes 4 entries +// per loop. +// +//go:noescape +func decompress4x_8b_main_loop_amd64(ctx *decompress4xContext) + +// fallback8BitSize is the size where using Go version is faster. +const fallback8BitSize = 800 + +type decompress4xContext struct { + pbr *[4]bitReaderShifted + peekBits uint8 + out *byte + dstEvery int + tbl *dEntrySingle + decoded int + limit *byte +} + +// Decompress4X will decompress a 4X encoded stream. +// The length of the supplied input must match the end of a block exactly. +// The *capacity* of the dst slice must match the destination size of +// the uncompressed data exactly. +func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) { + if len(d.dt.single) == 0 { + return nil, errors.New("no table loaded") + } + if len(src) < 6+(4*1) { + return nil, errors.New("input too small") + } + + use8BitTables := d.actualTableLog <= 8 + if cap(dst) < fallback8BitSize && use8BitTables { + return d.decompress4X8bit(dst, src) + } + + var br [4]bitReaderShifted + // Decode "jump table" + start := 6 + for i := 0; i < 3; i++ { + length := int(src[i*2]) | (int(src[i*2+1]) << 8) + if start+length >= len(src) { + return nil, errors.New("truncated input (or invalid offset)") + } + err := br[i].init(src[start : start+length]) + if err != nil { + return nil, err + } + start += length + } + err := br[3].init(src[start:]) + if err != nil { + return nil, err + } + + // destination, offset to match first output + dstSize := cap(dst) + dst = dst[:dstSize] + out := dst + dstEvery := (dstSize + 3) / 4 + + const tlSize = 1 << tableLogMax + const tlMask = tlSize - 1 + single := d.dt.single[:tlSize] + + var decoded int + + if len(out) > 4*4 && !(br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4) { + ctx := decompress4xContext{ + pbr: &br, + peekBits: uint8((64 - d.actualTableLog) & 63), // see: bitReaderShifted.peekBitsFast() + out: &out[0], + dstEvery: dstEvery, + tbl: &single[0], + limit: &out[dstEvery-4], // Always stop decoding when first buffer gets here to avoid writing OOB on last. + } + if use8BitTables { + decompress4x_8b_main_loop_amd64(&ctx) + } else { + decompress4x_main_loop_amd64(&ctx) + } + + decoded = ctx.decoded + out = out[decoded/4:] + } + + // Decode remaining. + remainBytes := dstEvery - (decoded / 4) + for i := range br { + offset := dstEvery * i + endsAt := offset + remainBytes + if endsAt > len(out) { + endsAt = len(out) + } + br := &br[i] + bitsLeft := br.remaining() + for bitsLeft > 0 { + br.fill() + if offset >= endsAt { + return nil, errors.New("corruption detected: stream overrun 4") + } + + // Read value and increment offset. + val := br.peekBitsFast(d.actualTableLog) + v := single[val&tlMask].entry + nBits := uint8(v) + br.advance(nBits) + bitsLeft -= uint(nBits) + out[offset] = uint8(v >> 8) + offset++ + } + if offset != endsAt { + return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt) + } + decoded += offset - dstEvery*i + err = br.close() + if err != nil { + return nil, err + } + } + if dstSize != decoded { + return nil, errors.New("corruption detected: short output block") + } + return dst, nil +} + +// decompress4x_main_loop_x86 is an x86 assembler implementation +// of Decompress1X when tablelog > 8. +// +//go:noescape +func decompress1x_main_loop_amd64(ctx *decompress1xContext) + +// decompress4x_main_loop_x86 is an x86 with BMI2 assembler implementation +// of Decompress1X when tablelog > 8. +// +//go:noescape +func decompress1x_main_loop_bmi2(ctx *decompress1xContext) + +type decompress1xContext struct { + pbr *bitReaderShifted + peekBits uint8 + out *byte + outCap int + tbl *dEntrySingle + decoded int +} + +// Error reported by asm implementations +const error_max_decoded_size_exeeded = -1 + +// Decompress1X will decompress a 1X encoded stream. +// The cap of the output buffer will be the maximum decompressed size. +// The length of the supplied input must match the end of a block exactly. +func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) { + if len(d.dt.single) == 0 { + return nil, errors.New("no table loaded") + } + var br bitReaderShifted + err := br.init(src) + if err != nil { + return dst, err + } + maxDecodedSize := cap(dst) + dst = dst[:maxDecodedSize] + + const tlSize = 1 << tableLogMax + const tlMask = tlSize - 1 + + if maxDecodedSize >= 4 { + ctx := decompress1xContext{ + pbr: &br, + out: &dst[0], + outCap: maxDecodedSize, + peekBits: uint8((64 - d.actualTableLog) & 63), // see: bitReaderShifted.peekBitsFast() + tbl: &d.dt.single[0], + } + + if cpuinfo.HasBMI2() { + decompress1x_main_loop_bmi2(&ctx) + } else { + decompress1x_main_loop_amd64(&ctx) + } + if ctx.decoded == error_max_decoded_size_exeeded { + return nil, ErrMaxDecodedSizeExceeded + } + + dst = dst[:ctx.decoded] + } + + // br < 8, so uint8 is fine + bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead + for bitsLeft > 0 { + br.fill() + if len(dst) >= maxDecodedSize { + br.close() + return nil, ErrMaxDecodedSizeExceeded + } + v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask] + nBits := uint8(v.entry) + br.advance(nBits) + bitsLeft -= nBits + dst = append(dst, uint8(v.entry>>8)) + } + return dst, br.close() +} diff --git a/vendor/github.com/klauspost/compress/huff0/decompress_amd64.s b/vendor/github.com/klauspost/compress/huff0/decompress_amd64.s new file mode 100644 index 0000000..c4c7ab2 --- /dev/null +++ b/vendor/github.com/klauspost/compress/huff0/decompress_amd64.s @@ -0,0 +1,830 @@ +// Code generated by command: go run gen.go -out ../decompress_amd64.s -pkg=huff0. DO NOT EDIT. + +//go:build amd64 && !appengine && !noasm && gc + +// func decompress4x_main_loop_amd64(ctx *decompress4xContext) +TEXT ·decompress4x_main_loop_amd64(SB), $0-8 + // Preload values + MOVQ ctx+0(FP), AX + MOVBQZX 8(AX), DI + MOVQ 16(AX), BX + MOVQ 48(AX), SI + MOVQ 24(AX), R8 + MOVQ 32(AX), R9 + MOVQ (AX), R10 + + // Main loop +main_loop: + XORL DX, DX + CMPQ BX, SI + SETGE DL + + // br0.fillFast32() + MOVQ 32(R10), R11 + MOVBQZX 40(R10), R12 + CMPQ R12, $0x20 + JBE skip_fill0 + MOVQ 24(R10), AX + SUBQ $0x20, R12 + SUBQ $0x04, AX + MOVQ (R10), R13 + + // b.value |= uint64(low) << (b.bitsRead & 63) + MOVL (AX)(R13*1), R13 + MOVQ R12, CX + SHLQ CL, R13 + MOVQ AX, 24(R10) + ORQ R13, R11 + + // exhausted += (br0.off < 4) + CMPQ AX, $0x04 + ADCB $+0, DL + +skip_fill0: + // val0 := br0.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v0 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br0.advance(uint8(v0.entry) + MOVB CH, AL + SHLQ CL, R11 + ADDB CL, R12 + + // val1 := br0.peekTopBits(peekBits) + MOVQ DI, CX + MOVQ R11, R13 + SHRQ CL, R13 + + // v1 := table[val1&mask] + MOVW (R9)(R13*2), CX + + // br0.advance(uint8(v1.entry)) + MOVB CH, AH + SHLQ CL, R11 + ADDB CL, R12 + + // these two writes get coalesced + // out[id * dstEvery + 0] = uint8(v0.entry >> 8) + // out[id * dstEvery + 1] = uint8(v1.entry >> 8) + MOVW AX, (BX) + + // update the bitreader structure + MOVQ R11, 32(R10) + MOVB R12, 40(R10) + + // br1.fillFast32() + MOVQ 80(R10), R11 + MOVBQZX 88(R10), R12 + CMPQ R12, $0x20 + JBE skip_fill1 + MOVQ 72(R10), AX + SUBQ $0x20, R12 + SUBQ $0x04, AX + MOVQ 48(R10), R13 + + // b.value |= uint64(low) << (b.bitsRead & 63) + MOVL (AX)(R13*1), R13 + MOVQ R12, CX + SHLQ CL, R13 + MOVQ AX, 72(R10) + ORQ R13, R11 + + // exhausted += (br1.off < 4) + CMPQ AX, $0x04 + ADCB $+0, DL + +skip_fill1: + // val0 := br1.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v0 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br1.advance(uint8(v0.entry) + MOVB CH, AL + SHLQ CL, R11 + ADDB CL, R12 + + // val1 := br1.peekTopBits(peekBits) + MOVQ DI, CX + MOVQ R11, R13 + SHRQ CL, R13 + + // v1 := table[val1&mask] + MOVW (R9)(R13*2), CX + + // br1.advance(uint8(v1.entry)) + MOVB CH, AH + SHLQ CL, R11 + ADDB CL, R12 + + // these two writes get coalesced + // out[id * dstEvery + 0] = uint8(v0.entry >> 8) + // out[id * dstEvery + 1] = uint8(v1.entry >> 8) + MOVW AX, (BX)(R8*1) + + // update the bitreader structure + MOVQ R11, 80(R10) + MOVB R12, 88(R10) + + // br2.fillFast32() + MOVQ 128(R10), R11 + MOVBQZX 136(R10), R12 + CMPQ R12, $0x20 + JBE skip_fill2 + MOVQ 120(R10), AX + SUBQ $0x20, R12 + SUBQ $0x04, AX + MOVQ 96(R10), R13 + + // b.value |= uint64(low) << (b.bitsRead & 63) + MOVL (AX)(R13*1), R13 + MOVQ R12, CX + SHLQ CL, R13 + MOVQ AX, 120(R10) + ORQ R13, R11 + + // exhausted += (br2.off < 4) + CMPQ AX, $0x04 + ADCB $+0, DL + +skip_fill2: + // val0 := br2.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v0 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br2.advance(uint8(v0.entry) + MOVB CH, AL + SHLQ CL, R11 + ADDB CL, R12 + + // val1 := br2.peekTopBits(peekBits) + MOVQ DI, CX + MOVQ R11, R13 + SHRQ CL, R13 + + // v1 := table[val1&mask] + MOVW (R9)(R13*2), CX + + // br2.advance(uint8(v1.entry)) + MOVB CH, AH + SHLQ CL, R11 + ADDB CL, R12 + + // these two writes get coalesced + // out[id * dstEvery + 0] = uint8(v0.entry >> 8) + // out[id * dstEvery + 1] = uint8(v1.entry >> 8) + MOVW AX, (BX)(R8*2) + + // update the bitreader structure + MOVQ R11, 128(R10) + MOVB R12, 136(R10) + + // br3.fillFast32() + MOVQ 176(R10), R11 + MOVBQZX 184(R10), R12 + CMPQ R12, $0x20 + JBE skip_fill3 + MOVQ 168(R10), AX + SUBQ $0x20, R12 + SUBQ $0x04, AX + MOVQ 144(R10), R13 + + // b.value |= uint64(low) << (b.bitsRead & 63) + MOVL (AX)(R13*1), R13 + MOVQ R12, CX + SHLQ CL, R13 + MOVQ AX, 168(R10) + ORQ R13, R11 + + // exhausted += (br3.off < 4) + CMPQ AX, $0x04 + ADCB $+0, DL + +skip_fill3: + // val0 := br3.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v0 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br3.advance(uint8(v0.entry) + MOVB CH, AL + SHLQ CL, R11 + ADDB CL, R12 + + // val1 := br3.peekTopBits(peekBits) + MOVQ DI, CX + MOVQ R11, R13 + SHRQ CL, R13 + + // v1 := table[val1&mask] + MOVW (R9)(R13*2), CX + + // br3.advance(uint8(v1.entry)) + MOVB CH, AH + SHLQ CL, R11 + ADDB CL, R12 + + // these two writes get coalesced + // out[id * dstEvery + 0] = uint8(v0.entry >> 8) + // out[id * dstEvery + 1] = uint8(v1.entry >> 8) + LEAQ (R8)(R8*2), CX + MOVW AX, (BX)(CX*1) + + // update the bitreader structure + MOVQ R11, 176(R10) + MOVB R12, 184(R10) + ADDQ $0x02, BX + TESTB DL, DL + JZ main_loop + MOVQ ctx+0(FP), AX + SUBQ 16(AX), BX + SHLQ $0x02, BX + MOVQ BX, 40(AX) + RET + +// func decompress4x_8b_main_loop_amd64(ctx *decompress4xContext) +TEXT ·decompress4x_8b_main_loop_amd64(SB), $0-8 + // Preload values + MOVQ ctx+0(FP), CX + MOVBQZX 8(CX), DI + MOVQ 16(CX), BX + MOVQ 48(CX), SI + MOVQ 24(CX), R8 + MOVQ 32(CX), R9 + MOVQ (CX), R10 + + // Main loop +main_loop: + XORL DX, DX + CMPQ BX, SI + SETGE DL + + // br0.fillFast32() + MOVQ 32(R10), R11 + MOVBQZX 40(R10), R12 + CMPQ R12, $0x20 + JBE skip_fill0 + MOVQ 24(R10), R13 + SUBQ $0x20, R12 + SUBQ $0x04, R13 + MOVQ (R10), R14 + + // b.value |= uint64(low) << (b.bitsRead & 63) + MOVL (R13)(R14*1), R14 + MOVQ R12, CX + SHLQ CL, R14 + MOVQ R13, 24(R10) + ORQ R14, R11 + + // exhausted += (br0.off < 4) + CMPQ R13, $0x04 + ADCB $+0, DL + +skip_fill0: + // val0 := br0.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v0 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br0.advance(uint8(v0.entry) + MOVB CH, AL + SHLQ CL, R11 + ADDB CL, R12 + + // val1 := br0.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v1 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br0.advance(uint8(v1.entry) + MOVB CH, AH + SHLQ CL, R11 + ADDB CL, R12 + BSWAPL AX + + // val2 := br0.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v2 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br0.advance(uint8(v2.entry) + MOVB CH, AH + SHLQ CL, R11 + ADDB CL, R12 + + // val3 := br0.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v3 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br0.advance(uint8(v3.entry) + MOVB CH, AL + SHLQ CL, R11 + ADDB CL, R12 + BSWAPL AX + + // these four writes get coalesced + // out[id * dstEvery + 0] = uint8(v0.entry >> 8) + // out[id * dstEvery + 1] = uint8(v1.entry >> 8) + // out[id * dstEvery + 3] = uint8(v2.entry >> 8) + // out[id * dstEvery + 4] = uint8(v3.entry >> 8) + MOVL AX, (BX) + + // update the bitreader structure + MOVQ R11, 32(R10) + MOVB R12, 40(R10) + + // br1.fillFast32() + MOVQ 80(R10), R11 + MOVBQZX 88(R10), R12 + CMPQ R12, $0x20 + JBE skip_fill1 + MOVQ 72(R10), R13 + SUBQ $0x20, R12 + SUBQ $0x04, R13 + MOVQ 48(R10), R14 + + // b.value |= uint64(low) << (b.bitsRead & 63) + MOVL (R13)(R14*1), R14 + MOVQ R12, CX + SHLQ CL, R14 + MOVQ R13, 72(R10) + ORQ R14, R11 + + // exhausted += (br1.off < 4) + CMPQ R13, $0x04 + ADCB $+0, DL + +skip_fill1: + // val0 := br1.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v0 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br1.advance(uint8(v0.entry) + MOVB CH, AL + SHLQ CL, R11 + ADDB CL, R12 + + // val1 := br1.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v1 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br1.advance(uint8(v1.entry) + MOVB CH, AH + SHLQ CL, R11 + ADDB CL, R12 + BSWAPL AX + + // val2 := br1.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v2 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br1.advance(uint8(v2.entry) + MOVB CH, AH + SHLQ CL, R11 + ADDB CL, R12 + + // val3 := br1.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v3 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br1.advance(uint8(v3.entry) + MOVB CH, AL + SHLQ CL, R11 + ADDB CL, R12 + BSWAPL AX + + // these four writes get coalesced + // out[id * dstEvery + 0] = uint8(v0.entry >> 8) + // out[id * dstEvery + 1] = uint8(v1.entry >> 8) + // out[id * dstEvery + 3] = uint8(v2.entry >> 8) + // out[id * dstEvery + 4] = uint8(v3.entry >> 8) + MOVL AX, (BX)(R8*1) + + // update the bitreader structure + MOVQ R11, 80(R10) + MOVB R12, 88(R10) + + // br2.fillFast32() + MOVQ 128(R10), R11 + MOVBQZX 136(R10), R12 + CMPQ R12, $0x20 + JBE skip_fill2 + MOVQ 120(R10), R13 + SUBQ $0x20, R12 + SUBQ $0x04, R13 + MOVQ 96(R10), R14 + + // b.value |= uint64(low) << (b.bitsRead & 63) + MOVL (R13)(R14*1), R14 + MOVQ R12, CX + SHLQ CL, R14 + MOVQ R13, 120(R10) + ORQ R14, R11 + + // exhausted += (br2.off < 4) + CMPQ R13, $0x04 + ADCB $+0, DL + +skip_fill2: + // val0 := br2.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v0 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br2.advance(uint8(v0.entry) + MOVB CH, AL + SHLQ CL, R11 + ADDB CL, R12 + + // val1 := br2.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v1 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br2.advance(uint8(v1.entry) + MOVB CH, AH + SHLQ CL, R11 + ADDB CL, R12 + BSWAPL AX + + // val2 := br2.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v2 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br2.advance(uint8(v2.entry) + MOVB CH, AH + SHLQ CL, R11 + ADDB CL, R12 + + // val3 := br2.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v3 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br2.advance(uint8(v3.entry) + MOVB CH, AL + SHLQ CL, R11 + ADDB CL, R12 + BSWAPL AX + + // these four writes get coalesced + // out[id * dstEvery + 0] = uint8(v0.entry >> 8) + // out[id * dstEvery + 1] = uint8(v1.entry >> 8) + // out[id * dstEvery + 3] = uint8(v2.entry >> 8) + // out[id * dstEvery + 4] = uint8(v3.entry >> 8) + MOVL AX, (BX)(R8*2) + + // update the bitreader structure + MOVQ R11, 128(R10) + MOVB R12, 136(R10) + + // br3.fillFast32() + MOVQ 176(R10), R11 + MOVBQZX 184(R10), R12 + CMPQ R12, $0x20 + JBE skip_fill3 + MOVQ 168(R10), R13 + SUBQ $0x20, R12 + SUBQ $0x04, R13 + MOVQ 144(R10), R14 + + // b.value |= uint64(low) << (b.bitsRead & 63) + MOVL (R13)(R14*1), R14 + MOVQ R12, CX + SHLQ CL, R14 + MOVQ R13, 168(R10) + ORQ R14, R11 + + // exhausted += (br3.off < 4) + CMPQ R13, $0x04 + ADCB $+0, DL + +skip_fill3: + // val0 := br3.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v0 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br3.advance(uint8(v0.entry) + MOVB CH, AL + SHLQ CL, R11 + ADDB CL, R12 + + // val1 := br3.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v1 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br3.advance(uint8(v1.entry) + MOVB CH, AH + SHLQ CL, R11 + ADDB CL, R12 + BSWAPL AX + + // val2 := br3.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v2 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br3.advance(uint8(v2.entry) + MOVB CH, AH + SHLQ CL, R11 + ADDB CL, R12 + + // val3 := br3.peekTopBits(peekBits) + MOVQ R11, R13 + MOVQ DI, CX + SHRQ CL, R13 + + // v3 := table[val0&mask] + MOVW (R9)(R13*2), CX + + // br3.advance(uint8(v3.entry) + MOVB CH, AL + SHLQ CL, R11 + ADDB CL, R12 + BSWAPL AX + + // these four writes get coalesced + // out[id * dstEvery + 0] = uint8(v0.entry >> 8) + // out[id * dstEvery + 1] = uint8(v1.entry >> 8) + // out[id * dstEvery + 3] = uint8(v2.entry >> 8) + // out[id * dstEvery + 4] = uint8(v3.entry >> 8) + LEAQ (R8)(R8*2), CX + MOVL AX, (BX)(CX*1) + + // update the bitreader structure + MOVQ R11, 176(R10) + MOVB R12, 184(R10) + ADDQ $0x04, BX + TESTB DL, DL + JZ main_loop + MOVQ ctx+0(FP), AX + SUBQ 16(AX), BX + SHLQ $0x02, BX + MOVQ BX, 40(AX) + RET + +// func decompress1x_main_loop_amd64(ctx *decompress1xContext) +TEXT ·decompress1x_main_loop_amd64(SB), $0-8 + MOVQ ctx+0(FP), CX + MOVQ 16(CX), DX + MOVQ 24(CX), BX + CMPQ BX, $0x04 + JB error_max_decoded_size_exceeded + LEAQ (DX)(BX*1), BX + MOVQ (CX), SI + MOVQ (SI), R8 + MOVQ 24(SI), R9 + MOVQ 32(SI), R10 + MOVBQZX 40(SI), R11 + MOVQ 32(CX), SI + MOVBQZX 8(CX), DI + JMP loop_condition + +main_loop: + // Check if we have room for 4 bytes in the output buffer + LEAQ 4(DX), CX + CMPQ CX, BX + JGE error_max_decoded_size_exceeded + + // Decode 4 values + CMPQ R11, $0x20 + JL bitReader_fillFast_1_end + SUBQ $0x20, R11 + SUBQ $0x04, R9 + MOVL (R8)(R9*1), R12 + MOVQ R11, CX + SHLQ CL, R12 + ORQ R12, R10 + +bitReader_fillFast_1_end: + MOVQ DI, CX + MOVQ R10, R12 + SHRQ CL, R12 + MOVW (SI)(R12*2), CX + MOVB CH, AL + MOVBQZX CL, CX + ADDQ CX, R11 + SHLQ CL, R10 + MOVQ DI, CX + MOVQ R10, R12 + SHRQ CL, R12 + MOVW (SI)(R12*2), CX + MOVB CH, AH + MOVBQZX CL, CX + ADDQ CX, R11 + SHLQ CL, R10 + BSWAPL AX + CMPQ R11, $0x20 + JL bitReader_fillFast_2_end + SUBQ $0x20, R11 + SUBQ $0x04, R9 + MOVL (R8)(R9*1), R12 + MOVQ R11, CX + SHLQ CL, R12 + ORQ R12, R10 + +bitReader_fillFast_2_end: + MOVQ DI, CX + MOVQ R10, R12 + SHRQ CL, R12 + MOVW (SI)(R12*2), CX + MOVB CH, AH + MOVBQZX CL, CX + ADDQ CX, R11 + SHLQ CL, R10 + MOVQ DI, CX + MOVQ R10, R12 + SHRQ CL, R12 + MOVW (SI)(R12*2), CX + MOVB CH, AL + MOVBQZX CL, CX + ADDQ CX, R11 + SHLQ CL, R10 + BSWAPL AX + + // Store the decoded values + MOVL AX, (DX) + ADDQ $0x04, DX + +loop_condition: + CMPQ R9, $0x08 + JGE main_loop + + // Update ctx structure + MOVQ ctx+0(FP), AX + SUBQ 16(AX), DX + MOVQ DX, 40(AX) + MOVQ (AX), AX + MOVQ R9, 24(AX) + MOVQ R10, 32(AX) + MOVB R11, 40(AX) + RET + + // Report error +error_max_decoded_size_exceeded: + MOVQ ctx+0(FP), AX + MOVQ $-1, CX + MOVQ CX, 40(AX) + RET + +// func decompress1x_main_loop_bmi2(ctx *decompress1xContext) +// Requires: BMI2 +TEXT ·decompress1x_main_loop_bmi2(SB), $0-8 + MOVQ ctx+0(FP), CX + MOVQ 16(CX), DX + MOVQ 24(CX), BX + CMPQ BX, $0x04 + JB error_max_decoded_size_exceeded + LEAQ (DX)(BX*1), BX + MOVQ (CX), SI + MOVQ (SI), R8 + MOVQ 24(SI), R9 + MOVQ 32(SI), R10 + MOVBQZX 40(SI), R11 + MOVQ 32(CX), SI + MOVBQZX 8(CX), DI + JMP loop_condition + +main_loop: + // Check if we have room for 4 bytes in the output buffer + LEAQ 4(DX), CX + CMPQ CX, BX + JGE error_max_decoded_size_exceeded + + // Decode 4 values + CMPQ R11, $0x20 + JL bitReader_fillFast_1_end + SUBQ $0x20, R11 + SUBQ $0x04, R9 + MOVL (R8)(R9*1), CX + SHLXQ R11, CX, CX + ORQ CX, R10 + +bitReader_fillFast_1_end: + SHRXQ DI, R10, CX + MOVW (SI)(CX*2), CX + MOVB CH, AL + MOVBQZX CL, CX + ADDQ CX, R11 + SHLXQ CX, R10, R10 + SHRXQ DI, R10, CX + MOVW (SI)(CX*2), CX + MOVB CH, AH + MOVBQZX CL, CX + ADDQ CX, R11 + SHLXQ CX, R10, R10 + BSWAPL AX + CMPQ R11, $0x20 + JL bitReader_fillFast_2_end + SUBQ $0x20, R11 + SUBQ $0x04, R9 + MOVL (R8)(R9*1), CX + SHLXQ R11, CX, CX + ORQ CX, R10 + +bitReader_fillFast_2_end: + SHRXQ DI, R10, CX + MOVW (SI)(CX*2), CX + MOVB CH, AH + MOVBQZX CL, CX + ADDQ CX, R11 + SHLXQ CX, R10, R10 + SHRXQ DI, R10, CX + MOVW (SI)(CX*2), CX + MOVB CH, AL + MOVBQZX CL, CX + ADDQ CX, R11 + SHLXQ CX, R10, R10 + BSWAPL AX + + // Store the decoded values + MOVL AX, (DX) + ADDQ $0x04, DX + +loop_condition: + CMPQ R9, $0x08 + JGE main_loop + + // Update ctx structure + MOVQ ctx+0(FP), AX + SUBQ 16(AX), DX + MOVQ DX, 40(AX) + MOVQ (AX), AX + MOVQ R9, 24(AX) + MOVQ R10, 32(AX) + MOVB R11, 40(AX) + RET + + // Report error +error_max_decoded_size_exceeded: + MOVQ ctx+0(FP), AX + MOVQ $-1, CX + MOVQ CX, 40(AX) + RET diff --git a/vendor/github.com/klauspost/compress/huff0/decompress_generic.go b/vendor/github.com/klauspost/compress/huff0/decompress_generic.go new file mode 100644 index 0000000..908c17d --- /dev/null +++ b/vendor/github.com/klauspost/compress/huff0/decompress_generic.go @@ -0,0 +1,299 @@ +//go:build !amd64 || appengine || !gc || noasm +// +build !amd64 appengine !gc noasm + +// This file contains a generic implementation of Decoder.Decompress4X. +package huff0 + +import ( + "errors" + "fmt" +) + +// Decompress4X will decompress a 4X encoded stream. +// The length of the supplied input must match the end of a block exactly. +// The *capacity* of the dst slice must match the destination size of +// the uncompressed data exactly. +func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) { + if len(d.dt.single) == 0 { + return nil, errors.New("no table loaded") + } + if len(src) < 6+(4*1) { + return nil, errors.New("input too small") + } + if use8BitTables && d.actualTableLog <= 8 { + return d.decompress4X8bit(dst, src) + } + + var br [4]bitReaderShifted + // Decode "jump table" + start := 6 + for i := 0; i < 3; i++ { + length := int(src[i*2]) | (int(src[i*2+1]) << 8) + if start+length >= len(src) { + return nil, errors.New("truncated input (or invalid offset)") + } + err := br[i].init(src[start : start+length]) + if err != nil { + return nil, err + } + start += length + } + err := br[3].init(src[start:]) + if err != nil { + return nil, err + } + + // destination, offset to match first output + dstSize := cap(dst) + dst = dst[:dstSize] + out := dst + dstEvery := (dstSize + 3) / 4 + + const tlSize = 1 << tableLogMax + const tlMask = tlSize - 1 + single := d.dt.single[:tlSize] + + // Use temp table to avoid bound checks/append penalty. + buf := d.buffer() + var off uint8 + var decoded int + + // Decode 2 values from each decoder/loop. + const bufoff = 256 + for { + if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 { + break + } + + { + const stream = 0 + const stream2 = 1 + br[stream].fillFast() + br[stream2].fillFast() + + val := br[stream].peekBitsFast(d.actualTableLog) + val2 := br[stream2].peekBitsFast(d.actualTableLog) + v := single[val&tlMask] + v2 := single[val2&tlMask] + br[stream].advance(uint8(v.entry)) + br[stream2].advance(uint8(v2.entry)) + buf[stream][off] = uint8(v.entry >> 8) + buf[stream2][off] = uint8(v2.entry >> 8) + + val = br[stream].peekBitsFast(d.actualTableLog) + val2 = br[stream2].peekBitsFast(d.actualTableLog) + v = single[val&tlMask] + v2 = single[val2&tlMask] + br[stream].advance(uint8(v.entry)) + br[stream2].advance(uint8(v2.entry)) + buf[stream][off+1] = uint8(v.entry >> 8) + buf[stream2][off+1] = uint8(v2.entry >> 8) + } + + { + const stream = 2 + const stream2 = 3 + br[stream].fillFast() + br[stream2].fillFast() + + val := br[stream].peekBitsFast(d.actualTableLog) + val2 := br[stream2].peekBitsFast(d.actualTableLog) + v := single[val&tlMask] + v2 := single[val2&tlMask] + br[stream].advance(uint8(v.entry)) + br[stream2].advance(uint8(v2.entry)) + buf[stream][off] = uint8(v.entry >> 8) + buf[stream2][off] = uint8(v2.entry >> 8) + + val = br[stream].peekBitsFast(d.actualTableLog) + val2 = br[stream2].peekBitsFast(d.actualTableLog) + v = single[val&tlMask] + v2 = single[val2&tlMask] + br[stream].advance(uint8(v.entry)) + br[stream2].advance(uint8(v2.entry)) + buf[stream][off+1] = uint8(v.entry >> 8) + buf[stream2][off+1] = uint8(v2.entry >> 8) + } + + off += 2 + + if off == 0 { + if bufoff > dstEvery { + d.bufs.Put(buf) + return nil, errors.New("corruption detected: stream overrun 1") + } + // There must at least be 3 buffers left. + if len(out)-bufoff < dstEvery*3 { + d.bufs.Put(buf) + return nil, errors.New("corruption detected: stream overrun 2") + } + //copy(out, buf[0][:]) + //copy(out[dstEvery:], buf[1][:]) + //copy(out[dstEvery*2:], buf[2][:]) + //copy(out[dstEvery*3:], buf[3][:]) + *(*[bufoff]byte)(out) = buf[0] + *(*[bufoff]byte)(out[dstEvery:]) = buf[1] + *(*[bufoff]byte)(out[dstEvery*2:]) = buf[2] + *(*[bufoff]byte)(out[dstEvery*3:]) = buf[3] + out = out[bufoff:] + decoded += bufoff * 4 + } + } + if off > 0 { + ioff := int(off) + if len(out) < dstEvery*3+ioff { + d.bufs.Put(buf) + return nil, errors.New("corruption detected: stream overrun 3") + } + copy(out, buf[0][:off]) + copy(out[dstEvery:], buf[1][:off]) + copy(out[dstEvery*2:], buf[2][:off]) + copy(out[dstEvery*3:], buf[3][:off]) + decoded += int(off) * 4 + out = out[off:] + } + + // Decode remaining. + remainBytes := dstEvery - (decoded / 4) + for i := range br { + offset := dstEvery * i + endsAt := offset + remainBytes + if endsAt > len(out) { + endsAt = len(out) + } + br := &br[i] + bitsLeft := br.remaining() + for bitsLeft > 0 { + br.fill() + if offset >= endsAt { + d.bufs.Put(buf) + return nil, errors.New("corruption detected: stream overrun 4") + } + + // Read value and increment offset. + val := br.peekBitsFast(d.actualTableLog) + v := single[val&tlMask].entry + nBits := uint8(v) + br.advance(nBits) + bitsLeft -= uint(nBits) + out[offset] = uint8(v >> 8) + offset++ + } + if offset != endsAt { + d.bufs.Put(buf) + return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt) + } + decoded += offset - dstEvery*i + err = br.close() + if err != nil { + return nil, err + } + } + d.bufs.Put(buf) + if dstSize != decoded { + return nil, errors.New("corruption detected: short output block") + } + return dst, nil +} + +// Decompress1X will decompress a 1X encoded stream. +// The cap of the output buffer will be the maximum decompressed size. +// The length of the supplied input must match the end of a block exactly. +func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) { + if len(d.dt.single) == 0 { + return nil, errors.New("no table loaded") + } + if use8BitTables && d.actualTableLog <= 8 { + return d.decompress1X8Bit(dst, src) + } + var br bitReaderShifted + err := br.init(src) + if err != nil { + return dst, err + } + maxDecodedSize := cap(dst) + dst = dst[:0] + + // Avoid bounds check by always having full sized table. + const tlSize = 1 << tableLogMax + const tlMask = tlSize - 1 + dt := d.dt.single[:tlSize] + + // Use temp table to avoid bound checks/append penalty. + bufs := d.buffer() + buf := &bufs[0] + var off uint8 + + for br.off >= 8 { + br.fillFast() + v := dt[br.peekBitsFast(d.actualTableLog)&tlMask] + br.advance(uint8(v.entry)) + buf[off+0] = uint8(v.entry >> 8) + + v = dt[br.peekBitsFast(d.actualTableLog)&tlMask] + br.advance(uint8(v.entry)) + buf[off+1] = uint8(v.entry >> 8) + + // Refill + br.fillFast() + + v = dt[br.peekBitsFast(d.actualTableLog)&tlMask] + br.advance(uint8(v.entry)) + buf[off+2] = uint8(v.entry >> 8) + + v = dt[br.peekBitsFast(d.actualTableLog)&tlMask] + br.advance(uint8(v.entry)) + buf[off+3] = uint8(v.entry >> 8) + + off += 4 + if off == 0 { + if len(dst)+256 > maxDecodedSize { + br.close() + d.bufs.Put(bufs) + return nil, ErrMaxDecodedSizeExceeded + } + dst = append(dst, buf[:]...) + } + } + + if len(dst)+int(off) > maxDecodedSize { + d.bufs.Put(bufs) + br.close() + return nil, ErrMaxDecodedSizeExceeded + } + dst = append(dst, buf[:off]...) + + // br < 8, so uint8 is fine + bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead + for bitsLeft > 0 { + br.fill() + if false && br.bitsRead >= 32 { + if br.off >= 4 { + v := br.in[br.off-4:] + v = v[:4] + low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) + br.value = (br.value << 32) | uint64(low) + br.bitsRead -= 32 + br.off -= 4 + } else { + for br.off > 0 { + br.value = (br.value << 8) | uint64(br.in[br.off-1]) + br.bitsRead -= 8 + br.off-- + } + } + } + if len(dst) >= maxDecodedSize { + d.bufs.Put(bufs) + br.close() + return nil, ErrMaxDecodedSizeExceeded + } + v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask] + nBits := uint8(v.entry) + br.advance(nBits) + bitsLeft -= nBits + dst = append(dst, uint8(v.entry>>8)) + } + d.bufs.Put(bufs) + return dst, br.close() +} diff --git a/vendor/github.com/klauspost/compress/huff0/huff0.go b/vendor/github.com/klauspost/compress/huff0/huff0.go index 3ee00ec..e8ad17a 100644 --- a/vendor/github.com/klauspost/compress/huff0/huff0.go +++ b/vendor/github.com/klauspost/compress/huff0/huff0.go @@ -8,6 +8,7 @@ import ( "fmt" "math" "math/bits" + "sync" "github.com/klauspost/compress/fse" ) @@ -116,6 +117,7 @@ type Scratch struct { nodes []nodeElt tmpOut [4][]byte fse *fse.Scratch + decPool sync.Pool // *[4][256]byte buffers. huffWeight [maxSymbolValue + 1]byte } diff --git a/vendor/github.com/klauspost/compress/internal/cpuinfo/cpuinfo.go b/vendor/github.com/klauspost/compress/internal/cpuinfo/cpuinfo.go new file mode 100644 index 0000000..3954c51 --- /dev/null +++ b/vendor/github.com/klauspost/compress/internal/cpuinfo/cpuinfo.go @@ -0,0 +1,34 @@ +// Package cpuinfo gives runtime info about the current CPU. +// +// This is a very limited module meant for use internally +// in this project. For more versatile solution check +// https://github.com/klauspost/cpuid. +package cpuinfo + +// HasBMI1 checks whether an x86 CPU supports the BMI1 extension. +func HasBMI1() bool { + return hasBMI1 +} + +// HasBMI2 checks whether an x86 CPU supports the BMI2 extension. +func HasBMI2() bool { + return hasBMI2 +} + +// DisableBMI2 will disable BMI2, for testing purposes. +// Call returned function to restore previous state. +func DisableBMI2() func() { + old := hasBMI2 + hasBMI2 = false + return func() { + hasBMI2 = old + } +} + +// HasBMI checks whether an x86 CPU supports both BMI1 and BMI2 extensions. +func HasBMI() bool { + return HasBMI1() && HasBMI2() +} + +var hasBMI1 bool +var hasBMI2 bool diff --git a/vendor/github.com/klauspost/compress/internal/cpuinfo/cpuinfo_amd64.go b/vendor/github.com/klauspost/compress/internal/cpuinfo/cpuinfo_amd64.go new file mode 100644 index 0000000..e802579 --- /dev/null +++ b/vendor/github.com/klauspost/compress/internal/cpuinfo/cpuinfo_amd64.go @@ -0,0 +1,11 @@ +//go:build amd64 && !appengine && !noasm && gc +// +build amd64,!appengine,!noasm,gc + +package cpuinfo + +// go:noescape +func x86extensions() (bmi1, bmi2 bool) + +func init() { + hasBMI1, hasBMI2 = x86extensions() +} diff --git a/vendor/github.com/klauspost/compress/internal/cpuinfo/cpuinfo_amd64.s b/vendor/github.com/klauspost/compress/internal/cpuinfo/cpuinfo_amd64.s new file mode 100644 index 0000000..4465fbe --- /dev/null +++ b/vendor/github.com/klauspost/compress/internal/cpuinfo/cpuinfo_amd64.s @@ -0,0 +1,36 @@ +// +build !appengine +// +build gc +// +build !noasm + +#include "textflag.h" +#include "funcdata.h" +#include "go_asm.h" + +TEXT ·x86extensions(SB), NOSPLIT, $0 + // 1. determine max EAX value + XORQ AX, AX + CPUID + + CMPQ AX, $7 + JB unsupported + + // 2. EAX = 7, ECX = 0 --- see Table 3-8 "Information Returned by CPUID Instruction" + MOVQ $7, AX + MOVQ $0, CX + CPUID + + BTQ $3, BX // bit 3 = BMI1 + SETCS AL + + BTQ $8, BX // bit 8 = BMI2 + SETCS AH + + MOVB AL, bmi1+0(FP) + MOVB AH, bmi2+1(FP) + RET + +unsupported: + XORQ AX, AX + MOVB AL, bmi1+0(FP) + MOVB AL, bmi2+1(FP) + RET diff --git a/vendor/github.com/klauspost/compress/internal/snapref/encode_other.go b/vendor/github.com/klauspost/compress/internal/snapref/encode_other.go index 511bba6..2aa6a95 100644 --- a/vendor/github.com/klauspost/compress/internal/snapref/encode_other.go +++ b/vendor/github.com/klauspost/compress/internal/snapref/encode_other.go @@ -18,6 +18,7 @@ func load64(b []byte, i int) uint64 { // emitLiteral writes a literal chunk and returns the number of bytes written. // // It assumes that: +// // dst is long enough to hold the encoded bytes // 1 <= len(lit) && len(lit) <= 65536 func emitLiteral(dst, lit []byte) int { @@ -42,6 +43,7 @@ func emitLiteral(dst, lit []byte) int { // emitCopy writes a copy chunk and returns the number of bytes written. // // It assumes that: +// // dst is long enough to hold the encoded bytes // 1 <= offset && offset <= 65535 // 4 <= length && length <= 65535 @@ -85,28 +87,40 @@ func emitCopy(dst []byte, offset, length int) int { return i + 2 } -// extendMatch returns the largest k such that k <= len(src) and that -// src[i:i+k-j] and src[j:k] have the same contents. -// -// It assumes that: -// 0 <= i && i < j && j <= len(src) -func extendMatch(src []byte, i, j int) int { - for ; j < len(src) && src[i] == src[j]; i, j = i+1, j+1 { - } - return j -} - func hash(u, shift uint32) uint32 { return (u * 0x1e35a7bd) >> shift } +// EncodeBlockInto exposes encodeBlock but checks dst size. +func EncodeBlockInto(dst, src []byte) (d int) { + if MaxEncodedLen(len(src)) > len(dst) { + return 0 + } + + // encodeBlock breaks on too big blocks, so split. + for len(src) > 0 { + p := src + src = nil + if len(p) > maxBlockSize { + p, src = p[:maxBlockSize], p[maxBlockSize:] + } + if len(p) < minNonLiteralBlockSize { + d += emitLiteral(dst[d:], p) + } else { + d += encodeBlock(dst[d:], p) + } + } + return d +} + // encodeBlock encodes a non-empty src to a guaranteed-large-enough dst. It // assumes that the varint-encoded length of the decompressed bytes has already // been written. // // It also assumes that: +// // len(dst) >= MaxEncodedLen(len(src)) && -// minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize +// minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize func encodeBlock(dst, src []byte) (d int) { // Initialize the hash table. Its size ranges from 1<<8 to 1<<14 inclusive. // The table element type is uint16, as s < sLimit and sLimit < len(src) diff --git a/vendor/github.com/klauspost/compress/zstd/README.md b/vendor/github.com/klauspost/compress/zstd/README.md index c8f0f16..bdd49c8 100644 --- a/vendor/github.com/klauspost/compress/zstd/README.md +++ b/vendor/github.com/klauspost/compress/zstd/README.md @@ -12,6 +12,8 @@ The `zstd` package is provided as open source software using a Go standard licen Currently the package is heavily optimized for 64 bit processors and will be significantly slower on 32 bit processors. +For seekable zstd streams, see [this excellent package](https://github.com/SaveTheRbtz/zstd-seekable-format-go). + ## Installation Install using `go get -u github.com/klauspost/compress`. The package is located in `github.com/klauspost/compress/zstd`. @@ -78,6 +80,9 @@ of a stream. This is independent of the `WithEncoderConcurrency(n)`, but that is in the future. So if you want to limit concurrency for future updates, specify the concurrency you would like. +If you would like stream encoding to be done without spawning async goroutines, use `WithEncoderConcurrency(1)` +which will compress input as each block is completed, blocking on writes until each has completed. + You can specify your desired compression level using `WithEncoderLevel()` option. Currently only pre-defined compression settings can be specified. @@ -104,7 +109,8 @@ and seems to ignore concatenated streams, even though [it is part of the spec](h For compressing small blocks, the returned encoder has a function called `EncodeAll(src, dst []byte) []byte`. `EncodeAll` will encode all input in src and append it to dst. -This function can be called concurrently, but each call will only run on a single goroutine. +This function can be called concurrently. +Each call will only run on a same goroutine as the caller. Encoded blocks can be concatenated and the result will be the combined input stream. Data compressed with EncodeAll can be decoded with the Decoder, using either a stream or `DecodeAll`. @@ -149,10 +155,10 @@ http://sun.aei.polsl.pl/~sdeor/corpus/silesia.zip This package: file out level insize outsize millis mb/s -silesia.tar zskp 1 211947520 73101992 643 313.87 -silesia.tar zskp 2 211947520 67504318 969 208.38 -silesia.tar zskp 3 211947520 64595893 2007 100.68 -silesia.tar zskp 4 211947520 60995370 8825 22.90 +silesia.tar zskp 1 211947520 73821326 634 318.47 +silesia.tar zskp 2 211947520 67655404 1508 133.96 +silesia.tar zskp 3 211947520 64746933 3000 67.37 +silesia.tar zskp 4 211947520 60073508 16926 11.94 cgo zstd: silesia.tar zstd 1 211947520 73605392 543 371.56 @@ -161,94 +167,94 @@ silesia.tar zstd 6 211947520 62916450 1913 105.66 silesia.tar zstd 9 211947520 60212393 5063 39.92 gzip, stdlib/this package: -silesia.tar gzstd 1 211947520 80007735 1654 122.21 -silesia.tar gzkp 1 211947520 80136201 1152 175.45 +silesia.tar gzstd 1 211947520 80007735 1498 134.87 +silesia.tar gzkp 1 211947520 80088272 1009 200.31 GOB stream of binary data. Highly compressible. https://files.klauspost.com/compress/gob-stream.7z file out level insize outsize millis mb/s -gob-stream zskp 1 1911399616 235022249 3088 590.30 -gob-stream zskp 2 1911399616 205669791 3786 481.34 -gob-stream zskp 3 1911399616 175034659 9636 189.17 -gob-stream zskp 4 1911399616 165609838 50369 36.19 +gob-stream zskp 1 1911399616 233948096 3230 564.34 +gob-stream zskp 2 1911399616 203997694 4997 364.73 +gob-stream zskp 3 1911399616 173526523 13435 135.68 +gob-stream zskp 4 1911399616 162195235 47559 38.33 gob-stream zstd 1 1911399616 249810424 2637 691.26 gob-stream zstd 3 1911399616 208192146 3490 522.31 gob-stream zstd 6 1911399616 193632038 6687 272.56 gob-stream zstd 9 1911399616 177620386 16175 112.70 -gob-stream gzstd 1 1911399616 357382641 10251 177.82 -gob-stream gzkp 1 1911399616 359753026 5438 335.20 +gob-stream gzstd 1 1911399616 357382013 9046 201.49 +gob-stream gzkp 1 1911399616 359136669 4885 373.08 The test data for the Large Text Compression Benchmark is the first 10^9 bytes of the English Wikipedia dump on Mar. 3, 2006. http://mattmahoney.net/dc/textdata.html file out level insize outsize millis mb/s -enwik9 zskp 1 1000000000 343848582 3609 264.18 -enwik9 zskp 2 1000000000 317276632 5746 165.97 -enwik9 zskp 3 1000000000 292243069 12162 78.41 -enwik9 zskp 4 1000000000 262183768 82837 11.51 +enwik9 zskp 1 1000000000 343833605 3687 258.64 +enwik9 zskp 2 1000000000 317001237 7672 124.29 +enwik9 zskp 3 1000000000 291915823 15923 59.89 +enwik9 zskp 4 1000000000 261710291 77697 12.27 enwik9 zstd 1 1000000000 358072021 3110 306.65 enwik9 zstd 3 1000000000 313734672 4784 199.35 enwik9 zstd 6 1000000000 295138875 10290 92.68 enwik9 zstd 9 1000000000 278348700 28549 33.40 -enwik9 gzstd 1 1000000000 382578136 9604 99.30 -enwik9 gzkp 1 1000000000 383825945 6544 145.73 +enwik9 gzstd 1 1000000000 382578136 8608 110.78 +enwik9 gzkp 1 1000000000 382781160 5628 169.45 Highly compressible JSON file. https://files.klauspost.com/compress/github-june-2days-2019.json.zst file out level insize outsize millis mb/s -github-june-2days-2019.json zskp 1 6273951764 699045015 10620 563.40 -github-june-2days-2019.json zskp 2 6273951764 617881763 11687 511.96 -github-june-2days-2019.json zskp 3 6273951764 524340691 34043 175.75 -github-june-2days-2019.json zskp 4 6273951764 470320075 170190 35.16 +github-june-2days-2019.json zskp 1 6273951764 697439532 9789 611.17 +github-june-2days-2019.json zskp 2 6273951764 610876538 18553 322.49 +github-june-2days-2019.json zskp 3 6273951764 517662858 44186 135.41 +github-june-2days-2019.json zskp 4 6273951764 464617114 165373 36.18 github-june-2days-2019.json zstd 1 6273951764 766284037 8450 708.00 github-june-2days-2019.json zstd 3 6273951764 661889476 10927 547.57 github-june-2days-2019.json zstd 6 6273951764 642756859 22996 260.18 github-june-2days-2019.json zstd 9 6273951764 601974523 52413 114.16 -github-june-2days-2019.json gzstd 1 6273951764 1164400847 29948 199.79 -github-june-2days-2019.json gzkp 1 6273951764 1125417694 21788 274.61 +github-june-2days-2019.json gzstd 1 6273951764 1164397768 26793 223.32 +github-june-2days-2019.json gzkp 1 6273951764 1120631856 17693 338.16 VM Image, Linux mint with a few installed applications: https://files.klauspost.com/compress/rawstudio-mint14.7z file out level insize outsize millis mb/s -rawstudio-mint14.tar zskp 1 8558382592 3667489370 20210 403.84 -rawstudio-mint14.tar zskp 2 8558382592 3364592300 31873 256.07 -rawstudio-mint14.tar zskp 3 8558382592 3158085214 77675 105.08 -rawstudio-mint14.tar zskp 4 8558382592 2965110639 857750 9.52 +rawstudio-mint14.tar zskp 1 8558382592 3718400221 18206 448.29 +rawstudio-mint14.tar zskp 2 8558382592 3326118337 37074 220.15 +rawstudio-mint14.tar zskp 3 8558382592 3163842361 87306 93.49 +rawstudio-mint14.tar zskp 4 8558382592 2970480650 783862 10.41 rawstudio-mint14.tar zstd 1 8558382592 3609250104 17136 476.27 rawstudio-mint14.tar zstd 3 8558382592 3341679997 29262 278.92 rawstudio-mint14.tar zstd 6 8558382592 3235846406 77904 104.77 rawstudio-mint14.tar zstd 9 8558382592 3160778861 140946 57.91 -rawstudio-mint14.tar gzstd 1 8558382592 3926257486 57722 141.40 -rawstudio-mint14.tar gzkp 1 8558382592 3962605659 45113 180.92 +rawstudio-mint14.tar gzstd 1 8558382592 3926234992 51345 158.96 +rawstudio-mint14.tar gzkp 1 8558382592 3960117298 36722 222.26 CSV data: https://files.klauspost.com/compress/nyc-taxi-data-10M.csv.zst file out level insize outsize millis mb/s -nyc-taxi-data-10M.csv zskp 1 3325605752 641339945 8925 355.35 -nyc-taxi-data-10M.csv zskp 2 3325605752 591748091 11268 281.44 -nyc-taxi-data-10M.csv zskp 3 3325605752 530289687 25239 125.66 -nyc-taxi-data-10M.csv zskp 4 3325605752 476268884 135958 23.33 +nyc-taxi-data-10M.csv zskp 1 3325605752 641319332 9462 335.17 +nyc-taxi-data-10M.csv zskp 2 3325605752 588976126 17570 180.50 +nyc-taxi-data-10M.csv zskp 3 3325605752 529329260 32432 97.79 +nyc-taxi-data-10M.csv zskp 4 3325605752 474949772 138025 22.98 nyc-taxi-data-10M.csv zstd 1 3325605752 687399637 8233 385.18 nyc-taxi-data-10M.csv zstd 3 3325605752 598514411 10065 315.07 nyc-taxi-data-10M.csv zstd 6 3325605752 570522953 20038 158.27 nyc-taxi-data-10M.csv zstd 9 3325605752 517554797 64565 49.12 -nyc-taxi-data-10M.csv gzstd 1 3325605752 928656485 23876 132.83 -nyc-taxi-data-10M.csv gzkp 1 3325605752 922257165 16780 189.00 +nyc-taxi-data-10M.csv gzstd 1 3325605752 928654908 21270 149.11 +nyc-taxi-data-10M.csv gzkp 1 3325605752 922273214 13929 227.68 ``` ## Decompressor @@ -283,8 +289,13 @@ func Decompress(in io.Reader, out io.Writer) error { } ``` -It is important to use the "Close" function when you no longer need the Reader to stop running goroutines. -See "Allocation-less operation" below. +It is important to use the "Close" function when you no longer need the Reader to stop running goroutines, +when running with default settings. +Goroutines will exit once an error has been returned, including `io.EOF` at the end of a stream. + +Streams are decoded concurrently in 4 asynchronous stages to give the best possible throughput. +However, if you prefer synchronous decompression, use `WithDecoderConcurrency(1)` which will decompress data +as it is being requested only. For decoding buffers, it could look something like this: @@ -293,7 +304,7 @@ import "github.com/klauspost/compress/zstd" // Create a reader that caches decompressors. // For this operation type we supply a nil Reader. -var decoder, _ = zstd.NewReader(nil) +var decoder, _ = zstd.NewReader(nil, zstd.WithDecoderConcurrency(0)) // Decompress a buffer. We don't supply a destination buffer, // so it will be allocated by the decoder. @@ -303,9 +314,12 @@ func Decompress(src []byte) ([]byte, error) { ``` Both of these cases should provide the functionality needed. -The decoder can be used for *concurrent* decompression of multiple buffers. +The decoder can be used for *concurrent* decompression of multiple buffers. +By default 4 decompressors will be created. + It will only allow a certain number of concurrent operations to run. -To tweak that yourself use the `WithDecoderConcurrency(n)` option when creating the decoder. +To tweak that yourself use the `WithDecoderConcurrency(n)` option when creating the decoder. +It is possible to use `WithDecoderConcurrency(0)` to create GOMAXPROCS decoders. ### Dictionaries @@ -357,62 +371,48 @@ In this case no unneeded allocations should be made. The buffer decoder does everything on the same goroutine and does nothing concurrently. It can however decode several buffers concurrently. Use `WithDecoderConcurrency(n)` to limit that. -The stream decoder operates on +The stream decoder will create goroutines that: -* One goroutine reads input and splits the input to several block decoders. -* A number of decoders will decode blocks. -* A goroutine coordinates these blocks and sends history from one to the next. +1) Reads input and splits the input into blocks. +2) Decompression of literals. +3) Decompression of sequences. +4) Reconstruction of output stream. So effectively this also means the decoder will "read ahead" and prepare data to always be available for output. +The concurrency level will, for streams, determine how many blocks ahead the compression will start. + Since "blocks" are quite dependent on the output of the previous block stream decoding will only have limited concurrency. -In practice this means that concurrency is often limited to utilizing about 2 cores effectively. - - +In practice this means that concurrency is often limited to utilizing about 3 cores effectively. + ### Benchmarks -These are some examples of performance compared to [datadog cgo library](https://github.com/DataDog/zstd). - The first two are streaming decodes and the last are smaller inputs. - + +Running on AMD Ryzen 9 3950X 16-Core Processor. AMD64 assembly used. + ``` -BenchmarkDecoderSilesia-8 3 385000067 ns/op 550.51 MB/s 5498 B/op 8 allocs/op -BenchmarkDecoderSilesiaCgo-8 6 197666567 ns/op 1072.25 MB/s 270672 B/op 8 allocs/op - -BenchmarkDecoderEnwik9-8 1 2027001600 ns/op 493.34 MB/s 10496 B/op 18 allocs/op -BenchmarkDecoderEnwik9Cgo-8 2 979499200 ns/op 1020.93 MB/s 270672 B/op 8 allocs/op - -Concurrent performance: - -BenchmarkDecoder_DecodeAllParallel/kppkn.gtb.zst-16 28915 42469 ns/op 4340.07 MB/s 114 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallel/geo.protodata.zst-16 116505 9965 ns/op 11900.16 MB/s 16 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallel/plrabn12.txt.zst-16 8952 134272 ns/op 3588.70 MB/s 915 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallel/lcet10.txt.zst-16 11820 102538 ns/op 4161.90 MB/s 594 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallel/asyoulik.txt.zst-16 34782 34184 ns/op 3661.88 MB/s 60 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallel/alice29.txt.zst-16 27712 43447 ns/op 3500.58 MB/s 99 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallel/html_x_4.zst-16 62826 18750 ns/op 21845.10 MB/s 104 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallel/paper-100k.pdf.zst-16 631545 1794 ns/op 57078.74 MB/s 2 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallel/fireworks.jpeg.zst-16 1690140 712 ns/op 172938.13 MB/s 1 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallel/urls.10K.zst-16 10432 113593 ns/op 6180.73 MB/s 1143 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallel/html.zst-16 113206 10671 ns/op 9596.27 MB/s 15 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallel/comp-data.bin.zst-16 1530615 779 ns/op 5229.49 MB/s 0 B/op 0 allocs/op - -BenchmarkDecoder_DecodeAllParallelCgo/kppkn.gtb.zst-16 65217 16192 ns/op 11383.34 MB/s 46 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallelCgo/geo.protodata.zst-16 292671 4039 ns/op 29363.19 MB/s 6 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallelCgo/plrabn12.txt.zst-16 26314 46021 ns/op 10470.43 MB/s 293 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallelCgo/lcet10.txt.zst-16 33897 34900 ns/op 12227.96 MB/s 205 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallelCgo/asyoulik.txt.zst-16 104348 11433 ns/op 10949.01 MB/s 20 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallelCgo/alice29.txt.zst-16 75949 15510 ns/op 9805.60 MB/s 32 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallelCgo/html_x_4.zst-16 173910 6756 ns/op 60624.29 MB/s 37 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallelCgo/paper-100k.pdf.zst-16 923076 1339 ns/op 76474.87 MB/s 1 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallelCgo/fireworks.jpeg.zst-16 922920 1351 ns/op 91102.57 MB/s 2 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallelCgo/urls.10K.zst-16 27649 43618 ns/op 16096.19 MB/s 407 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallelCgo/html.zst-16 279073 4160 ns/op 24614.18 MB/s 6 B/op 0 allocs/op -BenchmarkDecoder_DecodeAllParallelCgo/comp-data.bin.zst-16 749938 1579 ns/op 2581.71 MB/s 0 B/op 0 allocs/op +BenchmarkDecoderSilesia-32 5 206878840 ns/op 1024.50 MB/s 49808 B/op 43 allocs/op +BenchmarkDecoderEnwik9-32 1 1271809000 ns/op 786.28 MB/s 72048 B/op 52 allocs/op + +Concurrent blocks, performance: + +BenchmarkDecoder_DecodeAllParallel/kppkn.gtb.zst-32 67356 17857 ns/op 10321.96 MB/s 22.48 pct 102 B/op 0 allocs/op +BenchmarkDecoder_DecodeAllParallel/geo.protodata.zst-32 266656 4421 ns/op 26823.21 MB/s 11.89 pct 19 B/op 0 allocs/op +BenchmarkDecoder_DecodeAllParallel/plrabn12.txt.zst-32 20992 56842 ns/op 8477.17 MB/s 39.90 pct 754 B/op 0 allocs/op +BenchmarkDecoder_DecodeAllParallel/lcet10.txt.zst-32 27456 43932 ns/op 9714.01 MB/s 33.27 pct 524 B/op 0 allocs/op +BenchmarkDecoder_DecodeAllParallel/asyoulik.txt.zst-32 78432 15047 ns/op 8319.15 MB/s 40.34 pct 66 B/op 0 allocs/op +BenchmarkDecoder_DecodeAllParallel/alice29.txt.zst-32 65800 18436 ns/op 8249.63 MB/s 37.75 pct 88 B/op 0 allocs/op +BenchmarkDecoder_DecodeAllParallel/html_x_4.zst-32 102993 11523 ns/op 35546.09 MB/s 3.637 pct 143 B/op 0 allocs/op +BenchmarkDecoder_DecodeAllParallel/paper-100k.pdf.zst-32 1000000 1070 ns/op 95720.98 MB/s 80.53 pct 3 B/op 0 allocs/op +BenchmarkDecoder_DecodeAllParallel/fireworks.jpeg.zst-32 749802 1752 ns/op 70272.35 MB/s 100.0 pct 5 B/op 0 allocs/op +BenchmarkDecoder_DecodeAllParallel/urls.10K.zst-32 22640 52934 ns/op 13263.37 MB/s 26.25 pct 1014 B/op 0 allocs/op +BenchmarkDecoder_DecodeAllParallel/html.zst-32 226412 5232 ns/op 19572.27 MB/s 14.49 pct 20 B/op 0 allocs/op +BenchmarkDecoder_DecodeAllParallel/comp-data.bin.zst-32 923041 1276 ns/op 3194.71 MB/s 31.26 pct 0 B/op 0 allocs/op ``` -This reflects the performance around May 2020, but this may be out of date. +This reflects the performance around May 2022, but this may be out of date. ## Zstd inside ZIP files diff --git a/vendor/github.com/klauspost/compress/zstd/bitreader.go b/vendor/github.com/klauspost/compress/zstd/bitreader.go index 8544585..97299d4 100644 --- a/vendor/github.com/klauspost/compress/zstd/bitreader.go +++ b/vendor/github.com/klauspost/compress/zstd/bitreader.go @@ -7,6 +7,7 @@ package zstd import ( "encoding/binary" "errors" + "fmt" "io" "math/bits" ) @@ -50,16 +51,16 @@ func (b *bitReader) getBits(n uint8) int { if n == 0 /*|| b.bitsRead >= 64 */ { return 0 } - return b.getBitsFast(n) + return int(b.get32BitsFast(n)) } -// getBitsFast requires that at least one bit is requested every time. +// get32BitsFast requires that at least one bit is requested every time. // There are no checks if the buffer is filled. -func (b *bitReader) getBitsFast(n uint8) int { +func (b *bitReader) get32BitsFast(n uint8) uint32 { const regMask = 64 - 1 v := uint32((b.value << (b.bitsRead & regMask)) >> ((regMask + 1 - n) & regMask)) b.bitsRead += n - return int(v) + return v } // fillFast() will make sure at least 32 bits are available. @@ -125,6 +126,9 @@ func (b *bitReader) remain() uint { func (b *bitReader) close() error { // Release reference. b.in = nil + if !b.finished() { + return fmt.Errorf("%d extra bits on block, should be 0", b.remain()) + } if b.bitsRead > 64 { return io.ErrUnexpectedEOF } diff --git a/vendor/github.com/klauspost/compress/zstd/bitwriter.go b/vendor/github.com/klauspost/compress/zstd/bitwriter.go index 303ae90..78b3c61 100644 --- a/vendor/github.com/klauspost/compress/zstd/bitwriter.go +++ b/vendor/github.com/klauspost/compress/zstd/bitwriter.go @@ -5,8 +5,6 @@ package zstd -import "fmt" - // bitWriter will write bits. // First bit will be LSB of the first byte of output. type bitWriter struct { @@ -38,7 +36,7 @@ func (b *bitWriter) addBits16NC(value uint16, bits uint8) { b.nBits += bits } -// addBits32NC will add up to 32 bits. +// addBits32NC will add up to 31 bits. // It will not check if there is space for them, // so the caller must ensure that it has flushed recently. func (b *bitWriter) addBits32NC(value uint32, bits uint8) { @@ -46,6 +44,26 @@ func (b *bitWriter) addBits32NC(value uint32, bits uint8) { b.nBits += bits } +// addBits64NC will add up to 64 bits. +// There must be space for 32 bits. +func (b *bitWriter) addBits64NC(value uint64, bits uint8) { + if bits <= 31 { + b.addBits32Clean(uint32(value), bits) + return + } + b.addBits32Clean(uint32(value), 32) + b.flush32() + b.addBits32Clean(uint32(value>>32), bits-32) +} + +// addBits32Clean will add up to 32 bits. +// It will not check if there is space for them. +// The input must not contain more bits than specified. +func (b *bitWriter) addBits32Clean(value uint32, bits uint8) { + b.bitContainer |= uint64(value) << (b.nBits & 63) + b.nBits += bits +} + // addBits16Clean will add up to 16 bits. value may not contain more set bits than indicated. // It will not check if there is space for them, so the caller must ensure that it has flushed recently. func (b *bitWriter) addBits16Clean(value uint16, bits uint8) { @@ -53,80 +71,6 @@ func (b *bitWriter) addBits16Clean(value uint16, bits uint8) { b.nBits += bits } -// flush will flush all pending full bytes. -// There will be at least 56 bits available for writing when this has been called. -// Using flush32 is faster, but leaves less space for writing. -func (b *bitWriter) flush() { - v := b.nBits >> 3 - switch v { - case 0: - case 1: - b.out = append(b.out, - byte(b.bitContainer), - ) - case 2: - b.out = append(b.out, - byte(b.bitContainer), - byte(b.bitContainer>>8), - ) - case 3: - b.out = append(b.out, - byte(b.bitContainer), - byte(b.bitContainer>>8), - byte(b.bitContainer>>16), - ) - case 4: - b.out = append(b.out, - byte(b.bitContainer), - byte(b.bitContainer>>8), - byte(b.bitContainer>>16), - byte(b.bitContainer>>24), - ) - case 5: - b.out = append(b.out, - byte(b.bitContainer), - byte(b.bitContainer>>8), - byte(b.bitContainer>>16), - byte(b.bitContainer>>24), - byte(b.bitContainer>>32), - ) - case 6: - b.out = append(b.out, - byte(b.bitContainer), - byte(b.bitContainer>>8), - byte(b.bitContainer>>16), - byte(b.bitContainer>>24), - byte(b.bitContainer>>32), - byte(b.bitContainer>>40), - ) - case 7: - b.out = append(b.out, - byte(b.bitContainer), - byte(b.bitContainer>>8), - byte(b.bitContainer>>16), - byte(b.bitContainer>>24), - byte(b.bitContainer>>32), - byte(b.bitContainer>>40), - byte(b.bitContainer>>48), - ) - case 8: - b.out = append(b.out, - byte(b.bitContainer), - byte(b.bitContainer>>8), - byte(b.bitContainer>>16), - byte(b.bitContainer>>24), - byte(b.bitContainer>>32), - byte(b.bitContainer>>40), - byte(b.bitContainer>>48), - byte(b.bitContainer>>56), - ) - default: - panic(fmt.Errorf("bits (%d) > 64", b.nBits)) - } - b.bitContainer >>= v << 3 - b.nBits &= 7 -} - // flush32 will flush out, so there are at least 32 bits available for writing. func (b *bitWriter) flush32() { if b.nBits < 32 { diff --git a/vendor/github.com/klauspost/compress/zstd/blockdec.go b/vendor/github.com/klauspost/compress/zstd/blockdec.go index 8a98c45..9f17ce6 100644 --- a/vendor/github.com/klauspost/compress/zstd/blockdec.go +++ b/vendor/github.com/klauspost/compress/zstd/blockdec.go @@ -5,9 +5,14 @@ package zstd import ( + "bytes" + "encoding/binary" "errors" "fmt" + "hash/crc32" "io" + "os" + "path/filepath" "sync" "github.com/klauspost/compress/huff0" @@ -38,14 +43,14 @@ const ( // maxCompressedBlockSize is the biggest allowed compressed block size (128KB) maxCompressedBlockSize = 128 << 10 + compressedBlockOverAlloc = 16 + maxCompressedBlockSizeAlloc = 128<<10 + compressedBlockOverAlloc + // Maximum possible block size (all Raw+Uncompressed). maxBlockSize = (1 << 21) - 1 - // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#literals_section_header - maxCompressedLiteralSize = 1 << 18 - maxRLELiteralSize = 1 << 20 - maxMatchLen = 131074 - maxSequences = 0x7f00 + 0xffff + maxMatchLen = 131074 + maxSequences = 0x7f00 + 0xffff // We support slightly less than the reference decoder to be able to // use ints on 32 bit archs. @@ -76,20 +81,28 @@ type blockDec struct { // Window size of the block. WindowSize uint64 - history chan *history - input chan struct{} - result chan decodeOutput - sequenceBuf []seq - err error - decWG sync.WaitGroup + err error + + // Check against this crc, if hasCRC is true. + checkCRC uint32 + hasCRC bool // Frame to use for singlethreaded decoding. // Should not be used by the decoder itself since parent may be another frame. localFrame *frameDec + sequence []seqVals + + async struct { + newHist *history + literals []byte + seqData []byte + seqSize int // Size of uncompressed sequences + fcs uint64 + } + // Block is RLE, this is the size. RLESize uint32 - tmp [4]byte Type blockType @@ -109,13 +122,8 @@ func (b *blockDec) String() string { func newBlockDec(lowMem bool) *blockDec { b := blockDec{ - lowMem: lowMem, - result: make(chan decodeOutput, 1), - input: make(chan struct{}, 1), - history: make(chan *history, 1), + lowMem: lowMem, } - b.decWG.Add(1) - go b.startDecoder() return &b } @@ -133,11 +141,17 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error { b.Type = blockType((bh >> 1) & 3) // find size. cSize := int(bh >> 3) - maxSize := maxBlockSize + maxSize := maxCompressedBlockSizeAlloc switch b.Type { case blockTypeReserved: return ErrReservedBlockType case blockTypeRLE: + if cSize > maxCompressedBlockSize || cSize > int(b.WindowSize) { + if debugDecoder { + printf("rle block too big: csize:%d block: %+v\n", uint64(cSize), b) + } + return ErrWindowSizeExceeded + } b.RLESize = uint32(cSize) if b.lowMem { maxSize = cSize @@ -148,9 +162,9 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error { println("Data size on stream:", cSize) } b.RLESize = 0 - maxSize = maxCompressedBlockSize + maxSize = maxCompressedBlockSizeAlloc if windowSize < maxCompressedBlockSize && b.lowMem { - maxSize = int(windowSize) + maxSize = int(windowSize) + compressedBlockOverAlloc } if cSize > maxCompressedBlockSize || uint64(cSize) > b.WindowSize { if debugDecoder { @@ -158,7 +172,19 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error { } return ErrCompressedSizeTooBig } + // Empty compressed blocks must at least be 2 bytes + // for Literals_Block_Type and one for Sequences_Section_Header. + if cSize < 2 { + return ErrBlockTooSmall + } case blockTypeRaw: + if cSize > maxCompressedBlockSize || cSize > int(b.WindowSize) { + if debugDecoder { + printf("rle block too big: csize:%d block: %+v\n", uint64(cSize), b) + } + return ErrWindowSizeExceeded + } + b.RLESize = 0 // We do not need a destination for raw blocks. maxSize = -1 @@ -167,16 +193,14 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error { } // Read block data. - if cap(b.dataStorage) < cSize { + if _, ok := br.(*byteBuf); !ok && cap(b.dataStorage) < cSize { + // byteBuf doesn't need a destination buffer. if b.lowMem || cSize > maxCompressedBlockSize { - b.dataStorage = make([]byte, 0, cSize) + b.dataStorage = make([]byte, 0, cSize+compressedBlockOverAlloc) } else { - b.dataStorage = make([]byte, 0, maxCompressedBlockSize) + b.dataStorage = make([]byte, 0, maxCompressedBlockSizeAlloc) } } - if cap(b.dst) <= maxSize { - b.dst = make([]byte, 0, maxSize+1) - } b.data, err = br.readBig(cSize, b.dataStorage) if err != nil { if debugDecoder { @@ -185,6 +209,9 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error { } return err } + if cap(b.dst) <= maxSize { + b.dst = make([]byte, 0, maxSize+1) + } return nil } @@ -193,85 +220,14 @@ func (b *blockDec) sendErr(err error) { b.Last = true b.Type = blockTypeReserved b.err = err - b.input <- struct{}{} } // Close will release resources. // Closed blockDec cannot be reset. func (b *blockDec) Close() { - close(b.input) - close(b.history) - close(b.result) - b.decWG.Wait() } -// decodeAsync will prepare decoding the block when it receives input. -// This will separate output and history. -func (b *blockDec) startDecoder() { - defer b.decWG.Done() - for range b.input { - //println("blockDec: Got block input") - switch b.Type { - case blockTypeRLE: - if cap(b.dst) < int(b.RLESize) { - if b.lowMem { - b.dst = make([]byte, b.RLESize) - } else { - b.dst = make([]byte, maxBlockSize) - } - } - o := decodeOutput{ - d: b, - b: b.dst[:b.RLESize], - err: nil, - } - v := b.data[0] - for i := range o.b { - o.b[i] = v - } - hist := <-b.history - hist.append(o.b) - b.result <- o - case blockTypeRaw: - o := decodeOutput{ - d: b, - b: b.data, - err: nil, - } - hist := <-b.history - hist.append(o.b) - b.result <- o - case blockTypeCompressed: - b.dst = b.dst[:0] - err := b.decodeCompressed(nil) - o := decodeOutput{ - d: b, - b: b.dst, - err: err, - } - if debugDecoder { - println("Decompressed to", len(b.dst), "bytes, error:", err) - } - b.result <- o - case blockTypeReserved: - // Used for returning errors. - <-b.history - b.result <- decodeOutput{ - d: b, - b: nil, - err: b.err, - } - default: - panic("Invalid block type") - } - if debugDecoder { - println("blockDec: Finished block") - } - } -} - -// decodeAsync will prepare decoding the block when it receives the history. -// If history is provided, it will not fetch it from the channel. +// decodeBuf func (b *blockDec) decodeBuf(hist *history) error { switch b.Type { case blockTypeRLE: @@ -279,7 +235,7 @@ func (b *blockDec) decodeBuf(hist *history) error { if b.lowMem { b.dst = make([]byte, b.RLESize) } else { - b.dst = make([]byte, maxBlockSize) + b.dst = make([]byte, maxCompressedBlockSize) } } b.dst = b.dst[:b.RLESize] @@ -294,14 +250,23 @@ func (b *blockDec) decodeBuf(hist *history) error { return nil case blockTypeCompressed: saved := b.dst - b.dst = hist.b - hist.b = nil + // Append directly to history + if hist.ignoreBuffer == 0 { + b.dst = hist.b + hist.b = nil + } else { + b.dst = b.dst[:0] + } err := b.decodeCompressed(hist) if debugDecoder { println("Decompressed to total", len(b.dst), "bytes, hash:", xxhash.Sum64(b.dst), "error:", err) } - hist.b = b.dst - b.dst = saved + if hist.ignoreBuffer == 0 { + hist.b = b.dst + b.dst = saved + } else { + hist.appendKeep(b.dst) + } return err case blockTypeReserved: // Used for returning errors. @@ -311,30 +276,18 @@ func (b *blockDec) decodeBuf(hist *history) error { } } -// decodeCompressed will start decompressing a block. -// If no history is supplied the decoder will decodeAsync as much as possible -// before fetching from blockDec.history -func (b *blockDec) decodeCompressed(hist *history) error { - in := b.data - delayedHistory := hist == nil - - if delayedHistory { - // We must always grab history. - defer func() { - if hist == nil { - <-b.history - } - }() - } +func (b *blockDec) decodeLiterals(in []byte, hist *history) (remain []byte, err error) { // There must be at least one byte for Literals_Block_Type and one for Sequences_Section_Header if len(in) < 2 { - return ErrBlockTooSmall + return in, ErrBlockTooSmall } + litType := literalsBlockType(in[0] & 3) var litRegenSize int var litCompSize int sizeFormat := (in[0] >> 2) & 3 var fourStreams bool + var literals []byte switch litType { case literalsBlockRaw, literalsBlockRLE: switch sizeFormat { @@ -350,7 +303,7 @@ func (b *blockDec) decodeCompressed(hist *history) error { // Regenerated_Size uses 20 bits (0-1048575). Literals_Section_Header uses 3 bytes. if len(in) < 3 { println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in)) - return ErrBlockTooSmall + return in, ErrBlockTooSmall } litRegenSize = int(in[0]>>4) + (int(in[1]) << 4) + (int(in[2]) << 12) in = in[3:] @@ -361,7 +314,7 @@ func (b *blockDec) decodeCompressed(hist *history) error { // Both Regenerated_Size and Compressed_Size use 10 bits (0-1023). if len(in) < 3 { println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in)) - return ErrBlockTooSmall + return in, ErrBlockTooSmall } n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) litRegenSize = int(n & 1023) @@ -372,7 +325,7 @@ func (b *blockDec) decodeCompressed(hist *history) error { fourStreams = true if len(in) < 4 { println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in)) - return ErrBlockTooSmall + return in, ErrBlockTooSmall } n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20) litRegenSize = int(n & 16383) @@ -382,7 +335,7 @@ func (b *blockDec) decodeCompressed(hist *history) error { fourStreams = true if len(in) < 5 { println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in)) - return ErrBlockTooSmall + return in, ErrBlockTooSmall } n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20) + (uint64(in[4]) << 28) litRegenSize = int(n & 262143) @@ -393,13 +346,15 @@ func (b *blockDec) decodeCompressed(hist *history) error { if debugDecoder { println("literals type:", litType, "litRegenSize:", litRegenSize, "litCompSize:", litCompSize, "sizeFormat:", sizeFormat, "4X:", fourStreams) } - var literals []byte - var huff *huff0.Scratch + if litRegenSize > int(b.WindowSize) || litRegenSize > maxCompressedBlockSize { + return in, ErrWindowSizeExceeded + } + switch litType { case literalsBlockRaw: if len(in) < litRegenSize { println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litRegenSize) - return ErrBlockTooSmall + return in, ErrBlockTooSmall } literals = in[:litRegenSize] in = in[litRegenSize:] @@ -407,19 +362,13 @@ func (b *blockDec) decodeCompressed(hist *history) error { case literalsBlockRLE: if len(in) < 1 { println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", 1) - return ErrBlockTooSmall + return in, ErrBlockTooSmall } if cap(b.literalBuf) < litRegenSize { if b.lowMem { - b.literalBuf = make([]byte, litRegenSize) + b.literalBuf = make([]byte, litRegenSize, litRegenSize+compressedBlockOverAlloc) } else { - if litRegenSize > maxCompressedLiteralSize { - // Exceptional - b.literalBuf = make([]byte, litRegenSize) - } else { - b.literalBuf = make([]byte, litRegenSize, maxCompressedLiteralSize) - - } + b.literalBuf = make([]byte, litRegenSize, maxCompressedBlockSize+compressedBlockOverAlloc) } } literals = b.literalBuf[:litRegenSize] @@ -434,7 +383,7 @@ func (b *blockDec) decodeCompressed(hist *history) error { case literalsBlockTreeless: if len(in) < litCompSize { println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize) - return ErrBlockTooSmall + return in, ErrBlockTooSmall } // Store compressed literals, so we defer decoding until we get history. literals = in[:litCompSize] @@ -442,31 +391,68 @@ func (b *blockDec) decodeCompressed(hist *history) error { if debugDecoder { printf("Found %d compressed literals\n", litCompSize) } + huff := hist.huffTree + if huff == nil { + return in, errors.New("literal block was treeless, but no history was defined") + } + // Ensure we have space to store it. + if cap(b.literalBuf) < litRegenSize { + if b.lowMem { + b.literalBuf = make([]byte, 0, litRegenSize+compressedBlockOverAlloc) + } else { + b.literalBuf = make([]byte, 0, maxCompressedBlockSize+compressedBlockOverAlloc) + } + } + var err error + // Use our out buffer. + huff.MaxDecodedSize = litRegenSize + if fourStreams { + literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals) + } else { + literals, err = huff.Decoder().Decompress1X(b.literalBuf[:0:litRegenSize], literals) + } + // Make sure we don't leak our literals buffer + if err != nil { + println("decompressing literals:", err) + return in, err + } + if len(literals) != litRegenSize { + return in, fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals)) + } + case literalsBlockCompressed: if len(in) < litCompSize { println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize) - return ErrBlockTooSmall + return in, ErrBlockTooSmall } literals = in[:litCompSize] in = in[litCompSize:] - huff = huffDecoderPool.Get().(*huff0.Scratch) - var err error // Ensure we have space to store it. if cap(b.literalBuf) < litRegenSize { if b.lowMem { - b.literalBuf = make([]byte, 0, litRegenSize) + b.literalBuf = make([]byte, 0, litRegenSize+compressedBlockOverAlloc) } else { - b.literalBuf = make([]byte, 0, maxCompressedLiteralSize) + b.literalBuf = make([]byte, 0, maxCompressedBlockSize+compressedBlockOverAlloc) } } - if huff == nil { - huff = &huff0.Scratch{} + huff := hist.huffTree + if huff == nil || (hist.dict != nil && huff == hist.dict.litEnc) { + huff = huffDecoderPool.Get().(*huff0.Scratch) + if huff == nil { + huff = &huff0.Scratch{} + } + } + var err error + if debugDecoder { + println("huff table input:", len(literals), "CRC:", crc32.ChecksumIEEE(literals)) } huff, literals, err = huff0.ReadTable(literals, huff) if err != nil { println("reading huffman table:", err) - return err + return in, err } + hist.huffTree = huff + huff.MaxDecodedSize = litRegenSize // Use our out buffer. if fourStreams { literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals) @@ -475,27 +461,63 @@ func (b *blockDec) decodeCompressed(hist *history) error { } if err != nil { println("decoding compressed literals:", err) - return err + return in, err } // Make sure we don't leak our literals buffer if len(literals) != litRegenSize { - return fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals)) + return in, fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals)) } + // Re-cap to get extra size. + literals = b.literalBuf[:len(literals)] if debugDecoder { printf("Decompressed %d literals into %d bytes\n", litCompSize, litRegenSize) } } + hist.decoders.literals = literals + return in, nil +} + +// decodeCompressed will start decompressing a block. +func (b *blockDec) decodeCompressed(hist *history) error { + in := b.data + in, err := b.decodeLiterals(in, hist) + if err != nil { + return err + } + err = b.prepareSequences(in, hist) + if err != nil { + return err + } + if hist.decoders.nSeqs == 0 { + b.dst = append(b.dst, hist.decoders.literals...) + return nil + } + before := len(hist.decoders.out) + err = hist.decoders.decodeSync(hist.b[hist.ignoreBuffer:]) + if err != nil { + return err + } + if hist.decoders.maxSyncLen > 0 { + hist.decoders.maxSyncLen += uint64(before) + hist.decoders.maxSyncLen -= uint64(len(hist.decoders.out)) + } + b.dst = hist.decoders.out + hist.recentOffsets = hist.decoders.prevOffset + return nil +} +func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) { + if debugDecoder { + printf("prepareSequences: %d byte(s) input\n", len(in)) + } // Decode Sequences // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#sequences-section if len(in) < 1 { return ErrBlockTooSmall } + var nSeqs int seqHeader := in[0] - nSeqs := 0 switch { - case seqHeader == 0: - in = in[1:] case seqHeader < 128: nSeqs = int(seqHeader) in = in[1:] @@ -512,19 +534,16 @@ func (b *blockDec) decodeCompressed(hist *history) error { nSeqs = 0x7f00 + int(in[1]) + (int(in[2]) << 8) in = in[3:] } - // Allocate sequences - if cap(b.sequenceBuf) < nSeqs { - if b.lowMem { - b.sequenceBuf = make([]seq, nSeqs) - } else { - // Allocate max - b.sequenceBuf = make([]seq, nSeqs, maxSequences) + if nSeqs == 0 && len(in) != 0 { + // When no sequences, there should not be any more data... + if debugDecoder { + printf("prepareSequences: 0 sequences, but %d byte(s) left on stream\n", len(in)) } - } else { - // Reuse buffer - b.sequenceBuf = b.sequenceBuf[:nSeqs] + return ErrUnexpectedBlockSize } - var seqs = &sequenceDecs{} + + var seqs = &hist.decoders + seqs.nSeqs = nSeqs if nSeqs > 0 { if len(in) < 1 { return ErrBlockTooSmall @@ -553,6 +572,9 @@ func (b *blockDec) decodeCompressed(hist *history) error { } switch mode { case compModePredefined: + if seq.fse != nil && !seq.fse.preDefined { + fseDecoderPool.Put(seq.fse) + } seq.fse = &fsePredef[i] case compModeRLE: if br.remain() < 1 { @@ -560,34 +582,36 @@ func (b *blockDec) decodeCompressed(hist *history) error { } v := br.Uint8() br.advance(1) - dec := fseDecoderPool.Get().(*fseDecoder) + if seq.fse == nil || seq.fse.preDefined { + seq.fse = fseDecoderPool.Get().(*fseDecoder) + } symb, err := decSymbolValue(v, symbolTableX[i]) if err != nil { printf("RLE Transform table (%v) error: %v", tableIndex(i), err) return err } - dec.setRLE(symb) - seq.fse = dec + seq.fse.setRLE(symb) if debugDecoder { - printf("RLE set to %+v, code: %v", symb, v) + printf("RLE set to 0x%x, code: %v", symb, v) } case compModeFSE: println("Reading table for", tableIndex(i)) - dec := fseDecoderPool.Get().(*fseDecoder) - err := dec.readNCount(&br, uint16(maxTableSymbol[i])) + if seq.fse == nil || seq.fse.preDefined { + seq.fse = fseDecoderPool.Get().(*fseDecoder) + } + err := seq.fse.readNCount(&br, uint16(maxTableSymbol[i])) if err != nil { println("Read table error:", err) return err } - err = dec.transform(symbolTableX[i]) + err = seq.fse.transform(symbolTableX[i]) if err != nil { println("Transform table error:", err) return err } if debugDecoder { - println("Read table ok", "symbolLen:", dec.symbolLen) + println("Read table ok", "symbolLen:", seq.fse.symbolLen) } - seq.fse = dec case compModeRepeat: seq.repeat = true } @@ -597,140 +621,106 @@ func (b *blockDec) decodeCompressed(hist *history) error { } in = br.unread() } - - // Wait for history. - // All time spent after this is critical since it is strictly sequential. - if hist == nil { - hist = <-b.history - if hist.error { - return ErrDecoderClosed - } - } - - // Decode treeless literal block. - if litType == literalsBlockTreeless { - // TODO: We could send the history early WITHOUT the stream history. - // This would allow decoding treeless literals before the byte history is available. - // Silencia stats: Treeless 4393, with: 32775, total: 37168, 11% treeless. - // So not much obvious gain here. - - if hist.huffTree == nil { - return errors.New("literal block was treeless, but no history was defined") - } - // Ensure we have space to store it. - if cap(b.literalBuf) < litRegenSize { - if b.lowMem { - b.literalBuf = make([]byte, 0, litRegenSize) - } else { - b.literalBuf = make([]byte, 0, maxCompressedLiteralSize) - } - } - var err error - // Use our out buffer. - huff = hist.huffTree - if fourStreams { - literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals) - } else { - literals, err = huff.Decoder().Decompress1X(b.literalBuf[:0:litRegenSize], literals) - } - // Make sure we don't leak our literals buffer - if err != nil { - println("decompressing literals:", err) - return err - } - if len(literals) != litRegenSize { - return fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals)) - } - } else { - if hist.huffTree != nil && huff != nil { - if hist.dict == nil || hist.dict.litEnc != hist.huffTree { - huffDecoderPool.Put(hist.huffTree) - } - hist.huffTree = nil - } - } - if huff != nil { - hist.huffTree = huff - } if debugDecoder { - println("Final literals:", len(literals), "hash:", xxhash.Sum64(literals), "and", nSeqs, "sequences.") + println("Literals:", len(seqs.literals), "hash:", xxhash.Sum64(seqs.literals), "and", seqs.nSeqs, "sequences.") } if nSeqs == 0 { - // Decompressed content is defined entirely as Literals Section content. - b.dst = append(b.dst, literals...) - if delayedHistory { - hist.append(literals) + if len(b.sequence) > 0 { + b.sequence = b.sequence[:0] } return nil } + br := seqs.br + if br == nil { + br = &bitReader{} + } + if err := br.init(in); err != nil { + return err + } - seqs, err := seqs.mergeHistory(&hist.decoders) - if err != nil { + if err := seqs.initialize(br, hist, b.dst); err != nil { + println("initializing sequences:", err) return err } - if debugDecoder { - println("History merged ok") + // Extract blocks... + if false && hist.dict == nil { + fatalErr := func(err error) { + if err != nil { + panic(err) + } + } + fn := fmt.Sprintf("n-%d-lits-%d-prev-%d-%d-%d-win-%d.blk", hist.decoders.nSeqs, len(hist.decoders.literals), hist.recentOffsets[0], hist.recentOffsets[1], hist.recentOffsets[2], hist.windowSize) + var buf bytes.Buffer + fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.litLengths.fse)) + fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.matchLengths.fse)) + fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.offsets.fse)) + buf.Write(in) + os.WriteFile(filepath.Join("testdata", "seqs", fn), buf.Bytes(), os.ModePerm) } - br := &bitReader{} - if err := br.init(in); err != nil { - return err + + return nil +} + +func (b *blockDec) decodeSequences(hist *history) error { + if cap(b.sequence) < hist.decoders.nSeqs { + if b.lowMem { + b.sequence = make([]seqVals, 0, hist.decoders.nSeqs) + } else { + b.sequence = make([]seqVals, 0, 0x7F00+0xffff) + } + } + b.sequence = b.sequence[:hist.decoders.nSeqs] + if hist.decoders.nSeqs == 0 { + hist.decoders.seqSize = len(hist.decoders.literals) + return nil } + hist.decoders.windowSize = hist.windowSize + hist.decoders.prevOffset = hist.recentOffsets - // TODO: Investigate if sending history without decoders are faster. - // This would allow the sequences to be decoded async and only have to construct stream history. - // If only recent offsets were not transferred, this would be an obvious win. - // Also, if first 3 sequences don't reference recent offsets, all sequences can be decoded. + err := hist.decoders.decode(b.sequence) + hist.recentOffsets = hist.decoders.prevOffset + return err +} +func (b *blockDec) executeSequences(hist *history) error { hbytes := hist.b if len(hbytes) > hist.windowSize { hbytes = hbytes[len(hbytes)-hist.windowSize:] - // We do not need history any more. + // We do not need history anymore. if hist.dict != nil { hist.dict.content = nil } } - - if err := seqs.initialize(br, hist, literals, b.dst); err != nil { - println("initializing sequences:", err) - return err - } - - err = seqs.decode(nSeqs, br, hbytes) + hist.decoders.windowSize = hist.windowSize + hist.decoders.out = b.dst[:0] + err := hist.decoders.execute(b.sequence, hbytes) if err != nil { return err } - if !br.finished() { - return fmt.Errorf("%d extra bits on block, should be 0", br.remain()) - } + return b.updateHistory(hist) +} - err = br.close() - if err != nil { - printf("Closing sequences: %v, %+v\n", err, *br) - } +func (b *blockDec) updateHistory(hist *history) error { if len(b.data) > maxCompressedBlockSize { return fmt.Errorf("compressed block size too large (%d)", len(b.data)) } // Set output and release references. - b.dst = seqs.out - seqs.out, seqs.literals, seqs.hist = nil, nil, nil + b.dst = hist.decoders.out + hist.recentOffsets = hist.decoders.prevOffset - if !delayedHistory { - // If we don't have delayed history, no need to update. - hist.recentOffsets = seqs.prevOffset - return nil - } if b.Last { // if last block we don't care about history. println("Last block, no history returned") hist.b = hist.b[:0] return nil + } else { + hist.append(b.dst) + if debugDecoder { + println("Finished block with ", len(b.sequence), "sequences. Added", len(b.dst), "to history, now length", len(hist.b)) + } } - hist.append(b.dst) - hist.recentOffsets = seqs.prevOffset - if debugDecoder { - println("Finished block with literals:", len(literals), "and", nSeqs, "sequences.") - } + hist.decoders.out, hist.decoders.literals = nil, nil return nil } diff --git a/vendor/github.com/klauspost/compress/zstd/blockenc.go b/vendor/github.com/klauspost/compress/zstd/blockenc.go index 3df185e..fd4a36f 100644 --- a/vendor/github.com/klauspost/compress/zstd/blockenc.go +++ b/vendor/github.com/klauspost/compress/zstd/blockenc.go @@ -51,7 +51,7 @@ func (b *blockEnc) init() { if cap(b.literals) < maxCompressedBlockSize { b.literals = make([]byte, 0, maxCompressedBlockSize) } - const defSeqs = 200 + const defSeqs = 2000 if cap(b.sequences) < defSeqs { b.sequences = make([]seq, 0, defSeqs) } @@ -426,7 +426,7 @@ func fuzzFseEncoder(data []byte) int { return 0 } enc := fseEncoder{} - hist := enc.Histogram()[:256] + hist := enc.Histogram() maxSym := uint8(0) for i, v := range data { v = v & 63 @@ -473,7 +473,7 @@ func (b *blockEnc) encode(org []byte, raw, rawAllLits bool) error { return b.encodeLits(b.literals, rawAllLits) } // We want some difference to at least account for the headers. - saved := b.size - len(b.literals) - (b.size >> 5) + saved := b.size - len(b.literals) - (b.size >> 6) if saved < 16 { if org == nil { return errIncompressible @@ -722,52 +722,53 @@ func (b *blockEnc) encode(org []byte, raw, rawAllLits bool) error { println("Encoded seq", seq, s, "codes:", s.llCode, s.mlCode, s.ofCode, "states:", ll.state, ml.state, of.state, "bits:", llB, mlB, ofB) } seq-- - if llEnc.maxBits+mlEnc.maxBits+ofEnc.maxBits <= 32 { - // No need to flush (common) - for seq >= 0 { - s = b.sequences[seq] - wr.flush32() - llB, ofB, mlB := llTT[s.llCode], ofTT[s.ofCode], mlTT[s.mlCode] - // tabelog max is 8 for all. - of.encode(ofB) - ml.encode(mlB) - ll.encode(llB) - wr.flush32() - - // We checked that all can stay within 32 bits - wr.addBits32NC(s.litLen, llB.outBits) - wr.addBits32NC(s.matchLen, mlB.outBits) - wr.addBits32NC(s.offset, ofB.outBits) - - if debugSequences { - println("Encoded seq", seq, s) - } - - seq-- - } - } else { - for seq >= 0 { - s = b.sequences[seq] - wr.flush32() - llB, ofB, mlB := llTT[s.llCode], ofTT[s.ofCode], mlTT[s.mlCode] - // tabelog max is below 8 for each. - of.encode(ofB) - ml.encode(mlB) - ll.encode(llB) - wr.flush32() - - // ml+ll = max 32 bits total - wr.addBits32NC(s.litLen, llB.outBits) - wr.addBits32NC(s.matchLen, mlB.outBits) - wr.flush32() - wr.addBits32NC(s.offset, ofB.outBits) - - if debugSequences { - println("Encoded seq", seq, s) - } - - seq-- - } + // Store sequences in reverse... + for seq >= 0 { + s = b.sequences[seq] + + ofB := ofTT[s.ofCode] + wr.flush32() // tablelog max is below 8 for each, so it will fill max 24 bits. + //of.encode(ofB) + nbBitsOut := (uint32(of.state) + ofB.deltaNbBits) >> 16 + dstState := int32(of.state>>(nbBitsOut&15)) + int32(ofB.deltaFindState) + wr.addBits16NC(of.state, uint8(nbBitsOut)) + of.state = of.stateTable[dstState] + + // Accumulate extra bits. + outBits := ofB.outBits & 31 + extraBits := uint64(s.offset & bitMask32[outBits]) + extraBitsN := outBits + + mlB := mlTT[s.mlCode] + //ml.encode(mlB) + nbBitsOut = (uint32(ml.state) + mlB.deltaNbBits) >> 16 + dstState = int32(ml.state>>(nbBitsOut&15)) + int32(mlB.deltaFindState) + wr.addBits16NC(ml.state, uint8(nbBitsOut)) + ml.state = ml.stateTable[dstState] + + outBits = mlB.outBits & 31 + extraBits = extraBits<> 16 + dstState = int32(ll.state>>(nbBitsOut&15)) + int32(llB.deltaFindState) + wr.addBits16NC(ll.state, uint8(nbBitsOut)) + ll.state = ll.stateTable[dstState] + + outBits = llB.outBits & 31 + extraBits = extraBits<= b.size { - // Maybe even add a bigger margin. + // Discard and encode as raw block. + b.output = b.encodeRawTo(b.output[:bhOffset], org) + b.popOffsets() b.litEnc.Reuse = huff0.ReusePolicyNone - return errIncompressible + return nil } // Size is output minus block header. @@ -801,14 +805,13 @@ func (b *blockEnc) genCodes() { // nothing to do return } - if len(b.sequences) > math.MaxUint16 { panic("can only encode up to 64K sequences") } // No bounds checks after here: - llH := b.coders.llEnc.Histogram()[:256] - ofH := b.coders.ofEnc.Histogram()[:256] - mlH := b.coders.mlEnc.Histogram()[:256] + llH := b.coders.llEnc.Histogram() + ofH := b.coders.ofEnc.Histogram() + mlH := b.coders.mlEnc.Histogram() for i := range llH { llH[i] = 0 } @@ -820,7 +823,8 @@ func (b *blockEnc) genCodes() { } var llMax, ofMax, mlMax uint8 - for i, seq := range b.sequences { + for i := range b.sequences { + seq := &b.sequences[i] v := llCode(seq.litLen) seq.llCode = v llH[v]++ @@ -844,7 +848,6 @@ func (b *blockEnc) genCodes() { panic(fmt.Errorf("mlMax > maxMatchLengthSymbol (%d), matchlen: %d", mlMax, seq.matchLen)) } } - b.sequences[i] = seq } maxCount := func(a []uint32) int { var max uint32 diff --git a/vendor/github.com/klauspost/compress/zstd/bytebuf.go b/vendor/github.com/klauspost/compress/zstd/bytebuf.go index aab71c6..55a3885 100644 --- a/vendor/github.com/klauspost/compress/zstd/bytebuf.go +++ b/vendor/github.com/klauspost/compress/zstd/bytebuf.go @@ -7,7 +7,6 @@ package zstd import ( "fmt" "io" - "io/ioutil" ) type byteBuffer interface { @@ -23,7 +22,7 @@ type byteBuffer interface { readByte() (byte, error) // Skip n bytes. - skipN(n int) error + skipN(n int64) error } // in-memory buffer @@ -52,23 +51,22 @@ func (b *byteBuf) readBig(n int, dst []byte) ([]byte, error) { return r, nil } -func (b *byteBuf) remain() []byte { - return *b -} - func (b *byteBuf) readByte() (byte, error) { bb := *b if len(bb) < 1 { - return 0, nil + return 0, io.ErrUnexpectedEOF } r := bb[0] *b = bb[1:] return r, nil } -func (b *byteBuf) skipN(n int) error { +func (b *byteBuf) skipN(n int64) error { bb := *b - if len(bb) < n { + if n < 0 { + return fmt.Errorf("negative skip (%d) requested", n) + } + if int64(len(bb)) < n { return io.ErrUnexpectedEOF } *b = bb[n:] @@ -111,8 +109,11 @@ func (r *readerWrapper) readBig(n int, dst []byte) ([]byte, error) { } func (r *readerWrapper) readByte() (byte, error) { - n2, err := r.r.Read(r.tmp[:1]) + n2, err := io.ReadFull(r.r, r.tmp[:1]) if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } return 0, err } if n2 != 1 { @@ -121,9 +122,9 @@ func (r *readerWrapper) readByte() (byte, error) { return r.tmp[0], nil } -func (r *readerWrapper) skipN(n int) error { - n2, err := io.CopyN(ioutil.Discard, r.r, int64(n)) - if n2 != int64(n) { +func (r *readerWrapper) skipN(n int64) error { + n2, err := io.CopyN(io.Discard, r.r, n) + if n2 != n { err = io.ErrUnexpectedEOF } return err diff --git a/vendor/github.com/klauspost/compress/zstd/bytereader.go b/vendor/github.com/klauspost/compress/zstd/bytereader.go index 2c4fca1..0e59a24 100644 --- a/vendor/github.com/klauspost/compress/zstd/bytereader.go +++ b/vendor/github.com/klauspost/compress/zstd/bytereader.go @@ -13,12 +13,6 @@ type byteReader struct { off int } -// init will initialize the reader and set the input. -func (b *byteReader) init(in []byte) { - b.b = in - b.off = 0 -} - // advance the stream b n bytes. func (b *byteReader) advance(n uint) { b.off += int(n) diff --git a/vendor/github.com/klauspost/compress/zstd/decodeheader.go b/vendor/github.com/klauspost/compress/zstd/decodeheader.go index 69736e8..f6a2409 100644 --- a/vendor/github.com/klauspost/compress/zstd/decodeheader.go +++ b/vendor/github.com/klauspost/compress/zstd/decodeheader.go @@ -4,7 +4,7 @@ package zstd import ( - "bytes" + "encoding/binary" "errors" "io" ) @@ -15,18 +15,50 @@ const HeaderMaxSize = 14 + 3 // Header contains information about the first frame and block within that. type Header struct { - // Window Size the window of data to keep while decoding. - // Will only be set if HasFCS is false. - WindowSize uint64 + // SingleSegment specifies whether the data is to be decompressed into a + // single contiguous memory segment. + // It implies that WindowSize is invalid and that FrameContentSize is valid. + SingleSegment bool - // Frame content size. - // Expected size of the entire frame. - FrameContentSize uint64 + // WindowSize is the window of data to keep while decoding. + // Will only be set if SingleSegment is false. + WindowSize uint64 // Dictionary ID. // If 0, no dictionary. DictionaryID uint32 + // HasFCS specifies whether FrameContentSize has a valid value. + HasFCS bool + + // FrameContentSize is the expected uncompressed size of the entire frame. + FrameContentSize uint64 + + // Skippable will be true if the frame is meant to be skipped. + // This implies that FirstBlock.OK is false. + Skippable bool + + // SkippableID is the user-specific ID for the skippable frame. + // Valid values are between 0 to 15, inclusive. + SkippableID int + + // SkippableSize is the length of the user data to skip following + // the header. + SkippableSize uint32 + + // HeaderSize is the raw size of the frame header. + // + // For normal frames, it includes the size of the magic number and + // the size of the header (per section 3.1.1.1). + // It does not include the size for any data blocks (section 3.1.1.2) nor + // the size for the trailing content checksum. + // + // For skippable frames, this counts the size of the magic number + // along with the size of the size field of the payload. + // It does not include the size of the skippable payload itself. + // The total frame size is the HeaderSize plus the SkippableSize. + HeaderSize int + // First block information. FirstBlock struct { // OK will be set if first block could be decoded. @@ -51,17 +83,9 @@ type Header struct { CompressedSize int } - // Skippable will be true if the frame is meant to be skipped. - // No other information will be populated. - Skippable bool - // If set there is a checksum present for the block content. + // The checksum field at the end is always 4 bytes long. HasCheckSum bool - - // If this is true FrameContentSize will have a valid value - HasFCS bool - - SingleSegment bool } // Decode the header from the beginning of the stream. @@ -71,39 +95,46 @@ type Header struct { // If there isn't enough input, io.ErrUnexpectedEOF is returned. // The FirstBlock.OK will indicate if enough information was available to decode the first block header. func (h *Header) Decode(in []byte) error { + *h = Header{} if len(in) < 4 { return io.ErrUnexpectedEOF } + h.HeaderSize += 4 b, in := in[:4], in[4:] - if !bytes.Equal(b, frameMagic) { - if !bytes.Equal(b[1:4], skippableFrameMagic) || b[0]&0xf0 != 0x50 { + if string(b) != frameMagic { + if string(b[1:4]) != skippableFrameMagic || b[0]&0xf0 != 0x50 { return ErrMagicMismatch } - *h = Header{Skippable: true} + if len(in) < 4 { + return io.ErrUnexpectedEOF + } + h.HeaderSize += 4 + h.Skippable = true + h.SkippableID = int(b[0] & 0xf) + h.SkippableSize = binary.LittleEndian.Uint32(in) return nil } + + // Read Window_Descriptor + // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor if len(in) < 1 { return io.ErrUnexpectedEOF } - - // Clear output - *h = Header{} fhd, in := in[0], in[1:] + h.HeaderSize++ h.SingleSegment = fhd&(1<<5) != 0 h.HasCheckSum = fhd&(1<<2) != 0 - if fhd&(1<<3) != 0 { return errors.New("reserved bit set on frame header") } - // Read Window_Descriptor - // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor if !h.SingleSegment { if len(in) < 1 { return io.ErrUnexpectedEOF } var wd byte wd, in = in[0], in[1:] + h.HeaderSize++ windowLog := 10 + (wd >> 3) windowBase := uint64(1) << windowLog windowAdd := (windowBase / 8) * uint64(wd&0x7) @@ -120,10 +151,8 @@ func (h *Header) Decode(in []byte) error { return io.ErrUnexpectedEOF } b, in = in[:size], in[size:] - if b == nil { - return io.ErrUnexpectedEOF - } - switch size { + h.HeaderSize += int(size) + switch len(b) { case 1: h.DictionaryID = uint32(b[0]) case 2: @@ -152,10 +181,8 @@ func (h *Header) Decode(in []byte) error { return io.ErrUnexpectedEOF } b, in = in[:fcsSize], in[fcsSize:] - if b == nil { - return io.ErrUnexpectedEOF - } - switch fcsSize { + h.HeaderSize += int(fcsSize) + switch len(b) { case 1: h.FrameContentSize = uint64(b[0]) case 2: diff --git a/vendor/github.com/klauspost/compress/zstd/decoder.go b/vendor/github.com/klauspost/compress/zstd/decoder.go index f430f58..f04aaa2 100644 --- a/vendor/github.com/klauspost/compress/zstd/decoder.go +++ b/vendor/github.com/klauspost/compress/zstd/decoder.go @@ -5,9 +5,12 @@ package zstd import ( - "errors" + "context" + "encoding/binary" "io" "sync" + + "github.com/klauspost/compress/zstd/internal/xxhash" ) // Decoder provides decoding of zstandard streams. @@ -22,15 +25,22 @@ type Decoder struct { // Unreferenced decoders, ready for use. decoders chan *blockDec - // Streams ready to be decoded. - stream chan decodeStream - // Current read position used for Reader functionality. current decoderState + // sync stream decoding + syncStream struct { + decodedFrame uint64 + br readerWrapper + enabled bool + inFrame bool + dstBuf []byte + } + + frame *frameDec + // Custom dictionaries. - // Always uses copies. - dicts map[uint32]dict + dicts map[uint32]*dict // streamWg is the waitgroup for all streams streamWg sync.WaitGroup @@ -46,7 +56,10 @@ type decoderState struct { output chan decodeOutput // cancel remaining output. - cancel chan struct{} + cancel context.CancelFunc + + // crc of current frame + crc *xxhash.Digest flushed bool } @@ -81,7 +94,7 @@ func NewReader(r io.Reader, opts ...DOption) (*Decoder, error) { return nil, err } } - d.current.output = make(chan decodeOutput, d.o.concurrent) + d.current.crc = xxhash.New() d.current.flushed = true if r == nil { @@ -89,7 +102,7 @@ func NewReader(r io.Reader, opts ...DOption) (*Decoder, error) { } // Transfer option dicts. - d.dicts = make(map[uint32]dict, len(d.o.dicts)) + d.dicts = make(map[uint32]*dict, len(d.o.dicts)) for _, dc := range d.o.dicts { d.dicts[dc.id] = dc } @@ -130,7 +143,7 @@ func (d *Decoder) Read(p []byte) (int, error) { break } if !d.nextBlock(n == 0) { - return n, nil + return n, d.current.err } } } @@ -162,6 +175,7 @@ func (d *Decoder) Reset(r io.Reader) error { d.drainOutput() + d.syncStream.br.r = nil if r == nil { d.current.err = ErrDecoderNilInput if len(d.current.b) > 0 { @@ -172,21 +186,23 @@ func (d *Decoder) Reset(r io.Reader) error { } // If bytes buffer and < 5MB, do sync decoding anyway. - if bb, ok := r.(byter); ok && bb.Len() < 5<<20 { + if bb, ok := r.(byter); ok && bb.Len() < d.o.decodeBufsBelow && !d.o.limitToCap { bb2 := bb if debugDecoder { println("*bytes.Buffer detected, doing sync decode, len:", bb.Len()) } b := bb2.Bytes() var dst []byte - if cap(d.current.b) > 0 { - dst = d.current.b + if cap(d.syncStream.dstBuf) > 0 { + dst = d.syncStream.dstBuf[:0] } - dst, err := d.DecodeAll(b, dst[:0]) + dst, err := d.DecodeAll(b, dst) if err == nil { err = io.EOF } + // Save output buffer + d.syncStream.dstBuf = dst d.current.b = dst d.current.err = err d.current.flushed = true @@ -195,33 +211,40 @@ func (d *Decoder) Reset(r io.Reader) error { } return nil } - - if d.stream == nil { - d.stream = make(chan decodeStream, 1) - d.streamWg.Add(1) - go d.startStreamDecoder(d.stream) - } - // Remove current block. + d.stashDecoder() d.current.decodeOutput = decodeOutput{} d.current.err = nil - d.current.cancel = make(chan struct{}) d.current.flushed = false d.current.d = nil + d.syncStream.dstBuf = nil - d.stream <- decodeStream{ - r: r, - output: d.current.output, - cancel: d.current.cancel, + // Ensure no-one else is still running... + d.streamWg.Wait() + if d.frame == nil { + d.frame = newFrameDec(d.o) } + + if d.o.concurrent == 1 { + return d.startSyncDecoder(r) + } + + d.current.output = make(chan decodeOutput, d.o.concurrent) + ctx, cancel := context.WithCancel(context.Background()) + d.current.cancel = cancel + d.streamWg.Add(1) + go d.startStreamDecoder(ctx, r, d.current.output) + return nil } // drainOutput will drain the output until errEndOfStream is sent. func (d *Decoder) drainOutput() { if d.current.cancel != nil { - println("cancelling current") - close(d.current.cancel) + if debugDecoder { + println("cancelling current") + } + d.current.cancel() d.current.cancel = nil } if d.current.d != nil { @@ -243,12 +266,9 @@ func (d *Decoder) drainOutput() { } d.decoders <- v.d } - if v.err == errEndOfStream { - println("current flushed") - d.current.flushed = true - return - } } + d.current.output = nil + d.current.flushed = true } // WriteTo writes data to w until there's no more data to write or when an error occurs. @@ -287,19 +307,23 @@ func (d *Decoder) WriteTo(w io.Writer) (int64, error) { // DecodeAll can be used concurrently. // The Decoder concurrency limits will be respected. func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { - if d.current.err == ErrDecoderClosed { + if d.decoders == nil { return dst, ErrDecoderClosed } // Grab a block decoder and frame decoder. block := <-d.decoders frame := block.localFrame + initialSize := len(dst) defer func() { if debugDecoder { printf("re-adding decoder: %p", block) } frame.rawInput = nil frame.bBuf = nil + if frame.history.decoders.br != nil { + frame.history.decoders.br.in = nil + } d.decoders <- block }() frame.bBuf = input @@ -307,34 +331,45 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { for { frame.history.reset() err := frame.reset(&frame.bBuf) - if err == io.EOF { - if debugDecoder { - println("frame reset return EOF") - } - return dst, nil - } - if frame.DictionaryID != nil { - dict, ok := d.dicts[*frame.DictionaryID] - if !ok { - return nil, ErrUnknownDictionary - } - frame.history.setDict(&dict) - } if err != nil { + if err == io.EOF { + if debugDecoder { + println("frame reset return EOF") + } + return dst, nil + } return dst, err } - if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) { - return dst, ErrDecoderSizeExceeded + if err = d.setDict(frame); err != nil { + return nil, err } - if frame.FrameContentSize > 0 && frame.FrameContentSize < 1<<30 { - // Never preallocate moe than 1 GB up front. + if frame.WindowSize > d.o.maxWindowSize { + if debugDecoder { + println("window size exceeded:", frame.WindowSize, ">", d.o.maxWindowSize) + } + return dst, ErrWindowSizeExceeded + } + if frame.FrameContentSize != fcsUnknown { + if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)-initialSize) { + if debugDecoder { + println("decoder size exceeded; fcs:", frame.FrameContentSize, "> mcs:", d.o.maxDecodedSize-uint64(len(dst)-initialSize), "len:", len(dst)) + } + return dst, ErrDecoderSizeExceeded + } + if d.o.limitToCap && frame.FrameContentSize > uint64(cap(dst)-len(dst)) { + if debugDecoder { + println("decoder size exceeded; fcs:", frame.FrameContentSize, "> (cap-len)", cap(dst)-len(dst)) + } + return dst, ErrDecoderSizeExceeded + } if cap(dst)-len(dst) < int(frame.FrameContentSize) { - dst2 := make([]byte, len(dst), len(dst)+int(frame.FrameContentSize)) + dst2 := make([]byte, len(dst), len(dst)+int(frame.FrameContentSize)+compressedBlockOverAlloc) copy(dst2, dst) dst = dst2 } } - if cap(dst) == 0 { + + if cap(dst) == 0 && !d.o.limitToCap { // Allocate len(input) * 2 by default if nothing is provided // and we didn't get frame content size. size := len(input) * 2 @@ -352,6 +387,9 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { if err != nil { return dst, err } + if uint64(len(dst)-initialSize) > d.o.maxDecodedSize { + return dst, ErrDecoderSizeExceeded + } if len(frame.bBuf) == 0 { if debugDecoder { println("frame dbuf empty") @@ -368,33 +406,167 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { // If non-blocking mode is used the returned boolean will be false // if no data was available without blocking. func (d *Decoder) nextBlock(blocking bool) (ok bool) { - if d.current.d != nil { - if debugDecoder { - printf("re-adding current decoder %p", d.current.d) - } - d.decoders <- d.current.d - d.current.d = nil - } if d.current.err != nil { // Keep error state. - return blocking + return false } + d.current.b = d.current.b[:0] + // SYNC: + if d.syncStream.enabled { + if !blocking { + return false + } + ok = d.nextBlockSync() + if !ok { + d.stashDecoder() + } + return ok + } + + //ASYNC: + d.stashDecoder() if blocking { - d.current.decodeOutput = <-d.current.output + d.current.decodeOutput, ok = <-d.current.output } else { select { - case d.current.decodeOutput = <-d.current.output: + case d.current.decodeOutput, ok = <-d.current.output: default: return false } } + if !ok { + // This should not happen, so signal error state... + d.current.err = io.ErrUnexpectedEOF + return false + } + next := d.current.decodeOutput + if next.d != nil && next.d.async.newHist != nil { + d.current.crc.Reset() + } if debugDecoder { - println("got", len(d.current.b), "bytes, error:", d.current.err) + var tmp [4]byte + binary.LittleEndian.PutUint32(tmp[:], uint32(xxhash.Sum64(next.b))) + println("got", len(d.current.b), "bytes, error:", d.current.err, "data crc:", tmp) + } + + if d.o.ignoreChecksum { + return true + } + + if len(next.b) > 0 { + d.current.crc.Write(next.b) + } + if next.err == nil && next.d != nil && next.d.hasCRC { + got := uint32(d.current.crc.Sum64()) + if got != next.d.checkCRC { + if debugDecoder { + printf("CRC Check Failed: %08x (got) != %08x (on stream)\n", got, next.d.checkCRC) + } + d.current.err = ErrCRCMismatch + } else { + if debugDecoder { + printf("CRC ok %08x\n", got) + } + } + } + + return true +} + +func (d *Decoder) nextBlockSync() (ok bool) { + if d.current.d == nil { + d.current.d = <-d.decoders + } + for len(d.current.b) == 0 { + if !d.syncStream.inFrame { + d.frame.history.reset() + d.current.err = d.frame.reset(&d.syncStream.br) + if d.current.err == nil { + d.current.err = d.setDict(d.frame) + } + if d.current.err != nil { + return false + } + if d.frame.WindowSize > d.o.maxDecodedSize || d.frame.WindowSize > d.o.maxWindowSize { + d.current.err = ErrDecoderSizeExceeded + return false + } + + d.syncStream.decodedFrame = 0 + d.syncStream.inFrame = true + } + d.current.err = d.frame.next(d.current.d) + if d.current.err != nil { + return false + } + d.frame.history.ensureBlock() + if debugDecoder { + println("History trimmed:", len(d.frame.history.b), "decoded already:", d.syncStream.decodedFrame) + } + histBefore := len(d.frame.history.b) + d.current.err = d.current.d.decodeBuf(&d.frame.history) + + if d.current.err != nil { + println("error after:", d.current.err) + return false + } + d.current.b = d.frame.history.b[histBefore:] + if debugDecoder { + println("history after:", len(d.frame.history.b)) + } + + // Check frame size (before CRC) + d.syncStream.decodedFrame += uint64(len(d.current.b)) + if d.syncStream.decodedFrame > d.frame.FrameContentSize { + if debugDecoder { + printf("DecodedFrame (%d) > FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize) + } + d.current.err = ErrFrameSizeExceeded + return false + } + + // Check FCS + if d.current.d.Last && d.frame.FrameContentSize != fcsUnknown && d.syncStream.decodedFrame != d.frame.FrameContentSize { + if debugDecoder { + printf("DecodedFrame (%d) != FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize) + } + d.current.err = ErrFrameSizeMismatch + return false + } + + // Update/Check CRC + if d.frame.HasCheckSum { + if !d.o.ignoreChecksum { + d.frame.crc.Write(d.current.b) + } + if d.current.d.Last { + if !d.o.ignoreChecksum { + d.current.err = d.frame.checkCRC() + } else { + d.current.err = d.frame.consumeCRC() + } + if d.current.err != nil { + println("CRC error:", d.current.err) + return false + } + } + } + d.syncStream.inFrame = !d.current.d.Last } return true } +func (d *Decoder) stashDecoder() { + if d.current.d != nil { + if debugDecoder { + printf("re-adding current decoder %p", d.current.d) + } + d.decoders <- d.current.d + d.current.d = nil + } +} + // Close will release all resources. // It is NOT possible to reuse the decoder after this. func (d *Decoder) Close() { @@ -402,10 +574,10 @@ func (d *Decoder) Close() { return } d.drainOutput() - if d.stream != nil { - close(d.stream) + if d.current.cancel != nil { + d.current.cancel() d.streamWg.Wait() - d.stream = nil + d.current.cancel = nil } if d.decoders != nil { close(d.decoders) @@ -456,100 +628,321 @@ type decodeOutput struct { err error } -type decodeStream struct { - r io.Reader - - // Blocks ready to be written to output. - output chan decodeOutput - - // cancel reading from the input - cancel chan struct{} +func (d *Decoder) startSyncDecoder(r io.Reader) error { + d.frame.history.reset() + d.syncStream.br = readerWrapper{r: r} + d.syncStream.inFrame = false + d.syncStream.enabled = true + d.syncStream.decodedFrame = 0 + return nil } -// errEndOfStream indicates that everything from the stream was read. -var errEndOfStream = errors.New("end-of-stream") - // Create Decoder: -// Spawn n block decoders. These accept tasks to decode a block. -// Create goroutine that handles stream processing, this will send history to decoders as they are available. -// Decoders update the history as they decode. -// When a block is returned: -// a) history is sent to the next decoder, -// b) content written to CRC. -// c) return data to WRITER. -// d) wait for next block to return data. -// Once WRITTEN, the decoders reused by the writer frame decoder for re-use. -func (d *Decoder) startStreamDecoder(inStream chan decodeStream) { +// ASYNC: +// Spawn 3 go routines. +// 0: Read frames and decode block literals. +// 1: Decode sequences. +// 2: Execute sequences, send to output. +func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output chan decodeOutput) { defer d.streamWg.Done() - frame := newFrameDec(d.o) - for stream := range inStream { - if debugDecoder { - println("got new stream") + br := readerWrapper{r: r} + + var seqDecode = make(chan *blockDec, d.o.concurrent) + var seqExecute = make(chan *blockDec, d.o.concurrent) + + // Async 1: Decode sequences... + go func() { + var hist history + var hasErr bool + + for block := range seqDecode { + if hasErr { + if block != nil { + seqExecute <- block + } + continue + } + if block.async.newHist != nil { + if debugDecoder { + println("Async 1: new history, recent:", block.async.newHist.recentOffsets) + } + hist.reset() + hist.decoders = block.async.newHist.decoders + hist.recentOffsets = block.async.newHist.recentOffsets + hist.windowSize = block.async.newHist.windowSize + if block.async.newHist.dict != nil { + hist.setDict(block.async.newHist.dict) + } + } + if block.err != nil || block.Type != blockTypeCompressed { + hasErr = block.err != nil + seqExecute <- block + continue + } + + hist.decoders.literals = block.async.literals + block.err = block.prepareSequences(block.async.seqData, &hist) + if debugDecoder && block.err != nil { + println("prepareSequences returned:", block.err) + } + hasErr = block.err != nil + if block.err == nil { + block.err = block.decodeSequences(&hist) + if debugDecoder && block.err != nil { + println("decodeSequences returned:", block.err) + } + hasErr = block.err != nil + // block.async.sequence = hist.decoders.seq[:hist.decoders.nSeqs] + block.async.seqSize = hist.decoders.seqSize + } + seqExecute <- block } - br := readerWrapper{r: stream.r} - decodeStream: - for { - frame.history.reset() - err := frame.reset(&br) - if debugDecoder && err != nil { - println("Frame decoder returned", err) + close(seqExecute) + hist.reset() + }() + + var wg sync.WaitGroup + wg.Add(1) + + // Async 3: Execute sequences... + frameHistCache := d.frame.history.b + go func() { + var hist history + var decodedFrame uint64 + var fcs uint64 + var hasErr bool + for block := range seqExecute { + out := decodeOutput{err: block.err, d: block} + if block.err != nil || hasErr { + hasErr = true + output <- out + continue + } + if block.async.newHist != nil { + if debugDecoder { + println("Async 2: new history") + } + hist.reset() + hist.windowSize = block.async.newHist.windowSize + hist.allocFrameBuffer = block.async.newHist.allocFrameBuffer + if block.async.newHist.dict != nil { + hist.setDict(block.async.newHist.dict) + } + + if cap(hist.b) < hist.allocFrameBuffer { + if cap(frameHistCache) >= hist.allocFrameBuffer { + hist.b = frameHistCache + } else { + hist.b = make([]byte, 0, hist.allocFrameBuffer) + println("Alloc history sized", hist.allocFrameBuffer) + } + } + hist.b = hist.b[:0] + fcs = block.async.fcs + decodedFrame = 0 } - if err == nil && frame.DictionaryID != nil { - dict, ok := d.dicts[*frame.DictionaryID] - if !ok { - err = ErrUnknownDictionary + do := decodeOutput{err: block.err, d: block} + switch block.Type { + case blockTypeRLE: + if debugDecoder { + println("add rle block length:", block.RLESize) + } + + if cap(block.dst) < int(block.RLESize) { + if block.lowMem { + block.dst = make([]byte, block.RLESize) + } else { + block.dst = make([]byte, maxCompressedBlockSize) + } + } + block.dst = block.dst[:block.RLESize] + v := block.data[0] + for i := range block.dst { + block.dst[i] = v + } + hist.append(block.dst) + do.b = block.dst + case blockTypeRaw: + if debugDecoder { + println("add raw block length:", len(block.data)) + } + hist.append(block.data) + do.b = block.data + case blockTypeCompressed: + if debugDecoder { + println("execute with history length:", len(hist.b), "window:", hist.windowSize) + } + hist.decoders.seqSize = block.async.seqSize + hist.decoders.literals = block.async.literals + do.err = block.executeSequences(&hist) + hasErr = do.err != nil + if debugDecoder && hasErr { + println("executeSequences returned:", do.err) + } + do.b = block.dst + } + if !hasErr { + decodedFrame += uint64(len(do.b)) + if decodedFrame > fcs { + println("fcs exceeded", block.Last, fcs, decodedFrame) + do.err = ErrFrameSizeExceeded + hasErr = true + } else if block.Last && fcs != fcsUnknown && decodedFrame != fcs { + do.err = ErrFrameSizeMismatch + hasErr = true } else { - frame.history.setDict(&dict) + if debugDecoder { + println("fcs ok", block.Last, fcs, decodedFrame) + } } } - if err != nil { - stream.output <- decodeOutput{ - err: err, + output <- do + } + close(output) + frameHistCache = hist.b + wg.Done() + if debugDecoder { + println("decoder goroutines finished") + } + hist.reset() + }() + + var hist history +decodeStream: + for { + var hasErr bool + hist.reset() + decodeBlock := func(block *blockDec) { + if hasErr { + if block != nil { + seqDecode <- block } - break + return } + if block.err != nil || block.Type != blockTypeCompressed { + hasErr = block.err != nil + seqDecode <- block + return + } + + remain, err := block.decodeLiterals(block.data, &hist) + block.err = err + hasErr = block.err != nil + if err == nil { + block.async.literals = hist.decoders.literals + block.async.seqData = remain + } else if debugDecoder { + println("decodeLiterals error:", err) + } + seqDecode <- block + } + frame := d.frame + if debugDecoder { + println("New frame...") + } + var historySent bool + frame.history.reset() + err := frame.reset(&br) + if debugDecoder && err != nil { + println("Frame decoder returned", err) + } + if err == nil { + err = d.setDict(frame) + } + if err == nil && d.frame.WindowSize > d.o.maxWindowSize { if debugDecoder { - println("starting frame decoder") - } - - // This goroutine will forward history between frames. - frame.frameDone.Add(1) - frame.initAsync() - - go frame.startDecoder(stream.output) - decodeFrame: - // Go through all blocks of the frame. - for { - dec := <-d.decoders - select { - case <-stream.cancel: - if !frame.sendErr(dec, io.EOF) { - // To not let the decoder dangle, send it back. - stream.output <- decodeOutput{d: dec} - } - break decodeStream - default: + println("decoder size exceeded, fws:", d.frame.WindowSize, "> mws:", d.o.maxWindowSize) + } + + err = ErrDecoderSizeExceeded + } + if err != nil { + select { + case <-ctx.Done(): + case dec := <-d.decoders: + dec.sendErr(err) + decodeBlock(dec) + } + break decodeStream + } + + // Go through all blocks of the frame. + for { + var dec *blockDec + select { + case <-ctx.Done(): + break decodeStream + case dec = <-d.decoders: + // Once we have a decoder, we MUST return it. + } + err := frame.next(dec) + if !historySent { + h := frame.history + if debugDecoder { + println("Alloc History:", h.allocFrameBuffer) + } + hist.reset() + if h.dict != nil { + hist.setDict(h.dict) } - err := frame.next(dec) - switch err { - case io.EOF: - // End of current frame, no error - println("EOF on next block") - break decodeFrame - case nil: - continue - default: - println("block decoder returned", err) - break decodeStream + dec.async.newHist = &h + dec.async.fcs = frame.FrameContentSize + historySent = true + } else { + dec.async.newHist = nil + } + if debugDecoder && err != nil { + println("next block returned error:", err) + } + dec.err = err + dec.hasCRC = false + if dec.Last && frame.HasCheckSum && err == nil { + crc, err := frame.rawInput.readSmall(4) + if len(crc) < 4 { + if err == nil { + err = io.ErrUnexpectedEOF + + } + println("CRC missing?", err) + dec.err = err + } else { + dec.checkCRC = binary.LittleEndian.Uint32(crc) + dec.hasCRC = true + if debugDecoder { + printf("found crc to check: %08x\n", dec.checkCRC) + } } } - // All blocks have started decoding, check if there are more frames. - println("waiting for done") - frame.frameDone.Wait() - println("done waiting...") + err = dec.err + last := dec.Last + decodeBlock(dec) + if err != nil { + break decodeStream + } + if last { + break + } } - frame.frameDone.Wait() - println("Sending EOS") - stream.output <- decodeOutput{err: errEndOfStream} } + close(seqDecode) + wg.Wait() + hist.reset() + d.frame.history.b = frameHistCache +} + +func (d *Decoder) setDict(frame *frameDec) (err error) { + dict, ok := d.dicts[frame.DictionaryID] + if ok { + if debugDecoder { + println("setting dict", frame.DictionaryID) + } + frame.history.setDict(dict) + } else if frame.DictionaryID != 0 { + // A zero or missing dictionary id is ambiguous: + // either dictionary zero, or no dictionary. In particular, + // zstd --patch-from uses this id for the source file, + // so only return an error if the dictionary id is not zero. + err = ErrUnknownDictionary + } + return err } diff --git a/vendor/github.com/klauspost/compress/zstd/decoder_options.go b/vendor/github.com/klauspost/compress/zstd/decoder_options.go index 95cc9b8..774c5f0 100644 --- a/vendor/github.com/klauspost/compress/zstd/decoder_options.go +++ b/vendor/github.com/klauspost/compress/zstd/decoder_options.go @@ -6,6 +6,8 @@ package zstd import ( "errors" + "fmt" + "math/bits" "runtime" ) @@ -14,21 +16,28 @@ type DOption func(*decoderOptions) error // options retains accumulated state of multiple options. type decoderOptions struct { - lowMem bool - concurrent int - maxDecodedSize uint64 - maxWindowSize uint64 - dicts []dict + lowMem bool + concurrent int + maxDecodedSize uint64 + maxWindowSize uint64 + dicts []*dict + ignoreChecksum bool + limitToCap bool + decodeBufsBelow int } func (o *decoderOptions) setDefault() { *o = decoderOptions{ // use less ram: true for now, but may change. - lowMem: true, - concurrent: runtime.GOMAXPROCS(0), - maxWindowSize: MaxWindowSize, + lowMem: true, + concurrent: runtime.GOMAXPROCS(0), + maxWindowSize: MaxWindowSize, + decodeBufsBelow: 128 << 10, } - o.maxDecodedSize = 1 << 63 + if o.concurrent > 4 { + o.concurrent = 4 + } + o.maxDecodedSize = 64 << 30 } // WithDecoderLowmem will set whether to use a lower amount of memory, @@ -37,16 +46,25 @@ func WithDecoderLowmem(b bool) DOption { return func(o *decoderOptions) error { o.lowMem = b; return nil } } -// WithDecoderConcurrency will set the concurrency, -// meaning the maximum number of decoders to run concurrently. -// The value supplied must be at least 1. -// By default this will be set to GOMAXPROCS. +// WithDecoderConcurrency sets the number of created decoders. +// When decoding block with DecodeAll, this will limit the number +// of possible concurrently running decodes. +// When decoding streams, this will limit the number of +// inflight blocks. +// When decoding streams and setting maximum to 1, +// no async decoding will be done. +// When a value of 0 is provided GOMAXPROCS will be used. +// By default this will be set to 4 or GOMAXPROCS, whatever is lower. func WithDecoderConcurrency(n int) DOption { return func(o *decoderOptions) error { - if n <= 0 { + if n < 0 { return errors.New("concurrency must be at least 1") } - o.concurrent = n + if n == 0 { + o.concurrent = runtime.GOMAXPROCS(0) + } else { + o.concurrent = n + } return nil } } @@ -54,7 +72,7 @@ func WithDecoderConcurrency(n int) DOption { // WithDecoderMaxMemory allows to set a maximum decoded size for in-memory // non-streaming operations or maximum window size for streaming operations. // This can be used to control memory usage of potentially hostile content. -// Maximum and default is 1 << 63 bytes. +// Maximum is 1 << 63 bytes. Default is 64GiB. func WithDecoderMaxMemory(n uint64) DOption { return func(o *decoderOptions) error { if n == 0 { @@ -69,7 +87,13 @@ func WithDecoderMaxMemory(n uint64) DOption { } // WithDecoderDicts allows to register one or more dictionaries for the decoder. -// If several dictionaries with the same ID is provided the last one will be used. +// +// Each slice in dict must be in the [dictionary format] produced by +// "zstd --train" from the Zstandard reference implementation. +// +// If several dictionaries with the same ID are provided, the last one will be used. +// +// [dictionary format]: https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary-format func WithDecoderDicts(dicts ...[]byte) DOption { return func(o *decoderOptions) error { for _, b := range dicts { @@ -77,8 +101,20 @@ func WithDecoderDicts(dicts ...[]byte) DOption { if err != nil { return err } - o.dicts = append(o.dicts, *d) + o.dicts = append(o.dicts, d) + } + return nil + } +} + +// WithDecoderDictRaw registers a dictionary that may be used by the decoder. +// The slice content can be arbitrary data. +func WithDecoderDictRaw(id uint32, content []byte) DOption { + return func(o *decoderOptions) error { + if bits.UintSize > 32 && uint(len(content)) > dictMaxLength { + return fmt.Errorf("dictionary of size %d > 2GiB too large", len(content)) } + o.dicts = append(o.dicts, &dict{id: id, content: content, offsets: [3]int{1, 4, 8}}) return nil } } @@ -100,3 +136,34 @@ func WithDecoderMaxWindow(size uint64) DOption { return nil } } + +// WithDecodeAllCapLimit will limit DecodeAll to decoding cap(dst)-len(dst) bytes, +// or any size set in WithDecoderMaxMemory. +// This can be used to limit decoding to a specific maximum output size. +// Disabled by default. +func WithDecodeAllCapLimit(b bool) DOption { + return func(o *decoderOptions) error { + o.limitToCap = b + return nil + } +} + +// WithDecodeBuffersBelow will fully decode readers that have a +// `Bytes() []byte` and `Len() int` interface similar to bytes.Buffer. +// This typically uses less allocations but will have the full decompressed object in memory. +// Note that DecodeAllCapLimit will disable this, as well as giving a size of 0 or less. +// Default is 128KiB. +func WithDecodeBuffersBelow(size int) DOption { + return func(o *decoderOptions) error { + o.decodeBufsBelow = size + return nil + } +} + +// IgnoreChecksum allows to forcibly ignore checksum checking. +func IgnoreChecksum(b bool) DOption { + return func(o *decoderOptions) error { + o.ignoreChecksum = b + return nil + } +} diff --git a/vendor/github.com/klauspost/compress/zstd/dict.go b/vendor/github.com/klauspost/compress/zstd/dict.go index a36ae83..ca09514 100644 --- a/vendor/github.com/klauspost/compress/zstd/dict.go +++ b/vendor/github.com/klauspost/compress/zstd/dict.go @@ -1,7 +1,6 @@ package zstd import ( - "bytes" "encoding/binary" "errors" "fmt" @@ -20,7 +19,10 @@ type dict struct { content []byte } -var dictMagic = [4]byte{0x37, 0xa4, 0x30, 0xec} +const dictMagic = "\x37\xa4\x30\xec" + +// Maximum dictionary size for the reference implementation (1.5.3) is 2 GiB. +const dictMaxLength = 1 << 31 // ID returns the dictionary id or 0 if d is nil. func (d *dict) ID() uint32 { @@ -30,14 +32,38 @@ func (d *dict) ID() uint32 { return d.id } -// DictContentSize returns the dictionary content size or 0 if d is nil. -func (d *dict) DictContentSize() int { +// ContentSize returns the dictionary content size or 0 if d is nil. +func (d *dict) ContentSize() int { if d == nil { return 0 } return len(d.content) } +// Content returns the dictionary content. +func (d *dict) Content() []byte { + if d == nil { + return nil + } + return d.content +} + +// Offsets returns the initial offsets. +func (d *dict) Offsets() [3]int { + if d == nil { + return [3]int{} + } + return d.offsets +} + +// LitEncoder returns the literal encoder. +func (d *dict) LitEncoder() *huff0.Scratch { + if d == nil { + return nil + } + return d.litEnc +} + // Load a dictionary as described in // https://github.com/facebook/zstd/blob/master/doc/zstd_compression_format.md#dictionary-format func loadDict(b []byte) (*dict, error) { @@ -50,7 +76,7 @@ func loadDict(b []byte) (*dict, error) { ofDec: sequenceDec{fse: &fseDecoder{}}, mlDec: sequenceDec{fse: &fseDecoder{}}, } - if !bytes.Equal(b[:4], dictMagic[:]) { + if string(b[:4]) != dictMagic { return nil, ErrMagicMismatch } d.id = binary.LittleEndian.Uint32(b[4:8]) @@ -62,7 +88,7 @@ func loadDict(b []byte) (*dict, error) { var err error d.litEnc, b, err = huff0.ReadTable(b[8:], nil) if err != nil { - return nil, err + return nil, fmt.Errorf("loading literal table: %w", err) } d.litEnc.Reuse = huff0.ReusePolicyMust @@ -120,3 +146,16 @@ func loadDict(b []byte) (*dict, error) { return &d, nil } + +// InspectDictionary loads a zstd dictionary and provides functions to inspect the content. +func InspectDictionary(b []byte) (interface { + ID() uint32 + ContentSize() int + Content() []byte + Offsets() [3]int + LitEncoder() *huff0.Scratch +}, error) { + initPredefined() + d, err := loadDict(b) + return d, err +} diff --git a/vendor/github.com/klauspost/compress/zstd/enc_base.go b/vendor/github.com/klauspost/compress/zstd/enc_base.go index 295cd60..5ca4603 100644 --- a/vendor/github.com/klauspost/compress/zstd/enc_base.go +++ b/vendor/github.com/klauspost/compress/zstd/enc_base.go @@ -16,6 +16,7 @@ type fastBase struct { cur int32 // maximum offset. Should be at least 2x block size. maxMatchOff int32 + bufferReset int32 hist []byte crc *xxhash.Digest tmp [8]byte @@ -56,8 +57,8 @@ func (e *fastBase) Block() *blockEnc { } func (e *fastBase) addBlock(src []byte) int32 { - if debugAsserts && e.cur > bufferReset { - panic(fmt.Sprintf("ecur (%d) > buffer reset (%d)", e.cur, bufferReset)) + if debugAsserts && e.cur > e.bufferReset { + panic(fmt.Sprintf("ecur (%d) > buffer reset (%d)", e.cur, e.bufferReset)) } // check if we have space already if len(e.hist)+len(src) > cap(e.hist) { @@ -108,11 +109,6 @@ func (e *fastBase) UseBlock(enc *blockEnc) { e.blk = enc } -func (e *fastBase) matchlenNoHist(s, t int32, src []byte) int32 { - // Extend the match to be as long as possible. - return int32(matchLen(src[s:], src[t:])) -} - func (e *fastBase) matchlen(s, t int32, src []byte) int32 { if debugAsserts { if s < 0 { @@ -131,8 +127,6 @@ func (e *fastBase) matchlen(s, t int32, src []byte) int32 { panic(fmt.Sprintf("len(src)-s (%d) > maxCompressedBlockSize (%d)", len(src)-int(s), maxCompressedBlockSize)) } } - - // Extend the match to be as long as possible. return int32(matchLen(src[s:], src[t:])) } @@ -150,18 +144,19 @@ func (e *fastBase) resetBase(d *dict, singleBlock bool) { } else { e.crc.Reset() } + e.blk.dictLitEnc = nil if d != nil { low := e.lowMem if singleBlock { e.lowMem = true } - e.ensureHist(d.DictContentSize() + maxCompressedBlockSize) + e.ensureHist(d.ContentSize() + maxCompressedBlockSize) e.lowMem = low } // We offset current position so everything will be out of reach. // If above reset line, history will be purged. - if e.cur < bufferReset { + if e.cur < e.bufferReset { e.cur += e.maxMatchOff + int32(len(e.hist)) } e.hist = e.hist[:0] diff --git a/vendor/github.com/klauspost/compress/zstd/enc_best.go b/vendor/github.com/klauspost/compress/zstd/enc_best.go index 96028ec..9819d41 100644 --- a/vendor/github.com/klauspost/compress/zstd/enc_best.go +++ b/vendor/github.com/klauspost/compress/zstd/enc_best.go @@ -34,7 +34,7 @@ type match struct { est int32 } -const highScore = 25000 +const highScore = maxMatchLen * 8 // estBits will estimate output bits from predefined tables. func (m *match) estBits(bitsPerByte int32) { @@ -84,14 +84,10 @@ func (e *bestFastEncoder) Encode(blk *blockEnc, src []byte) { ) // Protect against e.cur wraparound. - for e.cur >= bufferReset { + for e.cur >= e.bufferReset-int32(len(e.hist)) { if len(e.hist) == 0 { - for i := range e.table[:] { - e.table[i] = prevEntry{} - } - for i := range e.longTable[:] { - e.longTable[i] = prevEntry{} - } + e.table = [bestShortTableSize]prevEntry{} + e.longTable = [bestLongTableSize]prevEntry{} e.cur = e.maxMatchOff break } @@ -163,7 +159,6 @@ func (e *bestFastEncoder) Encode(blk *blockEnc, src []byte) { // nextEmit is where in src the next emitLiteral should start from. nextEmit := s - cv := load6432(src, s) // Relative offsets offset1 := int32(blk.recentOffsets[0]) @@ -177,7 +172,6 @@ func (e *bestFastEncoder) Encode(blk *blockEnc, src []byte) { blk.literals = append(blk.literals, src[nextEmit:until]...) s.litLen = uint32(until - nextEmit) } - _ = addLiterals if debugEncoder { println("recent offsets:", blk.recentOffsets) @@ -192,49 +186,96 @@ encodeLoop: panic("offset0 was 0") } - bestOf := func(a, b match) match { - if a.est+(a.s-b.s)*bitsPerByte>>10 < b.est+(b.s-a.s)*bitsPerByte>>10 { - return a - } - return b - } - const goodEnough = 100 + const goodEnough = 250 + + cv := load6432(src, s) nextHashL := hashLen(cv, bestLongTableBits, bestLongLen) nextHashS := hashLen(cv, bestShortTableBits, bestShortLen) candidateL := e.longTable[nextHashL] candidateS := e.table[nextHashS] - matchAt := func(offset int32, s int32, first uint32, rep int32) match { + // Set m to a match at offset if it looks like that will improve compression. + improve := func(m *match, offset int32, s int32, first uint32, rep int32) { if s-offset >= e.maxMatchOff || load3232(src, offset) != first { - return match{s: s, est: highScore} + return } if debugAsserts { + if offset <= 0 { + panic(offset) + } if !bytes.Equal(src[s:s+4], src[offset:offset+4]) { panic(fmt.Sprintf("first match mismatch: %v != %v, first: %08x", src[s:s+4], src[offset:offset+4], first)) } } - m := match{offset: offset, s: s, length: 4 + e.matchlen(s+4, offset+4, src), rep: rep} - m.estBits(bitsPerByte) - return m + // Try to quick reject if we already have a long match. + if m.length > 16 { + left := len(src) - int(m.s+m.length) + // If we are too close to the end, keep as is. + if left <= 0 { + return + } + checkLen := m.length - (s - m.s) - 8 + if left > 2 && checkLen > 4 { + // Check 4 bytes, 4 bytes from the end of the current match. + a := load3232(src, offset+checkLen) + b := load3232(src, s+checkLen) + if a != b { + return + } + } + } + l := 4 + e.matchlen(s+4, offset+4, src) + if rep < 0 { + // Extend candidate match backwards as far as possible. + tMin := s - e.maxMatchOff + if tMin < 0 { + tMin = 0 + } + for offset > tMin && s > nextEmit && src[offset-1] == src[s-1] && l < maxMatchLength { + s-- + offset-- + l++ + } + } + + cand := match{offset: offset, s: s, length: l, rep: rep} + cand.estBits(bitsPerByte) + if m.est >= highScore || cand.est-m.est+(cand.s-m.s)*bitsPerByte>>10 < 0 { + *m = cand + } } - best := bestOf(matchAt(candidateL.offset-e.cur, s, uint32(cv), -1), matchAt(candidateL.prev-e.cur, s, uint32(cv), -1)) - best = bestOf(best, matchAt(candidateS.offset-e.cur, s, uint32(cv), -1)) - best = bestOf(best, matchAt(candidateS.prev-e.cur, s, uint32(cv), -1)) + best := match{s: s, est: highScore} + improve(&best, candidateL.offset-e.cur, s, uint32(cv), -1) + improve(&best, candidateL.prev-e.cur, s, uint32(cv), -1) + improve(&best, candidateS.offset-e.cur, s, uint32(cv), -1) + improve(&best, candidateS.prev-e.cur, s, uint32(cv), -1) if canRepeat && best.length < goodEnough { - cv32 := uint32(cv >> 8) - spp := s + 1 - best = bestOf(best, matchAt(spp-offset1, spp, cv32, 1)) - best = bestOf(best, matchAt(spp-offset2, spp, cv32, 2)) - best = bestOf(best, matchAt(spp-offset3, spp, cv32, 3)) - if best.length > 0 { - cv32 = uint32(cv >> 24) - spp += 2 - best = bestOf(best, matchAt(spp-offset1, spp, cv32, 1)) - best = bestOf(best, matchAt(spp-offset2, spp, cv32, 2)) - best = bestOf(best, matchAt(spp-offset3, spp, cv32, 3)) + if s == nextEmit { + // Check repeats straight after a match. + improve(&best, s-offset2, s, uint32(cv), 1|4) + improve(&best, s-offset3, s, uint32(cv), 2|4) + if offset1 > 1 { + improve(&best, s-(offset1-1), s, uint32(cv), 3|4) + } + } + + // If either no match or a non-repeat match, check at + 1 + if best.rep <= 0 { + cv32 := uint32(cv >> 8) + spp := s + 1 + improve(&best, spp-offset1, spp, cv32, 1) + improve(&best, spp-offset2, spp, cv32, 2) + improve(&best, spp-offset3, spp, cv32, 3) + if best.rep < 0 { + cv32 = uint32(cv >> 24) + spp += 2 + improve(&best, spp-offset1, spp, cv32, 1) + improve(&best, spp-offset2, spp, cv32, 2) + improve(&best, spp-offset3, spp, cv32, 3) + } } } // Load next and check... @@ -249,40 +290,45 @@ encodeLoop: if s >= sLimit { break encodeLoop } - cv = load6432(src, s) continue } - s++ candidateS = e.table[hashLen(cv>>8, bestShortTableBits, bestShortLen)] - cv = load6432(src, s) - cv2 := load6432(src, s+1) + cv = load6432(src, s+1) + cv2 := load6432(src, s+2) candidateL = e.longTable[hashLen(cv, bestLongTableBits, bestLongLen)] candidateL2 := e.longTable[hashLen(cv2, bestLongTableBits, bestLongLen)] // Short at s+1 - best = bestOf(best, matchAt(candidateS.offset-e.cur, s, uint32(cv), -1)) + improve(&best, candidateS.offset-e.cur, s+1, uint32(cv), -1) // Long at s+1, s+2 - best = bestOf(best, matchAt(candidateL.offset-e.cur, s, uint32(cv), -1)) - best = bestOf(best, matchAt(candidateL.prev-e.cur, s, uint32(cv), -1)) - best = bestOf(best, matchAt(candidateL2.offset-e.cur, s+1, uint32(cv2), -1)) - best = bestOf(best, matchAt(candidateL2.prev-e.cur, s+1, uint32(cv2), -1)) + improve(&best, candidateL.offset-e.cur, s+1, uint32(cv), -1) + improve(&best, candidateL.prev-e.cur, s+1, uint32(cv), -1) + improve(&best, candidateL2.offset-e.cur, s+2, uint32(cv2), -1) + improve(&best, candidateL2.prev-e.cur, s+2, uint32(cv2), -1) if false { // Short at s+3. // Too often worse... - best = bestOf(best, matchAt(e.table[hashLen(cv2>>8, bestShortTableBits, bestShortLen)].offset-e.cur, s+2, uint32(cv2>>8), -1)) + improve(&best, e.table[hashLen(cv2>>8, bestShortTableBits, bestShortLen)].offset-e.cur, s+3, uint32(cv2>>8), -1) } - // See if we can find a better match by checking where the current best ends. - // Use that offset to see if we can find a better full match. - if sAt := best.s + best.length; sAt < sLimit { - nextHashL := hashLen(load6432(src, sAt), bestLongTableBits, bestLongLen) - candidateEnd := e.longTable[nextHashL] - if pos := candidateEnd.offset - e.cur - best.length; pos >= 0 { - bestEnd := bestOf(best, matchAt(pos, best.s, load3232(src, best.s), -1)) - if pos := candidateEnd.prev - e.cur - best.length; pos >= 0 { - bestEnd = bestOf(bestEnd, matchAt(pos, best.s, load3232(src, best.s), -1)) + + // Start check at a fixed offset to allow for a few mismatches. + // For this compression level 2 yields the best results. + // We cannot do this if we have already indexed this position. + const skipBeginning = 2 + if best.s > s-skipBeginning { + // See if we can find a better match by checking where the current best ends. + // Use that offset to see if we can find a better full match. + if sAt := best.s + best.length; sAt < sLimit { + nextHashL := hashLen(load6432(src, sAt), bestLongTableBits, bestLongLen) + candidateEnd := e.longTable[nextHashL] + + if off := candidateEnd.offset - e.cur - best.length + skipBeginning; off >= 0 { + improve(&best, off, best.s+skipBeginning, load3232(src, best.s+skipBeginning), -1) + if off := candidateEnd.prev - e.cur - best.length + skipBeginning; off >= 0 { + improve(&best, off, best.s+skipBeginning, load3232(src, best.s+skipBeginning), -1) + } } - best = bestEnd } } } @@ -295,51 +341,34 @@ encodeLoop: // We have a match, we can store the forward value if best.rep > 0 { - s = best.s var seq seq seq.matchLen = uint32(best.length - zstdMinMatch) - - // We might be able to match backwards. - // Extend as long as we can. - start := best.s - // We end the search early, so we don't risk 0 literals - // and have to do special offset treatment. - startLimit := nextEmit + 1 - - tMin := s - e.maxMatchOff - if tMin < 0 { - tMin = 0 - } - repIndex := best.offset - for repIndex > tMin && start > startLimit && src[repIndex-1] == src[start-1] && seq.matchLen < maxMatchLength-zstdMinMatch-1 { - repIndex-- - start-- - seq.matchLen++ + if debugAsserts && s <= nextEmit { + panic("s <= nextEmit") } - addLiterals(&seq, start) + addLiterals(&seq, best.s) - // rep 0 - seq.offset = uint32(best.rep) + // Repeat. If bit 4 is set, this is a non-lit repeat. + seq.offset = uint32(best.rep & 3) if debugSequences { println("repeat sequence", seq, "next s:", s) } blk.sequences = append(blk.sequences, seq) - // Index match start+1 (long) -> s - 1 - index0 := s + // Index old s + 1 -> s - 1 + index0 := s + 1 s = best.s + best.length nextEmit = s if s >= sLimit { if debugEncoder { println("repeat ended", s, best.length) - } break encodeLoop } // Index skipped... off := index0 + e.cur - for index0 < s-1 { + for index0 < s { cv0 := load6432(src, index0) h0 := hashLen(cv0, bestLongTableBits, bestLongLen) h1 := hashLen(cv0, bestShortTableBits, bestShortLen) @@ -349,17 +378,19 @@ encodeLoop: index0++ } switch best.rep { - case 2: + case 2, 4 | 1: offset1, offset2 = offset2, offset1 - case 3: + case 3, 4 | 2: offset1, offset2, offset3 = offset3, offset1, offset2 + case 4 | 3: + offset1, offset2, offset3 = offset1-1, offset1, offset2 } - cv = load6432(src, s) continue } // A 4-byte match has been found. Update recent offsets. // We'll later see if more than 4 bytes. + index0 := s + 1 s = best.s t := best.offset offset1, offset2, offset3 = s-t, offset1, offset2 @@ -372,22 +403,9 @@ encodeLoop: panic("invalid offset") } - // Extend the n-byte match as long as possible. - l := best.length - - // Extend backwards - tMin := s - e.maxMatchOff - if tMin < 0 { - tMin = 0 - } - for t > tMin && s > nextEmit && src[t-1] == src[s-1] && l < maxMatchLength { - s-- - t-- - l++ - } - // Write our sequence var seq seq + l := best.length seq.litLen = uint32(s - nextEmit) seq.matchLen = uint32(l - zstdMinMatch) if seq.litLen > 0 { @@ -404,10 +422,8 @@ encodeLoop: break encodeLoop } - // Index match start+1 (long) -> s - 1 - index0 := s - l + 1 - // every entry - for index0 < s-1 { + // Index old s + 1 -> s - 1 + for index0 < s { cv0 := load6432(src, index0) h0 := hashLen(cv0, bestLongTableBits, bestLongLen) h1 := hashLen(cv0, bestShortTableBits, bestShortLen) @@ -416,50 +432,6 @@ encodeLoop: e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset} index0++ } - - cv = load6432(src, s) - if !canRepeat { - continue - } - - // Check offset 2 - for { - o2 := s - offset2 - if load3232(src, o2) != uint32(cv) { - // Do regular search - break - } - - // Store this, since we have it. - nextHashS := hashLen(cv, bestShortTableBits, bestShortLen) - nextHashL := hashLen(cv, bestLongTableBits, bestLongLen) - - // We have at least 4 byte match. - // No need to check backwards. We come straight from a match - l := 4 + e.matchlen(s+4, o2+4, src) - - e.longTable[nextHashL] = prevEntry{offset: s + e.cur, prev: e.longTable[nextHashL].offset} - e.table[nextHashS] = prevEntry{offset: s + e.cur, prev: e.table[nextHashS].offset} - seq.matchLen = uint32(l) - zstdMinMatch - seq.litLen = 0 - - // Since litlen is always 0, this is offset 1. - seq.offset = 1 - s += l - nextEmit = s - if debugSequences { - println("sequence", seq, "next s:", s) - } - blk.sequences = append(blk.sequences, seq) - - // Swap offset 1 and 2. - offset1, offset2 = offset2, offset1 - if s >= sLimit { - // Finished - break encodeLoop - } - cv = load6432(src, s) - } } if int(nextEmit) < len(src) { diff --git a/vendor/github.com/klauspost/compress/zstd/enc_better.go b/vendor/github.com/klauspost/compress/zstd/enc_better.go index 602c05e..8582f31 100644 --- a/vendor/github.com/klauspost/compress/zstd/enc_better.go +++ b/vendor/github.com/klauspost/compress/zstd/enc_better.go @@ -62,14 +62,10 @@ func (e *betterFastEncoder) Encode(blk *blockEnc, src []byte) { ) // Protect against e.cur wraparound. - for e.cur >= bufferReset { + for e.cur >= e.bufferReset-int32(len(e.hist)) { if len(e.hist) == 0 { - for i := range e.table[:] { - e.table[i] = tableEntry{} - } - for i := range e.longTable[:] { - e.longTable[i] = prevEntry{} - } + e.table = [betterShortTableSize]tableEntry{} + e.longTable = [betterLongTableSize]prevEntry{} e.cur = e.maxMatchOff break } @@ -156,8 +152,8 @@ encodeLoop: panic("offset0 was 0") } - nextHashS := hashLen(cv, betterShortTableBits, betterShortLen) nextHashL := hashLen(cv, betterLongTableBits, betterLongLen) + nextHashS := hashLen(cv, betterShortTableBits, betterShortLen) candidateL := e.longTable[nextHashL] candidateS := e.table[nextHashS] @@ -416,15 +412,23 @@ encodeLoop: // Try to find a better match by searching for a long match at the end of the current best match if s+matched < sLimit { + // Allow some bytes at the beginning to mismatch. + // Sweet spot is around 3 bytes, but depends on input. + // The skipped bytes are tested in Extend backwards, + // and still picked up as part of the match if they do. + const skipBeginning = 3 + nextHashL := hashLen(load6432(src, s+matched), betterLongTableBits, betterLongLen) - cv := load3232(src, s) + s2 := s + skipBeginning + cv := load3232(src, s2) candidateL := e.longTable[nextHashL] - coffsetL := candidateL.offset - e.cur - matched - if coffsetL >= 0 && coffsetL < s && s-coffsetL < e.maxMatchOff && cv == load3232(src, coffsetL) { + coffsetL := candidateL.offset - e.cur - matched + skipBeginning + if coffsetL >= 0 && coffsetL < s2 && s2-coffsetL < e.maxMatchOff && cv == load3232(src, coffsetL) { // Found a long match, at least 4 bytes. - matchedNext := e.matchlen(s+4, coffsetL+4, src) + 4 + matchedNext := e.matchlen(s2+4, coffsetL+4, src) + 4 if matchedNext > matched { t = coffsetL + s = s2 matched = matchedNext if debugMatches { println("long match at end-of-match") @@ -434,12 +438,13 @@ encodeLoop: // Check prev long... if true { - coffsetL = candidateL.prev - e.cur - matched - if coffsetL >= 0 && coffsetL < s && s-coffsetL < e.maxMatchOff && cv == load3232(src, coffsetL) { + coffsetL = candidateL.prev - e.cur - matched + skipBeginning + if coffsetL >= 0 && coffsetL < s2 && s2-coffsetL < e.maxMatchOff && cv == load3232(src, coffsetL) { // Found a long match, at least 4 bytes. - matchedNext := e.matchlen(s+4, coffsetL+4, src) + 4 + matchedNext := e.matchlen(s2+4, coffsetL+4, src) + 4 if matchedNext > matched { t = coffsetL + s = s2 matched = matchedNext if debugMatches { println("prev long match at end-of-match") @@ -518,8 +523,8 @@ encodeLoop: } // Store this, since we have it. - nextHashS := hashLen(cv, betterShortTableBits, betterShortLen) nextHashL := hashLen(cv, betterLongTableBits, betterLongLen) + nextHashS := hashLen(cv, betterShortTableBits, betterShortLen) // We have at least 4 byte match. // No need to check backwards. We come straight from a match @@ -578,7 +583,7 @@ func (e *betterFastEncoderDict) Encode(blk *blockEnc, src []byte) { ) // Protect against e.cur wraparound. - for e.cur >= bufferReset { + for e.cur >= e.bufferReset-int32(len(e.hist)) { if len(e.hist) == 0 { for i := range e.table[:] { e.table[i] = tableEntry{} @@ -674,8 +679,8 @@ encodeLoop: panic("offset0 was 0") } - nextHashS := hashLen(cv, betterShortTableBits, betterShortLen) nextHashL := hashLen(cv, betterLongTableBits, betterLongLen) + nextHashS := hashLen(cv, betterShortTableBits, betterShortLen) candidateL := e.longTable[nextHashL] candidateS := e.table[nextHashS] @@ -1047,8 +1052,8 @@ encodeLoop: } // Store this, since we have it. - nextHashS := hashLen(cv, betterShortTableBits, betterShortLen) nextHashL := hashLen(cv, betterLongTableBits, betterLongLen) + nextHashS := hashLen(cv, betterShortTableBits, betterShortLen) // We have at least 4 byte match. // No need to check backwards. We come straight from a match diff --git a/vendor/github.com/klauspost/compress/zstd/enc_dfast.go b/vendor/github.com/klauspost/compress/zstd/enc_dfast.go index d6b3104..a154c18 100644 --- a/vendor/github.com/klauspost/compress/zstd/enc_dfast.go +++ b/vendor/github.com/klauspost/compress/zstd/enc_dfast.go @@ -44,14 +44,10 @@ func (e *doubleFastEncoder) Encode(blk *blockEnc, src []byte) { ) // Protect against e.cur wraparound. - for e.cur >= bufferReset { + for e.cur >= e.bufferReset-int32(len(e.hist)) { if len(e.hist) == 0 { - for i := range e.table[:] { - e.table[i] = tableEntry{} - } - for i := range e.longTable[:] { - e.longTable[i] = tableEntry{} - } + e.table = [dFastShortTableSize]tableEntry{} + e.longTable = [dFastLongTableSize]tableEntry{} e.cur = e.maxMatchOff break } @@ -127,8 +123,8 @@ encodeLoop: panic("offset0 was 0") } - nextHashS := hashLen(cv, dFastShortTableBits, dFastShortLen) nextHashL := hashLen(cv, dFastLongTableBits, dFastLongLen) + nextHashS := hashLen(cv, dFastShortTableBits, dFastShortLen) candidateL := e.longTable[nextHashL] candidateS := e.table[nextHashS] @@ -388,7 +384,7 @@ func (e *doubleFastEncoder) EncodeNoHist(blk *blockEnc, src []byte) { ) // Protect against e.cur wraparound. - if e.cur >= bufferReset { + if e.cur >= e.bufferReset { for i := range e.table[:] { e.table[i] = tableEntry{} } @@ -439,8 +435,8 @@ encodeLoop: var t int32 for { - nextHashS := hashLen(cv, dFastShortTableBits, dFastShortLen) nextHashL := hashLen(cv, dFastLongTableBits, dFastLongLen) + nextHashS := hashLen(cv, dFastShortTableBits, dFastShortLen) candidateL := e.longTable[nextHashL] candidateS := e.table[nextHashS] @@ -685,7 +681,7 @@ encodeLoop: } // We do not store history, so we must offset e.cur to avoid false matches for next user. - if e.cur < bufferReset { + if e.cur < e.bufferReset { e.cur += int32(len(src)) } } @@ -700,7 +696,7 @@ func (e *doubleFastEncoderDict) Encode(blk *blockEnc, src []byte) { ) // Protect against e.cur wraparound. - for e.cur >= bufferReset { + for e.cur >= e.bufferReset-int32(len(e.hist)) { if len(e.hist) == 0 { for i := range e.table[:] { e.table[i] = tableEntry{} @@ -785,8 +781,8 @@ encodeLoop: panic("offset0 was 0") } - nextHashS := hashLen(cv, dFastShortTableBits, dFastShortLen) nextHashL := hashLen(cv, dFastLongTableBits, dFastLongLen) + nextHashS := hashLen(cv, dFastShortTableBits, dFastShortLen) candidateL := e.longTable[nextHashL] candidateS := e.table[nextHashS] @@ -969,7 +965,7 @@ encodeLoop: te0 := tableEntry{offset: index0 + e.cur, val: uint32(cv0)} te1 := tableEntry{offset: index1 + e.cur, val: uint32(cv1)} longHash1 := hashLen(cv0, dFastLongTableBits, dFastLongLen) - longHash2 := hashLen(cv0, dFastLongTableBits, dFastLongLen) + longHash2 := hashLen(cv1, dFastLongTableBits, dFastLongLen) e.longTable[longHash1] = te0 e.longTable[longHash2] = te1 e.markLongShardDirty(longHash1) @@ -1002,8 +998,8 @@ encodeLoop: } // Store this, since we have it. - nextHashS := hashLen(cv, dFastShortTableBits, dFastShortLen) nextHashL := hashLen(cv, dFastLongTableBits, dFastLongLen) + nextHashS := hashLen(cv, dFastShortTableBits, dFastShortLen) // We have at least 4 byte match. // No need to check backwards. We come straight from a match @@ -1088,7 +1084,7 @@ func (e *doubleFastEncoderDict) Reset(d *dict, singleBlock bool) { } } e.lastDictID = d.id - e.allDirty = true + allDirty = true } // Reset table to initial state e.cur = e.maxMatchOff @@ -1103,7 +1099,8 @@ func (e *doubleFastEncoderDict) Reset(d *dict, singleBlock bool) { } if allDirty || dirtyShardCnt > dLongTableShardCnt/2 { - copy(e.longTable[:], e.dictLongTable) + //copy(e.longTable[:], e.dictLongTable) + e.longTable = *(*[dFastLongTableSize]tableEntry)(e.dictLongTable) for i := range e.longTableShardDirty { e.longTableShardDirty[i] = false } @@ -1114,7 +1111,9 @@ func (e *doubleFastEncoderDict) Reset(d *dict, singleBlock bool) { continue } - copy(e.longTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize], e.dictLongTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize]) + // copy(e.longTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize], e.dictLongTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize]) + *(*[dLongTableShardSize]tableEntry)(e.longTable[i*dLongTableShardSize:]) = *(*[dLongTableShardSize]tableEntry)(e.dictLongTable[i*dLongTableShardSize:]) + e.longTableShardDirty[i] = false } } diff --git a/vendor/github.com/klauspost/compress/zstd/enc_fast.go b/vendor/github.com/klauspost/compress/zstd/enc_fast.go index f250262..f45a3da 100644 --- a/vendor/github.com/klauspost/compress/zstd/enc_fast.go +++ b/vendor/github.com/klauspost/compress/zstd/enc_fast.go @@ -6,8 +6,6 @@ package zstd import ( "fmt" - "math" - "math/bits" ) const ( @@ -45,7 +43,7 @@ func (e *fastEncoder) Encode(blk *blockEnc, src []byte) { ) // Protect against e.cur wraparound. - for e.cur >= bufferReset { + for e.cur >= e.bufferReset-int32(len(e.hist)) { if len(e.hist) == 0 { for i := range e.table[:] { e.table[i] = tableEntry{} @@ -87,7 +85,7 @@ func (e *fastEncoder) Encode(blk *blockEnc, src []byte) { // TEMPLATE const hashLog = tableBits // seems global, but would be nice to tweak. - const kSearchStrength = 7 + const kSearchStrength = 6 // nextEmit is where in src the next emitLiteral should start from. nextEmit := s @@ -135,21 +133,7 @@ encodeLoop: if canRepeat && repIndex >= 0 && load3232(src, repIndex) == uint32(cv>>16) { // Consider history as well. var seq seq - var length int32 - // length = 4 + e.matchlen(s+6, repIndex+4, src) - { - a := src[s+6:] - b := src[repIndex+4:] - endI := len(a) & (math.MaxInt32 - 7) - length = int32(endI) + 4 - for i := 0; i < endI; i += 8 { - if diff := load64(a, i) ^ load64(b, i); diff != 0 { - length = int32(i+bits.TrailingZeros64(diff)>>3) + 4 - break - } - } - } - + length := 4 + e.matchlen(s+6, repIndex+4, src) seq.matchLen = uint32(length - zstdMinMatch) // We might be able to match backwards. @@ -236,20 +220,7 @@ encodeLoop: } // Extend the 4-byte match as long as possible. - //l := e.matchlen(s+4, t+4, src) + 4 - var l int32 - { - a := src[s+4:] - b := src[t+4:] - endI := len(a) & (math.MaxInt32 - 7) - l = int32(endI) + 4 - for i := 0; i < endI; i += 8 { - if diff := load64(a, i) ^ load64(b, i); diff != 0 { - l = int32(i+bits.TrailingZeros64(diff)>>3) + 4 - break - } - } - } + l := e.matchlen(s+4, t+4, src) + 4 // Extend backwards tMin := s - e.maxMatchOff @@ -286,20 +257,7 @@ encodeLoop: if o2 := s - offset2; canRepeat && load3232(src, o2) == uint32(cv) { // We have at least 4 byte match. // No need to check backwards. We come straight from a match - //l := 4 + e.matchlen(s+4, o2+4, src) - var l int32 - { - a := src[s+4:] - b := src[o2+4:] - endI := len(a) & (math.MaxInt32 - 7) - l = int32(endI) + 4 - for i := 0; i < endI; i += 8 { - if diff := load64(a, i) ^ load64(b, i); diff != 0 { - l = int32(i+bits.TrailingZeros64(diff)>>3) + 4 - break - } - } - } + l := 4 + e.matchlen(s+4, o2+4, src) // Store this, since we have it. nextHash := hashLen(cv, hashLog, tableFastHashLen) @@ -345,13 +303,13 @@ func (e *fastEncoder) EncodeNoHist(blk *blockEnc, src []byte) { minNonLiteralBlockSize = 1 + 1 + inputMargin ) if debugEncoder { - if len(src) > maxBlockSize { + if len(src) > maxCompressedBlockSize { panic("src too big") } } // Protect against e.cur wraparound. - if e.cur >= bufferReset { + if e.cur >= e.bufferReset { for i := range e.table[:] { e.table[i] = tableEntry{} } @@ -375,7 +333,7 @@ func (e *fastEncoder) EncodeNoHist(blk *blockEnc, src []byte) { // TEMPLATE const hashLog = tableBits // seems global, but would be nice to tweak. - const kSearchStrength = 8 + const kSearchStrength = 6 // nextEmit is where in src the next emitLiteral should start from. nextEmit := s @@ -418,21 +376,7 @@ encodeLoop: if len(blk.sequences) > 2 && load3232(src, repIndex) == uint32(cv>>16) { // Consider history as well. var seq seq - // length := 4 + e.matchlen(s+6, repIndex+4, src) - // length := 4 + int32(matchLen(src[s+6:], src[repIndex+4:])) - var length int32 - { - a := src[s+6:] - b := src[repIndex+4:] - endI := len(a) & (math.MaxInt32 - 7) - length = int32(endI) + 4 - for i := 0; i < endI; i += 8 { - if diff := load64(a, i) ^ load64(b, i); diff != 0 { - length = int32(i+bits.TrailingZeros64(diff)>>3) + 4 - break - } - } - } + length := 4 + e.matchlen(s+6, repIndex+4, src) seq.matchLen = uint32(length - zstdMinMatch) @@ -522,21 +466,7 @@ encodeLoop: panic(fmt.Sprintf("t (%d) < 0 ", t)) } // Extend the 4-byte match as long as possible. - //l := e.matchlenNoHist(s+4, t+4, src) + 4 - // l := int32(matchLen(src[s+4:], src[t+4:])) + 4 - var l int32 - { - a := src[s+4:] - b := src[t+4:] - endI := len(a) & (math.MaxInt32 - 7) - l = int32(endI) + 4 - for i := 0; i < endI; i += 8 { - if diff := load64(a, i) ^ load64(b, i); diff != 0 { - l = int32(i+bits.TrailingZeros64(diff)>>3) + 4 - break - } - } - } + l := e.matchlen(s+4, t+4, src) + 4 // Extend backwards tMin := s - e.maxMatchOff @@ -573,21 +503,7 @@ encodeLoop: if o2 := s - offset2; len(blk.sequences) > 2 && load3232(src, o2) == uint32(cv) { // We have at least 4 byte match. // No need to check backwards. We come straight from a match - //l := 4 + e.matchlenNoHist(s+4, o2+4, src) - // l := 4 + int32(matchLen(src[s+4:], src[o2+4:])) - var l int32 - { - a := src[s+4:] - b := src[o2+4:] - endI := len(a) & (math.MaxInt32 - 7) - l = int32(endI) + 4 - for i := 0; i < endI; i += 8 { - if diff := load64(a, i) ^ load64(b, i); diff != 0 { - l = int32(i+bits.TrailingZeros64(diff)>>3) + 4 - break - } - } - } + l := 4 + e.matchlen(s+4, o2+4, src) // Store this, since we have it. nextHash := hashLen(cv, hashLog, tableFastHashLen) @@ -621,7 +537,7 @@ encodeLoop: println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits) } // We do not store history, so we must offset e.cur to avoid false matches for next user. - if e.cur < bufferReset { + if e.cur < e.bufferReset { e.cur += int32(len(src)) } } @@ -638,11 +554,9 @@ func (e *fastEncoderDict) Encode(blk *blockEnc, src []byte) { return } // Protect against e.cur wraparound. - for e.cur >= bufferReset { + for e.cur >= e.bufferReset-int32(len(e.hist)) { if len(e.hist) == 0 { - for i := range e.table[:] { - e.table[i] = tableEntry{} - } + e.table = [tableSize]tableEntry{} e.cur = e.maxMatchOff break } @@ -730,20 +644,7 @@ encodeLoop: if canRepeat && repIndex >= 0 && load3232(src, repIndex) == uint32(cv>>16) { // Consider history as well. var seq seq - var length int32 - // length = 4 + e.matchlen(s+6, repIndex+4, src) - { - a := src[s+6:] - b := src[repIndex+4:] - endI := len(a) & (math.MaxInt32 - 7) - length = int32(endI) + 4 - for i := 0; i < endI; i += 8 { - if diff := load64(a, i) ^ load64(b, i); diff != 0 { - length = int32(i+bits.TrailingZeros64(diff)>>3) + 4 - break - } - } - } + length := 4 + e.matchlen(s+6, repIndex+4, src) seq.matchLen = uint32(length - zstdMinMatch) @@ -831,20 +732,7 @@ encodeLoop: } // Extend the 4-byte match as long as possible. - //l := e.matchlen(s+4, t+4, src) + 4 - var l int32 - { - a := src[s+4:] - b := src[t+4:] - endI := len(a) & (math.MaxInt32 - 7) - l = int32(endI) + 4 - for i := 0; i < endI; i += 8 { - if diff := load64(a, i) ^ load64(b, i); diff != 0 { - l = int32(i+bits.TrailingZeros64(diff)>>3) + 4 - break - } - } - } + l := e.matchlen(s+4, t+4, src) + 4 // Extend backwards tMin := s - e.maxMatchOff @@ -881,20 +769,7 @@ encodeLoop: if o2 := s - offset2; canRepeat && load3232(src, o2) == uint32(cv) { // We have at least 4 byte match. // No need to check backwards. We come straight from a match - //l := 4 + e.matchlen(s+4, o2+4, src) - var l int32 - { - a := src[s+4:] - b := src[o2+4:] - endI := len(a) & (math.MaxInt32 - 7) - l = int32(endI) + 4 - for i := 0; i < endI; i += 8 { - if diff := load64(a, i) ^ load64(b, i); diff != 0 { - l = int32(i+bits.TrailingZeros64(diff)>>3) + 4 - break - } - } - } + l := 4 + e.matchlen(s+4, o2+4, src) // Store this, since we have it. nextHash := hashLen(cv, hashLog, tableFastHashLen) @@ -954,13 +829,12 @@ func (e *fastEncoderDict) Reset(d *dict, singleBlock bool) { } if true { end := e.maxMatchOff + int32(len(d.content)) - 8 - for i := e.maxMatchOff; i < end; i += 3 { + for i := e.maxMatchOff; i < end; i += 2 { const hashLog = tableBits cv := load6432(d.content, i-e.maxMatchOff) - nextHash := hashLen(cv, hashLog, tableFastHashLen) // 0 -> 5 - nextHash1 := hashLen(cv>>8, hashLog, tableFastHashLen) // 1 -> 6 - nextHash2 := hashLen(cv>>16, hashLog, tableFastHashLen) // 2 -> 7 + nextHash := hashLen(cv, hashLog, tableFastHashLen) // 0 -> 6 + nextHash1 := hashLen(cv>>8, hashLog, tableFastHashLen) // 1 -> 7 e.dictTable[nextHash] = tableEntry{ val: uint32(cv), offset: i, @@ -969,10 +843,6 @@ func (e *fastEncoderDict) Reset(d *dict, singleBlock bool) { val: uint32(cv >> 8), offset: i + 1, } - e.dictTable[nextHash2] = tableEntry{ - val: uint32(cv >> 16), - offset: i + 2, - } } } e.lastDictID = d.id @@ -992,7 +862,8 @@ func (e *fastEncoderDict) Reset(d *dict, singleBlock bool) { const shardCnt = tableShardCnt const shardSize = tableShardSize if e.allDirty || dirtyShardCnt > shardCnt*4/6 { - copy(e.table[:], e.dictTable) + //copy(e.table[:], e.dictTable) + e.table = *(*[tableSize]tableEntry)(e.dictTable) for i := range e.tableShardDirty { e.tableShardDirty[i] = false } @@ -1004,7 +875,8 @@ func (e *fastEncoderDict) Reset(d *dict, singleBlock bool) { continue } - copy(e.table[i*shardSize:(i+1)*shardSize], e.dictTable[i*shardSize:(i+1)*shardSize]) + //copy(e.table[i*shardSize:(i+1)*shardSize], e.dictTable[i*shardSize:(i+1)*shardSize]) + *(*[shardSize]tableEntry)(e.table[i*shardSize:]) = *(*[shardSize]tableEntry)(e.dictTable[i*shardSize:]) e.tableShardDirty[i] = false } e.allDirty = false diff --git a/vendor/github.com/klauspost/compress/zstd/encoder.go b/vendor/github.com/klauspost/compress/zstd/encoder.go index e6e3159..4de0aed 100644 --- a/vendor/github.com/klauspost/compress/zstd/encoder.go +++ b/vendor/github.com/klauspost/compress/zstd/encoder.go @@ -8,6 +8,7 @@ import ( "crypto/rand" "fmt" "io" + "math" rdebug "runtime/debug" "sync" @@ -98,23 +99,25 @@ func (e *Encoder) Reset(w io.Writer) { if cap(s.filling) == 0 { s.filling = make([]byte, 0, e.o.blockSize) } - if cap(s.current) == 0 { - s.current = make([]byte, 0, e.o.blockSize) - } - if cap(s.previous) == 0 { - s.previous = make([]byte, 0, e.o.blockSize) + if e.o.concurrent > 1 { + if cap(s.current) == 0 { + s.current = make([]byte, 0, e.o.blockSize) + } + if cap(s.previous) == 0 { + s.previous = make([]byte, 0, e.o.blockSize) + } + s.current = s.current[:0] + s.previous = s.previous[:0] + if s.writing == nil { + s.writing = &blockEnc{lowMem: e.o.lowMem} + s.writing.init() + } + s.writing.initNewEncode() } if s.encoder == nil { s.encoder = e.o.encoder() } - if s.writing == nil { - s.writing = &blockEnc{lowMem: e.o.lowMem} - s.writing.init() - } - s.writing.initNewEncode() s.filling = s.filling[:0] - s.current = s.current[:0] - s.previous = s.previous[:0] s.encoder.Reset(e.o.dict, false) s.headerWritten = false s.eofWritten = false @@ -258,6 +261,32 @@ func (e *Encoder) nextBlock(final bool) error { return s.err } + // SYNC: + if e.o.concurrent == 1 { + src := s.filling + s.nInput += int64(len(s.filling)) + if debugEncoder { + println("Adding sync block,", len(src), "bytes, final:", final) + } + enc := s.encoder + blk := enc.Block() + blk.reset(nil) + enc.Encode(blk, src) + blk.last = final + if final { + s.eofWritten = true + } + + s.err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy) + if s.err != nil { + return s.err + } + _, s.err = s.w.Write(blk.output) + s.nWritten += int64(len(blk.output)) + s.filling = s.filling[:0] + return s.err + } + // Move blocks forward. s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current s.nInput += int64(len(s.current)) @@ -300,22 +329,8 @@ func (e *Encoder) nextBlock(final bool) error { } s.wWg.Done() }() - err := errIncompressible - // If we got the exact same number of literals as input, - // assume the literals cannot be compressed. - if len(src) != len(blk.literals) || len(src) != e.o.blockSize { - err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy) - } - switch err { - case errIncompressible: - if debugEncoder { - println("Storing incompressible block as raw") - } - blk.encodeRaw(src) - // In fast mode, we do not transfer offsets, so we don't have to deal with changing the. - case nil: - default: - s.writeErr = err + s.writeErr = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy) + if s.writeErr != nil { return } _, s.writeErr = s.w.Write(blk.output) @@ -486,8 +501,8 @@ func (e *Encoder) EncodeAll(src, dst []byte) []byte { // If a non-single block is needed the encoder will reset again. e.encoders <- enc }() - // Use single segments when above minimum window and below 1MB. - single := len(src) < 1<<20 && len(src) > MinWindowSize + // Use single segments when above minimum window and below window size. + single := len(src) <= e.o.windowSize && len(src) > MinWindowSize if e.o.single != nil { single = *e.o.single } @@ -509,7 +524,7 @@ func (e *Encoder) EncodeAll(src, dst []byte) []byte { } // If we can do everything in one block, prefer that. - if len(src) <= maxCompressedBlockSize { + if len(src) <= e.o.blockSize { enc.Reset(e.o.dict, true) // Slightly faster with no history and everything in one block. if e.o.crc { @@ -525,25 +540,15 @@ func (e *Encoder) EncodeAll(src, dst []byte) []byte { // If we got the exact same number of literals as input, // assume the literals cannot be compressed. - err := errIncompressible oldout := blk.output - if len(blk.literals) != len(src) || len(src) != e.o.blockSize { - // Output directly to dst - blk.output = dst - err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy) - } + // Output directly to dst + blk.output = dst - switch err { - case errIncompressible: - if debugEncoder { - println("Storing incompressible block as raw") - } - dst = blk.encodeRawTo(dst, src) - case nil: - dst = blk.output - default: + err := blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy) + if err != nil { panic(err) } + dst = blk.output blk.output = oldout } else { enc.Reset(e.o.dict, false) @@ -562,25 +567,11 @@ func (e *Encoder) EncodeAll(src, dst []byte) []byte { if len(src) == 0 { blk.last = true } - err := errIncompressible - // If we got the exact same number of literals as input, - // assume the literals cannot be compressed. - if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize { - err = blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy) - } - - switch err { - case errIncompressible: - if debugEncoder { - println("Storing incompressible block as raw") - } - dst = blk.encodeRawTo(dst, todo) - blk.popOffsets() - case nil: - dst = append(dst, blk.output...) - default: + err := blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy) + if err != nil { panic(err) } + dst = append(dst, blk.output...) blk.reset(nil) } } @@ -597,3 +588,37 @@ func (e *Encoder) EncodeAll(src, dst []byte) []byte { } return dst } + +// MaxEncodedSize returns the expected maximum +// size of an encoded block or stream. +func (e *Encoder) MaxEncodedSize(size int) int { + frameHeader := 4 + 2 // magic + frame header & window descriptor + if e.o.dict != nil { + frameHeader += 4 + } + // Frame content size: + if size < 256 { + frameHeader++ + } else if size < 65536+256 { + frameHeader += 2 + } else if size < math.MaxInt32 { + frameHeader += 4 + } else { + frameHeader += 8 + } + // Final crc + if e.o.crc { + frameHeader += 4 + } + + // Max overhead is 3 bytes/block. + // There cannot be 0 blocks. + blocks := (size + e.o.blockSize) / e.o.blockSize + + // Combine, add padding. + maxSz := frameHeader + 3*blocks + size + if e.o.pad > 1 { + maxSz += calcSkippableFrame(int64(maxSz), int64(e.o.pad)) + } + return maxSz +} diff --git a/vendor/github.com/klauspost/compress/zstd/encoder_options.go b/vendor/github.com/klauspost/compress/zstd/encoder_options.go index 7d29e1d..faaf819 100644 --- a/vendor/github.com/klauspost/compress/zstd/encoder_options.go +++ b/vendor/github.com/klauspost/compress/zstd/encoder_options.go @@ -3,6 +3,8 @@ package zstd import ( "errors" "fmt" + "math" + "math/bits" "runtime" "strings" ) @@ -24,6 +26,7 @@ type encoderOptions struct { allLitEntropy bool customWindow bool customALEntropy bool + customBlockSize bool lowMem bool dict *dict } @@ -33,10 +36,10 @@ func (o *encoderOptions) setDefault() { concurrent: runtime.GOMAXPROCS(0), crc: true, single: nil, - blockSize: 1 << 16, + blockSize: maxCompressedBlockSize, windowSize: 8 << 20, level: SpeedDefault, - allLitEntropy: true, + allLitEntropy: false, lowMem: false, } } @@ -46,22 +49,22 @@ func (o encoderOptions) encoder() encoder { switch o.level { case SpeedFastest: if o.dict != nil { - return &fastEncoderDict{fastEncoder: fastEncoder{fastBase: fastBase{maxMatchOff: int32(o.windowSize), lowMem: o.lowMem}}} + return &fastEncoderDict{fastEncoder: fastEncoder{fastBase: fastBase{maxMatchOff: int32(o.windowSize), bufferReset: math.MaxInt32 - int32(o.windowSize*2), lowMem: o.lowMem}}} } - return &fastEncoder{fastBase: fastBase{maxMatchOff: int32(o.windowSize), lowMem: o.lowMem}} + return &fastEncoder{fastBase: fastBase{maxMatchOff: int32(o.windowSize), bufferReset: math.MaxInt32 - int32(o.windowSize*2), lowMem: o.lowMem}} case SpeedDefault: if o.dict != nil { - return &doubleFastEncoderDict{fastEncoderDict: fastEncoderDict{fastEncoder: fastEncoder{fastBase: fastBase{maxMatchOff: int32(o.windowSize), lowMem: o.lowMem}}}} + return &doubleFastEncoderDict{fastEncoderDict: fastEncoderDict{fastEncoder: fastEncoder{fastBase: fastBase{maxMatchOff: int32(o.windowSize), bufferReset: math.MaxInt32 - int32(o.windowSize*2), lowMem: o.lowMem}}}} } - return &doubleFastEncoder{fastEncoder: fastEncoder{fastBase: fastBase{maxMatchOff: int32(o.windowSize), lowMem: o.lowMem}}} + return &doubleFastEncoder{fastEncoder: fastEncoder{fastBase: fastBase{maxMatchOff: int32(o.windowSize), bufferReset: math.MaxInt32 - int32(o.windowSize*2), lowMem: o.lowMem}}} case SpeedBetterCompression: if o.dict != nil { - return &betterFastEncoderDict{betterFastEncoder: betterFastEncoder{fastBase: fastBase{maxMatchOff: int32(o.windowSize), lowMem: o.lowMem}}} + return &betterFastEncoderDict{betterFastEncoder: betterFastEncoder{fastBase: fastBase{maxMatchOff: int32(o.windowSize), bufferReset: math.MaxInt32 - int32(o.windowSize*2), lowMem: o.lowMem}}} } - return &betterFastEncoder{fastBase: fastBase{maxMatchOff: int32(o.windowSize), lowMem: o.lowMem}} + return &betterFastEncoder{fastBase: fastBase{maxMatchOff: int32(o.windowSize), bufferReset: math.MaxInt32 - int32(o.windowSize*2), lowMem: o.lowMem}} case SpeedBestCompression: - return &bestFastEncoder{fastBase: fastBase{maxMatchOff: int32(o.windowSize), lowMem: o.lowMem}} + return &bestFastEncoder{fastBase: fastBase{maxMatchOff: int32(o.windowSize), bufferReset: math.MaxInt32 - int32(o.windowSize*2), lowMem: o.lowMem}} } panic("unknown compression level") } @@ -75,6 +78,7 @@ func WithEncoderCRC(b bool) EOption { // WithEncoderConcurrency will set the concurrency, // meaning the maximum number of encoders to run concurrently. // The value supplied must be at least 1. +// For streams, setting a value of 1 will disable async compression. // By default this will be set to GOMAXPROCS. func WithEncoderConcurrency(n int) EOption { return func(o *encoderOptions) error { @@ -106,6 +110,7 @@ func WithWindowSize(n int) EOption { o.customWindow = true if o.blockSize > o.windowSize { o.blockSize = o.windowSize + o.customBlockSize = true } return nil } @@ -124,7 +129,7 @@ func WithEncoderPadding(n int) EOption { } // No need to waste our time. if n == 1 { - o.pad = 0 + n = 0 } if n > 1<<30 { return fmt.Errorf("padding must less than 1GB (1<<30 bytes) ") @@ -188,10 +193,9 @@ func EncoderLevelFromZstd(level int) EncoderLevel { return SpeedDefault case level >= 6 && level < 10: return SpeedBetterCompression - case level >= 10: + default: return SpeedBestCompression } - return SpeedDefault } // String provides a string representation of the compression level. @@ -222,6 +226,9 @@ func WithEncoderLevel(l EncoderLevel) EOption { switch o.level { case SpeedFastest: o.windowSize = 4 << 20 + if !o.customBlockSize { + o.blockSize = 1 << 16 + } case SpeedDefault: o.windowSize = 8 << 20 case SpeedBetterCompression: @@ -231,7 +238,7 @@ func WithEncoderLevel(l EncoderLevel) EOption { } } if !o.customALEntropy { - o.allLitEntropy = l > SpeedFastest + o.allLitEntropy = l > SpeedDefault } return nil @@ -278,7 +285,7 @@ func WithNoEntropyCompression(b bool) EOption { // a decoder is allowed to reject a compressed frame which requests a memory size beyond decoder's authorized range. // For broader compatibility, decoders are recommended to support memory sizes of at least 8 MB. // This is only a recommendation, each decoder is free to support higher or lower limits, depending on local limitations. -// If this is not specified, block encodes will automatically choose this based on the input size. +// If this is not specified, block encodes will automatically choose this based on the input size and the window size. // This setting has no effect on streamed encodes. func WithSingleSegment(b bool) EOption { return func(o *encoderOptions) error { @@ -299,7 +306,13 @@ func WithLowerEncoderMem(b bool) EOption { } // WithEncoderDict allows to register a dictionary that will be used for the encode. +// +// The slice dict must be in the [dictionary format] produced by +// "zstd --train" from the Zstandard reference implementation. +// // The encoder *may* choose to use no dictionary instead for certain payloads. +// +// [dictionary format]: https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary-format func WithEncoderDict(dict []byte) EOption { return func(o *encoderOptions) error { d, err := loadDict(dict) @@ -310,3 +323,17 @@ func WithEncoderDict(dict []byte) EOption { return nil } } + +// WithEncoderDictRaw registers a dictionary that may be used by the encoder. +// +// The slice content may contain arbitrary data. It will be used as an initial +// history. +func WithEncoderDictRaw(id uint32, content []byte) EOption { + return func(o *encoderOptions) error { + if bits.UintSize > 32 && uint(len(content)) > dictMaxLength { + return fmt.Errorf("dictionary of size %d > 2GiB too large", len(content)) + } + o.dict = &dict{id: id, content: content, offsets: [3]int{1, 4, 8}} + return nil + } +} diff --git a/vendor/github.com/klauspost/compress/zstd/framedec.go b/vendor/github.com/klauspost/compress/zstd/framedec.go index 989c79f..53e160f 100644 --- a/vendor/github.com/klauspost/compress/zstd/framedec.go +++ b/vendor/github.com/klauspost/compress/zstd/framedec.go @@ -5,26 +5,20 @@ package zstd import ( - "bytes" + "encoding/binary" "encoding/hex" "errors" - "hash" "io" - "sync" "github.com/klauspost/compress/zstd/internal/xxhash" ) type frameDec struct { - o decoderOptions - crc hash.Hash64 - offset int64 + o decoderOptions + crc *xxhash.Digest WindowSize uint64 - // In order queue of blocks being decoded. - decoding chan *blockDec - // Frame history passed between blocks history history @@ -34,15 +28,10 @@ type frameDec struct { bBuf byteBuf FrameContentSize uint64 - frameDone sync.WaitGroup - DictionaryID *uint32 + DictionaryID uint32 HasCheckSum bool SingleSegment bool - - // asyncRunning indicates whether the async routine processes input on 'decoding'. - asyncRunningMu sync.Mutex - asyncRunning bool } const ( @@ -54,9 +43,9 @@ const ( MaxWindowSize = 1 << 29 ) -var ( - frameMagic = []byte{0x28, 0xb5, 0x2f, 0xfd} - skippableFrameMagic = []byte{0x2a, 0x4d, 0x18} +const ( + frameMagic = "\x28\xb5\x2f\xfd" + skippableFrameMagic = "\x2a\x4d\x18" ) func newFrameDec(o decoderOptions) *frameDec { @@ -84,25 +73,25 @@ func (d *frameDec) reset(br byteBuffer) error { switch err { case io.EOF, io.ErrUnexpectedEOF: return io.EOF - default: - return err case nil: signature[0] = b[0] + default: + return err } // Read the rest, don't allow io.ErrUnexpectedEOF b, err = br.readSmall(3) switch err { case io.EOF: return io.EOF - default: - return err case nil: copy(signature[1:], b) + default: + return err } - if !bytes.Equal(signature[1:4], skippableFrameMagic) || signature[0]&0xf0 != 0x50 { + if string(signature[1:4]) != skippableFrameMagic || signature[0]&0xf0 != 0x50 { if debugDecoder { - println("Not skippable", hex.EncodeToString(signature[:]), hex.EncodeToString(skippableFrameMagic)) + println("Not skippable", hex.EncodeToString(signature[:]), hex.EncodeToString([]byte(skippableFrameMagic))) } // Break if not skippable frame. break @@ -117,7 +106,7 @@ func (d *frameDec) reset(br byteBuffer) error { } n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) println("Skipping frame with", n, "bytes.") - err = br.skipN(int(n)) + err = br.skipN(int64(n)) if err != nil { if debugDecoder { println("Reading discarded frame", err) @@ -125,9 +114,9 @@ func (d *frameDec) reset(br byteBuffer) error { return err } } - if !bytes.Equal(signature[:], frameMagic) { + if string(signature[:]) != frameMagic { if debugDecoder { - println("Got magic numbers: ", signature, "want:", frameMagic) + println("Got magic numbers: ", signature, "want:", []byte(frameMagic)) } return ErrMagicMismatch } @@ -166,7 +155,7 @@ func (d *frameDec) reset(br byteBuffer) error { // Read Dictionary_ID // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary_id - d.DictionaryID = nil + d.DictionaryID = 0 if size := fhd & 3; size != 0 { if size == 3 { size = 4 @@ -178,7 +167,7 @@ func (d *frameDec) reset(br byteBuffer) error { return err } var id uint32 - switch size { + switch len(b) { case 1: id = uint32(b[0]) case 2: @@ -189,11 +178,7 @@ func (d *frameDec) reset(br byteBuffer) error { if debugDecoder { println("Dict size", size, "ID:", id) } - if id > 0 { - // ID 0 means "sorry, no dictionary anyway". - // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary-format - d.DictionaryID = &id - } + d.DictionaryID = id } // Read Frame_Content_Size @@ -208,14 +193,14 @@ func (d *frameDec) reset(br byteBuffer) error { default: fcsSize = 1 << v } - d.FrameContentSize = 0 + d.FrameContentSize = fcsUnknown if fcsSize > 0 { b, err := br.readSmall(fcsSize) if err != nil { println("Reading Frame content", err) return err } - switch fcsSize { + switch len(b) { case 1: d.FrameContentSize = uint64(b[0]) case 2: @@ -229,9 +214,10 @@ func (d *frameDec) reset(br byteBuffer) error { d.FrameContentSize = uint64(d1) | (uint64(d2) << 32) } if debugDecoder { - println("field size bits:", v, "fcsSize:", fcsSize, "FrameContentSize:", d.FrameContentSize, hex.EncodeToString(b[:fcsSize]), "singleseg:", d.SingleSegment, "window:", d.WindowSize) + println("Read FCS:", d.FrameContentSize) } } + // Move this to shared. d.HasCheckSum = fhd&(1<<2) != 0 if d.HasCheckSum { @@ -241,20 +227,27 @@ func (d *frameDec) reset(br byteBuffer) error { d.crc.Reset() } + if d.WindowSize > d.o.maxWindowSize { + if debugDecoder { + printf("window size %d > max %d\n", d.WindowSize, d.o.maxWindowSize) + } + return ErrWindowSizeExceeded + } + if d.WindowSize == 0 && d.SingleSegment { // We may not need window in this case. d.WindowSize = d.FrameContentSize if d.WindowSize < MinWindowSize { d.WindowSize = MinWindowSize } - } - - if d.WindowSize > uint64(d.o.maxWindowSize) { - if debugDecoder { - printf("window size %d > max %d\n", d.WindowSize, d.o.maxWindowSize) + if d.WindowSize > d.o.maxDecodedSize { + if debugDecoder { + printf("window size %d > max %d\n", d.WindowSize, d.o.maxWindowSize) + } + return ErrDecoderSizeExceeded } - return ErrWindowSizeExceeded } + // The minimum Window_Size is 1 KB. if d.WindowSize < MinWindowSize { if debugDecoder { @@ -263,11 +256,23 @@ func (d *frameDec) reset(br byteBuffer) error { return ErrWindowSizeTooSmall } d.history.windowSize = int(d.WindowSize) - if d.o.lowMem && d.history.windowSize < maxBlockSize { - d.history.maxSize = d.history.windowSize * 2 + if !d.o.lowMem || d.history.windowSize < maxBlockSize { + // Alloc 2x window size if not low-mem, or window size below 2MB. + d.history.allocFrameBuffer = d.history.windowSize * 2 } else { - d.history.maxSize = d.history.windowSize + maxBlockSize + if d.o.lowMem { + // Alloc with 1MB extra. + d.history.allocFrameBuffer = d.history.windowSize + maxBlockSize/2 + } else { + // Alloc with 2MB extra. + d.history.allocFrameBuffer = d.history.windowSize + maxBlockSize + } } + + if debugDecoder { + println("Frame: Dict:", d.DictionaryID, "FrameContentSize:", d.FrameContentSize, "singleseg:", d.SingleSegment, "window:", d.WindowSize, "crc:", d.HasCheckSum) + } + // history contains input - maybe we do something d.rawInput = br return nil @@ -276,209 +281,85 @@ func (d *frameDec) reset(br byteBuffer) error { // next will start decoding the next block from stream. func (d *frameDec) next(block *blockDec) error { if debugDecoder { - printf("decoding new block %p:%p", block, block.data) + println("decoding new block") } err := block.reset(d.rawInput, d.WindowSize) if err != nil { println("block error:", err) // Signal the frame decoder we have a problem. - d.sendErr(block, err) + block.sendErr(err) return err } - block.input <- struct{}{} - if debugDecoder { - println("next block:", block) - } - d.asyncRunningMu.Lock() - defer d.asyncRunningMu.Unlock() - if !d.asyncRunning { - return nil - } - if block.Last { - // We indicate the frame is done by sending io.EOF - d.decoding <- block - return io.EOF - } - d.decoding <- block return nil } -// sendEOF will queue an error block on the frame. -// This will cause the frame decoder to return when it encounters the block. -// Returns true if the decoder was added. -func (d *frameDec) sendErr(block *blockDec, err error) bool { - d.asyncRunningMu.Lock() - defer d.asyncRunningMu.Unlock() - if !d.asyncRunning { - return false - } - - println("sending error", err.Error()) - block.sendErr(err) - d.decoding <- block - return true -} - -// checkCRC will check the checksum if the frame has one. +// checkCRC will check the checksum, assuming the frame has one. // Will return ErrCRCMismatch if crc check failed, otherwise nil. func (d *frameDec) checkCRC() error { - if !d.HasCheckSum { - return nil - } - var tmp [4]byte - got := d.crc.Sum64() - // Flip to match file order. - tmp[0] = byte(got >> 0) - tmp[1] = byte(got >> 8) - tmp[2] = byte(got >> 16) - tmp[3] = byte(got >> 24) - // We can overwrite upper tmp now - want, err := d.rawInput.readSmall(4) + buf, err := d.rawInput.readSmall(4) if err != nil { println("CRC missing?", err) return err } - if !bytes.Equal(tmp[:], want) { + want := binary.LittleEndian.Uint32(buf[:4]) + got := uint32(d.crc.Sum64()) + + if got != want { if debugDecoder { - println("CRC Check Failed:", tmp[:], "!=", want) + printf("CRC check failed: got %08x, want %08x\n", got, want) } return ErrCRCMismatch } if debugDecoder { - println("CRC ok", tmp[:]) + printf("CRC ok %08x\n", got) } return nil } -func (d *frameDec) initAsync() { - if !d.o.lowMem && !d.SingleSegment { - // set max extra size history to 2MB. - d.history.maxSize = d.history.windowSize + maxBlockSize - } - // re-alloc if more than one extra block size. - if d.o.lowMem && cap(d.history.b) > d.history.maxSize+maxBlockSize { - d.history.b = make([]byte, 0, d.history.maxSize) - } - if cap(d.history.b) < d.history.maxSize { - d.history.b = make([]byte, 0, d.history.maxSize) - } - if cap(d.decoding) < d.o.concurrent { - d.decoding = make(chan *blockDec, d.o.concurrent) - } - if debugDecoder { - h := d.history - printf("history init. len: %d, cap: %d", len(h.b), cap(h.b)) - } - d.asyncRunningMu.Lock() - d.asyncRunning = true - d.asyncRunningMu.Unlock() -} - -// startDecoder will start decoding blocks and write them to the writer. -// The decoder will stop as soon as an error occurs or at end of frame. -// When the frame has finished decoding the *bufio.Reader -// containing the remaining input will be sent on frameDec.frameDone. -func (d *frameDec) startDecoder(output chan decodeOutput) { - written := int64(0) - - defer func() { - d.asyncRunningMu.Lock() - d.asyncRunning = false - d.asyncRunningMu.Unlock() - - // Drain the currently decoding. - d.history.error = true - flushdone: - for { - select { - case b := <-d.decoding: - b.history <- &d.history - output <- <-b.result - default: - break flushdone - } - } - println("frame decoder done, signalling done") - d.frameDone.Done() - }() - // Get decoder for first block. - block := <-d.decoding - block.history <- &d.history - for { - var next *blockDec - // Get result - r := <-block.result - if r.err != nil { - println("Result contained error", r.err) - output <- r - return - } - if debugDecoder { - println("got result, from ", d.offset, "to", d.offset+int64(len(r.b))) - d.offset += int64(len(r.b)) - } - if !block.Last { - // Send history to next block - select { - case next = <-d.decoding: - if debugDecoder { - println("Sending ", len(d.history.b), "bytes as history") - } - next.history <- &d.history - default: - // Wait until we have sent the block, so - // other decoders can potentially get the decoder. - next = nil - } - } - - // Add checksum, async to decoding. - if d.HasCheckSum { - n, err := d.crc.Write(r.b) - if err != nil { - r.err = err - if n != len(r.b) { - r.err = io.ErrShortWrite - } - output <- r - return - } - } - written += int64(len(r.b)) - if d.SingleSegment && uint64(written) > d.FrameContentSize { - println("runDecoder: single segment and", uint64(written), ">", d.FrameContentSize) - r.err = ErrFrameSizeExceeded - output <- r - return - } - if block.Last { - r.err = d.checkCRC() - output <- r - return - } - output <- r - if next == nil { - // There was no decoder available, we wait for one now that we have sent to the writer. - if debugDecoder { - println("Sending ", len(d.history.b), " bytes as history") - } - next = <-d.decoding - next.history <- &d.history - } - block = next +// consumeCRC skips over the checksum, assuming the frame has one. +func (d *frameDec) consumeCRC() error { + _, err := d.rawInput.readSmall(4) + if err != nil { + println("CRC missing?", err) } + return err } -// runDecoder will create a sync decoder that will decode a block of data. +// runDecoder will run the decoder for the remainder of the frame. func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { saved := d.history.b // We use the history for output to avoid copying it. d.history.b = dst + d.history.ignoreBuffer = len(dst) // Store input length, so we only check new data. crcStart := len(dst) + d.history.decoders.maxSyncLen = 0 + if d.o.limitToCap { + d.history.decoders.maxSyncLen = uint64(cap(dst) - len(dst)) + } + if d.FrameContentSize != fcsUnknown { + if !d.o.limitToCap || d.FrameContentSize+uint64(len(dst)) < d.history.decoders.maxSyncLen { + d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst)) + } + if d.history.decoders.maxSyncLen > d.o.maxDecodedSize { + if debugDecoder { + println("maxSyncLen:", d.history.decoders.maxSyncLen, "> maxDecodedSize:", d.o.maxDecodedSize) + } + return dst, ErrDecoderSizeExceeded + } + if debugDecoder { + println("maxSyncLen:", d.history.decoders.maxSyncLen) + } + if !d.o.limitToCap && uint64(cap(dst)) < d.history.decoders.maxSyncLen { + // Alloc for output + dst2 := make([]byte, len(dst), d.history.decoders.maxSyncLen+compressedBlockOverAlloc) + copy(dst2, dst) + dst = dst2 + } + } var err error for { err = dec.reset(d.rawInput, d.WindowSize) @@ -489,30 +370,41 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { println("next block:", dec) } err = dec.decodeBuf(&d.history) - if err != nil || dec.Last { + if err != nil { + break + } + if uint64(len(d.history.b)-crcStart) > d.o.maxDecodedSize { + println("runDecoder: maxDecodedSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.o.maxDecodedSize) + err = ErrDecoderSizeExceeded break } - if uint64(len(d.history.b)) > d.o.maxDecodedSize { + if d.o.limitToCap && len(d.history.b) > cap(dst) { + println("runDecoder: cap exceeded", uint64(len(d.history.b)), ">", cap(dst)) err = ErrDecoderSizeExceeded break } - if d.SingleSegment && uint64(len(d.history.b)) > d.o.maxDecodedSize { - println("runDecoder: single segment and", uint64(len(d.history.b)), ">", d.o.maxDecodedSize) + if uint64(len(d.history.b)-crcStart) > d.FrameContentSize { + println("runDecoder: FrameContentSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.FrameContentSize) err = ErrFrameSizeExceeded break } + if dec.Last { + break + } + if debugDecoder { + println("runDecoder: FrameContentSize", uint64(len(d.history.b)-crcStart), "<=", d.FrameContentSize) + } } dst = d.history.b if err == nil { - if d.HasCheckSum { - var n int - n, err = d.crc.Write(dst[crcStart:]) - if err == nil { - if n != len(dst)-crcStart { - err = io.ErrShortWrite - } else { - err = d.checkCRC() - } + if d.FrameContentSize != fcsUnknown && uint64(len(d.history.b)-crcStart) != d.FrameContentSize { + err = ErrFrameSizeMismatch + } else if d.HasCheckSum { + if d.o.ignoreChecksum { + err = d.consumeCRC() + } else { + d.crc.Write(dst[crcStart:]) + err = d.checkCRC() } } } diff --git a/vendor/github.com/klauspost/compress/zstd/fse_decoder.go b/vendor/github.com/klauspost/compress/zstd/fse_decoder.go index e6d3d49..2f8860a 100644 --- a/vendor/github.com/klauspost/compress/zstd/fse_decoder.go +++ b/vendor/github.com/klauspost/compress/zstd/fse_decoder.go @@ -5,8 +5,10 @@ package zstd import ( + "encoding/binary" "errors" "fmt" + "io" ) const ( @@ -178,10 +180,32 @@ func (s *fseDecoder) readNCount(b *byteReader, maxSymbol uint16) error { return fmt.Errorf("corruption detected (total %d != %d)", gotTotal, 1<> 3) - // println(s.norm[:s.symbolLen], s.symbolLen) return s.buildDtable() } +func (s *fseDecoder) mustReadFrom(r io.Reader) { + fatalErr := func(err error) { + if err != nil { + panic(err) + } + } + // dt [maxTablesize]decSymbol // Decompression table. + // symbolLen uint16 // Length of active part of the symbol table. + // actualTableLog uint8 // Selected tablelog. + // maxBits uint8 // Maximum number of additional bits + // // used for table creation to avoid allocations. + // stateTable [256]uint16 + // norm [maxSymbolValue + 1]int16 + // preDefined bool + fatalErr(binary.Read(r, binary.LittleEndian, &s.dt)) + fatalErr(binary.Read(r, binary.LittleEndian, &s.symbolLen)) + fatalErr(binary.Read(r, binary.LittleEndian, &s.actualTableLog)) + fatalErr(binary.Read(r, binary.LittleEndian, &s.maxBits)) + fatalErr(binary.Read(r, binary.LittleEndian, &s.stateTable)) + fatalErr(binary.Read(r, binary.LittleEndian, &s.norm)) + fatalErr(binary.Read(r, binary.LittleEndian, &s.preDefined)) +} + // decSymbol contains information about a state entry, // Including the state offset base, the output symbol and // the number of bits to read for the low part of the destination state. @@ -204,18 +228,10 @@ func (d decSymbol) newState() uint16 { return uint16(d >> 16) } -func (d decSymbol) baseline() uint32 { - return uint32(d >> 32) -} - func (d decSymbol) baselineInt() int { return int(d >> 32) } -func (d *decSymbol) set(nbits, addBits uint8, newState uint16, baseline uint32) { - *d = decSymbol(nbits) | (decSymbol(addBits) << 8) | (decSymbol(newState) << 16) | (decSymbol(baseline) << 32) -} - func (d *decSymbol) setNBits(nBits uint8) { const mask = 0xffffffffffffff00 *d = (*d & mask) | decSymbol(nBits) @@ -231,11 +247,6 @@ func (d *decSymbol) setNewState(state uint16) { *d = (*d & mask) | decSymbol(state)<<16 } -func (d *decSymbol) setBaseline(baseline uint32) { - const mask = 0xffffffff - *d = (*d & mask) | decSymbol(baseline)<<32 -} - func (d *decSymbol) setExt(addBits uint8, baseline uint32) { const mask = 0xffff00ff *d = (*d & mask) | (decSymbol(addBits) << 8) | (decSymbol(baseline) << 32) @@ -257,68 +268,6 @@ func (s *fseDecoder) setRLE(symbol decSymbol) { s.dt[0] = symbol } -// buildDtable will build the decoding table. -func (s *fseDecoder) buildDtable() error { - tableSize := uint32(1 << s.actualTableLog) - highThreshold := tableSize - 1 - symbolNext := s.stateTable[:256] - - // Init, lay down lowprob symbols - { - for i, v := range s.norm[:s.symbolLen] { - if v == -1 { - s.dt[highThreshold].setAddBits(uint8(i)) - highThreshold-- - symbolNext[i] = 1 - } else { - symbolNext[i] = uint16(v) - } - } - } - // Spread symbols - { - tableMask := tableSize - 1 - step := tableStep(tableSize) - position := uint32(0) - for ss, v := range s.norm[:s.symbolLen] { - for i := 0; i < int(v); i++ { - s.dt[position].setAddBits(uint8(ss)) - position = (position + step) & tableMask - for position > highThreshold { - // lowprob area - position = (position + step) & tableMask - } - } - } - if position != 0 { - // position must reach all cells once, otherwise normalizedCounter is incorrect - return errors.New("corrupted input (position != 0)") - } - } - - // Build Decoding table - { - tableSize := uint16(1 << s.actualTableLog) - for u, v := range s.dt[:tableSize] { - symbol := v.addBits() - nextState := symbolNext[symbol] - symbolNext[symbol] = nextState + 1 - nBits := s.actualTableLog - byte(highBits(uint32(nextState))) - s.dt[u&maxTableMask].setNBits(nBits) - newState := (nextState << nBits) - tableSize - if newState > tableSize { - return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize) - } - if newState == uint16(u) && nBits == 0 { - // Seems weird that this is possible with nbits > 0. - return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u) - } - s.dt[u&maxTableMask].setNewState(newState) - } - } - return nil -} - // transform will transform the decoder table into a table usable for // decoding without having to apply the transformation while decoding. // The state will contain the base value and the number of bits to read. @@ -352,34 +301,7 @@ func (s *fseState) init(br *bitReader, tableLog uint8, dt []decSymbol) { s.state = dt[br.getBits(tableLog)] } -// next returns the current symbol and sets the next state. -// At least tablelog bits must be available in the bit reader. -func (s *fseState) next(br *bitReader) { - lowBits := uint16(br.getBits(s.state.nbBits())) - s.state = s.dt[s.state.newState()+lowBits] -} - -// finished returns true if all bits have been read from the bitstream -// and the next state would require reading bits from the input. -func (s *fseState) finished(br *bitReader) bool { - return br.finished() && s.state.nbBits() > 0 -} - -// final returns the current state symbol without decoding the next. -func (s *fseState) final() (int, uint8) { - return s.state.baselineInt(), s.state.addBits() -} - // final returns the current state symbol without decoding the next. func (s decSymbol) final() (int, uint8) { return s.baselineInt(), s.addBits() } - -// nextFast returns the next symbol and sets the next state. -// This can only be used if no symbols are 0 bits. -// At least tablelog bits must be available in the bit reader. -func (s *fseState) nextFast(br *bitReader) (uint32, uint8) { - lowBits := uint16(br.getBitsFast(s.state.nbBits())) - s.state = s.dt[s.state.newState()+lowBits] - return s.state.baseline(), s.state.addBits() -} diff --git a/vendor/github.com/klauspost/compress/zstd/fse_decoder_amd64.go b/vendor/github.com/klauspost/compress/zstd/fse_decoder_amd64.go new file mode 100644 index 0000000..d04a829 --- /dev/null +++ b/vendor/github.com/klauspost/compress/zstd/fse_decoder_amd64.go @@ -0,0 +1,65 @@ +//go:build amd64 && !appengine && !noasm && gc +// +build amd64,!appengine,!noasm,gc + +package zstd + +import ( + "fmt" +) + +type buildDtableAsmContext struct { + // inputs + stateTable *uint16 + norm *int16 + dt *uint64 + + // outputs --- set by the procedure in the case of error; + // for interpretation please see the error handling part below + errParam1 uint64 + errParam2 uint64 +} + +// buildDtable_asm is an x86 assembly implementation of fseDecoder.buildDtable. +// Function returns non-zero exit code on error. +// +//go:noescape +func buildDtable_asm(s *fseDecoder, ctx *buildDtableAsmContext) int + +// please keep in sync with _generate/gen_fse.go +const ( + errorCorruptedNormalizedCounter = 1 + errorNewStateTooBig = 2 + errorNewStateNoBits = 3 +) + +// buildDtable will build the decoding table. +func (s *fseDecoder) buildDtable() error { + ctx := buildDtableAsmContext{ + stateTable: &s.stateTable[0], + norm: &s.norm[0], + dt: (*uint64)(&s.dt[0]), + } + code := buildDtable_asm(s, &ctx) + + if code != 0 { + switch code { + case errorCorruptedNormalizedCounter: + position := ctx.errParam1 + return fmt.Errorf("corrupted input (position=%d, expected 0)", position) + + case errorNewStateTooBig: + newState := decSymbol(ctx.errParam1) + size := ctx.errParam2 + return fmt.Errorf("newState (%d) outside table size (%d)", newState, size) + + case errorNewStateNoBits: + newState := decSymbol(ctx.errParam1) + oldState := decSymbol(ctx.errParam2) + return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, oldState) + + default: + return fmt.Errorf("buildDtable_asm returned unhandled nonzero code = %d", code) + } + } + return nil +} diff --git a/vendor/github.com/klauspost/compress/zstd/fse_decoder_amd64.s b/vendor/github.com/klauspost/compress/zstd/fse_decoder_amd64.s new file mode 100644 index 0000000..bcde398 --- /dev/null +++ b/vendor/github.com/klauspost/compress/zstd/fse_decoder_amd64.s @@ -0,0 +1,126 @@ +// Code generated by command: go run gen_fse.go -out ../fse_decoder_amd64.s -pkg=zstd. DO NOT EDIT. + +//go:build !appengine && !noasm && gc && !noasm + +// func buildDtable_asm(s *fseDecoder, ctx *buildDtableAsmContext) int +TEXT ·buildDtable_asm(SB), $0-24 + MOVQ ctx+8(FP), CX + MOVQ s+0(FP), DI + + // Load values + MOVBQZX 4098(DI), DX + XORQ AX, AX + BTSQ DX, AX + MOVQ (CX), BX + MOVQ 16(CX), SI + LEAQ -1(AX), R8 + MOVQ 8(CX), CX + MOVWQZX 4096(DI), DI + + // End load values + // Init, lay down lowprob symbols + XORQ R9, R9 + JMP init_main_loop_condition + +init_main_loop: + MOVWQSX (CX)(R9*2), R10 + CMPW R10, $-1 + JNE do_not_update_high_threshold + MOVB R9, 1(SI)(R8*8) + DECQ R8 + MOVQ $0x0000000000000001, R10 + +do_not_update_high_threshold: + MOVW R10, (BX)(R9*2) + INCQ R9 + +init_main_loop_condition: + CMPQ R9, DI + JL init_main_loop + + // Spread symbols + // Calculate table step + MOVQ AX, R9 + SHRQ $0x01, R9 + MOVQ AX, R10 + SHRQ $0x03, R10 + LEAQ 3(R9)(R10*1), R9 + + // Fill add bits values + LEAQ -1(AX), R10 + XORQ R11, R11 + XORQ R12, R12 + JMP spread_main_loop_condition + +spread_main_loop: + XORQ R13, R13 + MOVWQSX (CX)(R12*2), R14 + JMP spread_inner_loop_condition + +spread_inner_loop: + MOVB R12, 1(SI)(R11*8) + +adjust_position: + ADDQ R9, R11 + ANDQ R10, R11 + CMPQ R11, R8 + JG adjust_position + INCQ R13 + +spread_inner_loop_condition: + CMPQ R13, R14 + JL spread_inner_loop + INCQ R12 + +spread_main_loop_condition: + CMPQ R12, DI + JL spread_main_loop + TESTQ R11, R11 + JZ spread_check_ok + MOVQ ctx+8(FP), AX + MOVQ R11, 24(AX) + MOVQ $+1, ret+16(FP) + RET + +spread_check_ok: + // Build Decoding table + XORQ DI, DI + +build_table_main_table: + MOVBQZX 1(SI)(DI*8), CX + MOVWQZX (BX)(CX*2), R8 + LEAQ 1(R8), R9 + MOVW R9, (BX)(CX*2) + MOVQ R8, R9 + BSRQ R9, R9 + MOVQ DX, CX + SUBQ R9, CX + SHLQ CL, R8 + SUBQ AX, R8 + MOVB CL, (SI)(DI*8) + MOVW R8, 2(SI)(DI*8) + CMPQ R8, AX + JLE build_table_check1_ok + MOVQ ctx+8(FP), CX + MOVQ R8, 24(CX) + MOVQ AX, 32(CX) + MOVQ $+2, ret+16(FP) + RET + +build_table_check1_ok: + TESTB CL, CL + JNZ build_table_check2_ok + CMPW R8, DI + JNE build_table_check2_ok + MOVQ ctx+8(FP), AX + MOVQ R8, 24(AX) + MOVQ DI, 32(AX) + MOVQ $+3, ret+16(FP) + RET + +build_table_check2_ok: + INCQ DI + CMPQ DI, AX + JL build_table_main_table + MOVQ $+0, ret+16(FP) + RET diff --git a/vendor/github.com/klauspost/compress/zstd/fse_decoder_generic.go b/vendor/github.com/klauspost/compress/zstd/fse_decoder_generic.go new file mode 100644 index 0000000..332e51f --- /dev/null +++ b/vendor/github.com/klauspost/compress/zstd/fse_decoder_generic.go @@ -0,0 +1,72 @@ +//go:build !amd64 || appengine || !gc || noasm +// +build !amd64 appengine !gc noasm + +package zstd + +import ( + "errors" + "fmt" +) + +// buildDtable will build the decoding table. +func (s *fseDecoder) buildDtable() error { + tableSize := uint32(1 << s.actualTableLog) + highThreshold := tableSize - 1 + symbolNext := s.stateTable[:256] + + // Init, lay down lowprob symbols + { + for i, v := range s.norm[:s.symbolLen] { + if v == -1 { + s.dt[highThreshold].setAddBits(uint8(i)) + highThreshold-- + symbolNext[i] = 1 + } else { + symbolNext[i] = uint16(v) + } + } + } + + // Spread symbols + { + tableMask := tableSize - 1 + step := tableStep(tableSize) + position := uint32(0) + for ss, v := range s.norm[:s.symbolLen] { + for i := 0; i < int(v); i++ { + s.dt[position].setAddBits(uint8(ss)) + position = (position + step) & tableMask + for position > highThreshold { + // lowprob area + position = (position + step) & tableMask + } + } + } + if position != 0 { + // position must reach all cells once, otherwise normalizedCounter is incorrect + return errors.New("corrupted input (position != 0)") + } + } + + // Build Decoding table + { + tableSize := uint16(1 << s.actualTableLog) + for u, v := range s.dt[:tableSize] { + symbol := v.addBits() + nextState := symbolNext[symbol] + symbolNext[symbol] = nextState + 1 + nBits := s.actualTableLog - byte(highBits(uint32(nextState))) + s.dt[u&maxTableMask].setNBits(nBits) + newState := (nextState << nBits) - tableSize + if newState > tableSize { + return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize) + } + if newState == uint16(u) && nBits == 0 { + // Seems weird that this is possible with nbits > 0. + return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u) + } + s.dt[u&maxTableMask].setNewState(newState) + } + } + return nil +} diff --git a/vendor/github.com/klauspost/compress/zstd/fse_encoder.go b/vendor/github.com/klauspost/compress/zstd/fse_encoder.go index b4757ee..ab26326 100644 --- a/vendor/github.com/klauspost/compress/zstd/fse_encoder.go +++ b/vendor/github.com/klauspost/compress/zstd/fse_encoder.go @@ -62,9 +62,8 @@ func (s symbolTransform) String() string { // To indicate that you have populated the histogram call HistogramFinished // with the value of the highest populated symbol, as well as the number of entries // in the most populated entry. These are accepted at face value. -// The returned slice will always be length 256. -func (s *fseEncoder) Histogram() []uint32 { - return s.count[:] +func (s *fseEncoder) Histogram() *[256]uint32 { + return &s.count } // HistogramFinished can be called to indicate that the histogram has been populated. @@ -77,21 +76,6 @@ func (s *fseEncoder) HistogramFinished(maxSymbol uint8, maxCount int) { s.clearCount = maxCount != 0 } -// prepare will prepare and allocate scratch tables used for both compression and decompression. -func (s *fseEncoder) prepare() (*fseEncoder, error) { - if s == nil { - s = &fseEncoder{} - } - s.useRLE = false - if s.clearCount && s.maxCount == 0 { - for i := range s.count { - s.count[i] = 0 - } - s.clearCount = false - } - return s, nil -} - // allocCtable will allocate tables needed for compression. // If existing tables a re big enough, they are simply re-used. func (s *fseEncoder) allocCtable() { @@ -710,14 +694,6 @@ func (c *cState) init(bw *bitWriter, ct *cTable, first symbolTransform) { c.state = c.stateTable[lu] } -// encode the output symbol provided and write it to the bitstream. -func (c *cState) encode(symbolTT symbolTransform) { - nbBitsOut := (uint32(c.state) + symbolTT.deltaNbBits) >> 16 - dstState := int32(c.state>>(nbBitsOut&15)) + int32(symbolTT.deltaFindState) - c.bw.addBits16NC(c.state, uint8(nbBitsOut)) - c.state = c.stateTable[dstState] -} - // flush will write the tablelog to the output and flush the remaining full bytes. func (c *cState) flush(tableLog uint8) { c.bw.flush32() diff --git a/vendor/github.com/klauspost/compress/zstd/hash.go b/vendor/github.com/klauspost/compress/zstd/hash.go index cf33f29..5d73c21 100644 --- a/vendor/github.com/klauspost/compress/zstd/hash.go +++ b/vendor/github.com/klauspost/compress/zstd/hash.go @@ -33,9 +33,3 @@ func hashLen(u uint64, length, mls uint8) uint32 { return (uint32(u) * prime4bytes) >> (32 - length) } } - -// hash3 returns the hash of the lower 3 bytes of u to fit in a hash table with h bits. -// Preferably h should be a constant and should always be <32. -func hash3(u uint32, h uint8) uint32 { - return ((u << (32 - 24)) * prime3bytes) >> ((32 - h) & 31) -} diff --git a/vendor/github.com/klauspost/compress/zstd/history.go b/vendor/github.com/klauspost/compress/zstd/history.go index f783e32..0916485 100644 --- a/vendor/github.com/klauspost/compress/zstd/history.go +++ b/vendor/github.com/klauspost/compress/zstd/history.go @@ -10,40 +10,48 @@ import ( // history contains the information transferred between blocks. type history struct { - b []byte - huffTree *huff0.Scratch - recentOffsets [3]int + // Literal decompression + huffTree *huff0.Scratch + + // Sequence decompression decoders sequenceDecs - windowSize int - maxSize int - error bool - dict *dict + recentOffsets [3]int + + // History buffer... + b []byte + + // ignoreBuffer is meant to ignore a number of bytes + // when checking for matches in history + ignoreBuffer int + + windowSize int + allocFrameBuffer int // needed? + error bool + dict *dict } // reset will reset the history to initial state of a frame. // The history must already have been initialized to the desired size. func (h *history) reset() { h.b = h.b[:0] + h.ignoreBuffer = 0 h.error = false h.recentOffsets = [3]int{1, 4, 8} - if f := h.decoders.litLengths.fse; f != nil && !f.preDefined { - fseDecoderPool.Put(f) - } - if f := h.decoders.offsets.fse; f != nil && !f.preDefined { - fseDecoderPool.Put(f) - } - if f := h.decoders.matchLengths.fse; f != nil && !f.preDefined { - fseDecoderPool.Put(f) - } - h.decoders = sequenceDecs{} + h.decoders.freeDecoders() + h.decoders = sequenceDecs{br: h.decoders.br} + h.freeHuffDecoder() + h.huffTree = nil + h.dict = nil + //printf("history created: %+v (l: %d, c: %d)", *h, len(h.b), cap(h.b)) +} + +func (h *history) freeHuffDecoder() { if h.huffTree != nil { if h.dict == nil || h.dict.litEnc != h.huffTree { huffDecoderPool.Put(h.huffTree) + h.huffTree = nil } } - h.huffTree = nil - h.dict = nil - //printf("history created: %+v (l: %d, c: %d)", *h, len(h.b), cap(h.b)) } func (h *history) setDict(dict *dict) { @@ -54,6 +62,7 @@ func (h *history) setDict(dict *dict) { h.decoders.litLengths = dict.llDec h.decoders.offsets = dict.ofDec h.decoders.matchLengths = dict.mlDec + h.decoders.dict = dict.content h.recentOffsets = dict.offsets h.huffTree = dict.litEnc } @@ -83,6 +92,24 @@ func (h *history) append(b []byte) { copy(h.b[h.windowSize-len(b):], b) } +// ensureBlock will ensure there is space for at least one block... +func (h *history) ensureBlock() { + if cap(h.b) < h.allocFrameBuffer { + h.b = make([]byte, 0, h.allocFrameBuffer) + return + } + + avail := cap(h.b) - len(h.b) + if avail >= h.windowSize || avail > maxCompressedBlockSize { + return + } + // Move data down so we only have window size left. + // We know we have less than window size in b at this point. + discard := len(h.b) - h.windowSize + copy(h.b, h.b[discard:]) + h.b = h.b[:h.windowSize] +} + // append bytes to history without ever discarding anything. func (h *history) appendKeep(b []byte) { h.b = append(h.b, b...) diff --git a/vendor/github.com/klauspost/compress/zstd/internal/xxhash/README.md b/vendor/github.com/klauspost/compress/zstd/internal/xxhash/README.md index 69aa3bb..777290d 100644 --- a/vendor/github.com/klauspost/compress/zstd/internal/xxhash/README.md +++ b/vendor/github.com/klauspost/compress/zstd/internal/xxhash/README.md @@ -2,12 +2,7 @@ VENDORED: Go to [github.com/cespare/xxhash](https://github.com/cespare/xxhash) for original package. - -[![GoDoc](https://godoc.org/github.com/cespare/xxhash?status.svg)](https://godoc.org/github.com/cespare/xxhash) -[![Build Status](https://travis-ci.org/cespare/xxhash.svg?branch=master)](https://travis-ci.org/cespare/xxhash) - -xxhash is a Go implementation of the 64-bit -[xxHash](http://cyan4973.github.io/xxHash/) algorithm, XXH64. This is a +xxhash is a Go implementation of the 64-bit [xxHash] algorithm, XXH64. This is a high-quality hashing algorithm that is much faster than anything in the Go standard library. @@ -28,31 +23,49 @@ func (*Digest) WriteString(string) (int, error) func (*Digest) Sum64() uint64 ``` -This implementation provides a fast pure-Go implementation and an even faster -assembly implementation for amd64. +The package is written with optimized pure Go and also contains even faster +assembly implementations for amd64 and arm64. If desired, the `purego` build tag +opts into using the Go code even on those architectures. + +[xxHash]: http://cyan4973.github.io/xxHash/ + +## Compatibility + +This package is in a module and the latest code is in version 2 of the module. +You need a version of Go with at least "minimal module compatibility" to use +github.com/cespare/xxhash/v2: + +* 1.9.7+ for Go 1.9 +* 1.10.3+ for Go 1.10 +* Go 1.11 or later + +I recommend using the latest release of Go. ## Benchmarks Here are some quick benchmarks comparing the pure-Go and assembly implementations of Sum64. -| input size | purego | asm | -| --- | --- | --- | -| 5 B | 979.66 MB/s | 1291.17 MB/s | -| 100 B | 7475.26 MB/s | 7973.40 MB/s | -| 4 KB | 17573.46 MB/s | 17602.65 MB/s | -| 10 MB | 17131.46 MB/s | 17142.16 MB/s | +| input size | purego | asm | +| ---------- | --------- | --------- | +| 4 B | 1.3 GB/s | 1.2 GB/s | +| 16 B | 2.9 GB/s | 3.5 GB/s | +| 100 B | 6.9 GB/s | 8.1 GB/s | +| 4 KB | 11.7 GB/s | 16.7 GB/s | +| 10 MB | 12.0 GB/s | 17.3 GB/s | -These numbers were generated on Ubuntu 18.04 with an Intel i7-8700K CPU using -the following commands under Go 1.11.2: +These numbers were generated on Ubuntu 20.04 with an Intel Xeon Platinum 8252C +CPU using the following commands under Go 1.19.2: ``` -$ go test -tags purego -benchtime 10s -bench '/xxhash,direct,bytes' -$ go test -benchtime 10s -bench '/xxhash,direct,bytes' +benchstat <(go test -tags purego -benchtime 500ms -count 15 -bench 'Sum64$') +benchstat <(go test -benchtime 500ms -count 15 -bench 'Sum64$') ``` ## Projects using this package - [InfluxDB](https://github.com/influxdata/influxdb) - [Prometheus](https://github.com/prometheus/prometheus) +- [VictoriaMetrics](https://github.com/VictoriaMetrics/VictoriaMetrics) - [FreeCache](https://github.com/coocood/freecache) +- [FastCache](https://github.com/VictoriaMetrics/fastcache) diff --git a/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash.go b/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash.go index 2c112a0..fc40c82 100644 --- a/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash.go +++ b/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash.go @@ -18,19 +18,11 @@ const ( prime5 uint64 = 2870177450012600261 ) -// NOTE(caleb): I'm using both consts and vars of the primes. Using consts where -// possible in the Go code is worth a small (but measurable) performance boost -// by avoiding some MOVQs. Vars are needed for the asm and also are useful for -// convenience in the Go code in a few places where we need to intentionally -// avoid constant arithmetic (e.g., v1 := prime1 + prime2 fails because the -// result overflows a uint64). -var ( - prime1v = prime1 - prime2v = prime2 - prime3v = prime3 - prime4v = prime4 - prime5v = prime5 -) +// Store the primes in an array as well. +// +// The consts are used when possible in Go code to avoid MOVs but we need a +// contiguous array of the assembly code. +var primes = [...]uint64{prime1, prime2, prime3, prime4, prime5} // Digest implements hash.Hash64. type Digest struct { @@ -52,10 +44,10 @@ func New() *Digest { // Reset clears the Digest's state so that it can be reused. func (d *Digest) Reset() { - d.v1 = prime1v + prime2 + d.v1 = primes[0] + prime2 d.v2 = prime2 d.v3 = 0 - d.v4 = -prime1v + d.v4 = -primes[0] d.total = 0 d.n = 0 } @@ -71,21 +63,23 @@ func (d *Digest) Write(b []byte) (n int, err error) { n = len(b) d.total += uint64(n) + memleft := d.mem[d.n&(len(d.mem)-1):] + if d.n+n < 32 { // This new data doesn't even fill the current block. - copy(d.mem[d.n:], b) + copy(memleft, b) d.n += n return } if d.n > 0 { // Finish off the partial block. - copy(d.mem[d.n:], b) + c := copy(memleft, b) d.v1 = round(d.v1, u64(d.mem[0:8])) d.v2 = round(d.v2, u64(d.mem[8:16])) d.v3 = round(d.v3, u64(d.mem[16:24])) d.v4 = round(d.v4, u64(d.mem[24:32])) - b = b[32-d.n:] + b = b[c:] d.n = 0 } @@ -135,21 +129,20 @@ func (d *Digest) Sum64() uint64 { h += d.total - i, end := 0, d.n - for ; i+8 <= end; i += 8 { - k1 := round(0, u64(d.mem[i:i+8])) + b := d.mem[:d.n&(len(d.mem)-1)] + for ; len(b) >= 8; b = b[8:] { + k1 := round(0, u64(b[:8])) h ^= k1 h = rol27(h)*prime1 + prime4 } - if i+4 <= end { - h ^= uint64(u32(d.mem[i:i+4])) * prime1 + if len(b) >= 4 { + h ^= uint64(u32(b[:4])) * prime1 h = rol23(h)*prime2 + prime3 - i += 4 + b = b[4:] } - for i < end { - h ^= uint64(d.mem[i]) * prime5 + for ; len(b) > 0; b = b[1:] { + h ^= uint64(b[0]) * prime5 h = rol11(h) * prime1 - i++ } h ^= h >> 33 diff --git a/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_amd64.go b/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_amd64.go deleted file mode 100644 index 0ae847f..0000000 --- a/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_amd64.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build !appengine && gc && !purego -// +build !appengine,gc,!purego - -package xxhash - -// Sum64 computes the 64-bit xxHash digest of b. -// -//go:noescape -func Sum64(b []byte) uint64 - -//go:noescape -func writeBlocks(d *Digest, b []byte) int diff --git a/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_amd64.s b/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_amd64.s index be8db5b..ddb63aa 100644 --- a/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_amd64.s +++ b/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_amd64.s @@ -1,215 +1,210 @@ +//go:build !appengine && gc && !purego && !noasm // +build !appengine // +build gc // +build !purego +// +build !noasm #include "textflag.h" -// Register allocation: -// AX h -// SI pointer to advance through b -// DX n -// BX loop end -// R8 v1, k1 -// R9 v2 -// R10 v3 -// R11 v4 -// R12 tmp -// R13 prime1v -// R14 prime2v -// DI prime4v - -// round reads from and advances the buffer pointer in SI. -// It assumes that R13 has prime1v and R14 has prime2v. -#define round(r) \ - MOVQ (SI), R12 \ - ADDQ $8, SI \ - IMULQ R14, R12 \ - ADDQ R12, r \ - ROLQ $31, r \ - IMULQ R13, r - -// mergeRound applies a merge round on the two registers acc and val. -// It assumes that R13 has prime1v, R14 has prime2v, and DI has prime4v. -#define mergeRound(acc, val) \ - IMULQ R14, val \ - ROLQ $31, val \ - IMULQ R13, val \ - XORQ val, acc \ - IMULQ R13, acc \ - ADDQ DI, acc +// Registers: +#define h AX +#define d AX +#define p SI // pointer to advance through b +#define n DX +#define end BX // loop end +#define v1 R8 +#define v2 R9 +#define v3 R10 +#define v4 R11 +#define x R12 +#define prime1 R13 +#define prime2 R14 +#define prime4 DI + +#define round(acc, x) \ + IMULQ prime2, x \ + ADDQ x, acc \ + ROLQ $31, acc \ + IMULQ prime1, acc + +// round0 performs the operation x = round(0, x). +#define round0(x) \ + IMULQ prime2, x \ + ROLQ $31, x \ + IMULQ prime1, x + +// mergeRound applies a merge round on the two registers acc and x. +// It assumes that prime1, prime2, and prime4 have been loaded. +#define mergeRound(acc, x) \ + round0(x) \ + XORQ x, acc \ + IMULQ prime1, acc \ + ADDQ prime4, acc + +// blockLoop processes as many 32-byte blocks as possible, +// updating v1, v2, v3, and v4. It assumes that there is at least one block +// to process. +#define blockLoop() \ +loop: \ + MOVQ +0(p), x \ + round(v1, x) \ + MOVQ +8(p), x \ + round(v2, x) \ + MOVQ +16(p), x \ + round(v3, x) \ + MOVQ +24(p), x \ + round(v4, x) \ + ADDQ $32, p \ + CMPQ p, end \ + JLE loop // func Sum64(b []byte) uint64 -TEXT ·Sum64(SB), NOSPLIT, $0-32 +TEXT ·Sum64(SB), NOSPLIT|NOFRAME, $0-32 // Load fixed primes. - MOVQ ·prime1v(SB), R13 - MOVQ ·prime2v(SB), R14 - MOVQ ·prime4v(SB), DI + MOVQ ·primes+0(SB), prime1 + MOVQ ·primes+8(SB), prime2 + MOVQ ·primes+24(SB), prime4 // Load slice. - MOVQ b_base+0(FP), SI - MOVQ b_len+8(FP), DX - LEAQ (SI)(DX*1), BX + MOVQ b_base+0(FP), p + MOVQ b_len+8(FP), n + LEAQ (p)(n*1), end // The first loop limit will be len(b)-32. - SUBQ $32, BX + SUBQ $32, end // Check whether we have at least one block. - CMPQ DX, $32 + CMPQ n, $32 JLT noBlocks // Set up initial state (v1, v2, v3, v4). - MOVQ R13, R8 - ADDQ R14, R8 - MOVQ R14, R9 - XORQ R10, R10 - XORQ R11, R11 - SUBQ R13, R11 - - // Loop until SI > BX. -blockLoop: - round(R8) - round(R9) - round(R10) - round(R11) - - CMPQ SI, BX - JLE blockLoop - - MOVQ R8, AX - ROLQ $1, AX - MOVQ R9, R12 - ROLQ $7, R12 - ADDQ R12, AX - MOVQ R10, R12 - ROLQ $12, R12 - ADDQ R12, AX - MOVQ R11, R12 - ROLQ $18, R12 - ADDQ R12, AX - - mergeRound(AX, R8) - mergeRound(AX, R9) - mergeRound(AX, R10) - mergeRound(AX, R11) + MOVQ prime1, v1 + ADDQ prime2, v1 + MOVQ prime2, v2 + XORQ v3, v3 + XORQ v4, v4 + SUBQ prime1, v4 + + blockLoop() + + MOVQ v1, h + ROLQ $1, h + MOVQ v2, x + ROLQ $7, x + ADDQ x, h + MOVQ v3, x + ROLQ $12, x + ADDQ x, h + MOVQ v4, x + ROLQ $18, x + ADDQ x, h + + mergeRound(h, v1) + mergeRound(h, v2) + mergeRound(h, v3) + mergeRound(h, v4) JMP afterBlocks noBlocks: - MOVQ ·prime5v(SB), AX + MOVQ ·primes+32(SB), h afterBlocks: - ADDQ DX, AX - - // Right now BX has len(b)-32, and we want to loop until SI > len(b)-8. - ADDQ $24, BX - - CMPQ SI, BX - JG fourByte - -wordLoop: - // Calculate k1. - MOVQ (SI), R8 - ADDQ $8, SI - IMULQ R14, R8 - ROLQ $31, R8 - IMULQ R13, R8 - - XORQ R8, AX - ROLQ $27, AX - IMULQ R13, AX - ADDQ DI, AX - - CMPQ SI, BX - JLE wordLoop - -fourByte: - ADDQ $4, BX - CMPQ SI, BX - JG singles - - MOVL (SI), R8 - ADDQ $4, SI - IMULQ R13, R8 - XORQ R8, AX - - ROLQ $23, AX - IMULQ R14, AX - ADDQ ·prime3v(SB), AX - -singles: - ADDQ $4, BX - CMPQ SI, BX + ADDQ n, h + + ADDQ $24, end + CMPQ p, end + JG try4 + +loop8: + MOVQ (p), x + ADDQ $8, p + round0(x) + XORQ x, h + ROLQ $27, h + IMULQ prime1, h + ADDQ prime4, h + + CMPQ p, end + JLE loop8 + +try4: + ADDQ $4, end + CMPQ p, end + JG try1 + + MOVL (p), x + ADDQ $4, p + IMULQ prime1, x + XORQ x, h + + ROLQ $23, h + IMULQ prime2, h + ADDQ ·primes+16(SB), h + +try1: + ADDQ $4, end + CMPQ p, end JGE finalize -singlesLoop: - MOVBQZX (SI), R12 - ADDQ $1, SI - IMULQ ·prime5v(SB), R12 - XORQ R12, AX +loop1: + MOVBQZX (p), x + ADDQ $1, p + IMULQ ·primes+32(SB), x + XORQ x, h + ROLQ $11, h + IMULQ prime1, h - ROLQ $11, AX - IMULQ R13, AX - - CMPQ SI, BX - JL singlesLoop + CMPQ p, end + JL loop1 finalize: - MOVQ AX, R12 - SHRQ $33, R12 - XORQ R12, AX - IMULQ R14, AX - MOVQ AX, R12 - SHRQ $29, R12 - XORQ R12, AX - IMULQ ·prime3v(SB), AX - MOVQ AX, R12 - SHRQ $32, R12 - XORQ R12, AX - - MOVQ AX, ret+24(FP) + MOVQ h, x + SHRQ $33, x + XORQ x, h + IMULQ prime2, h + MOVQ h, x + SHRQ $29, x + XORQ x, h + IMULQ ·primes+16(SB), h + MOVQ h, x + SHRQ $32, x + XORQ x, h + + MOVQ h, ret+24(FP) RET -// writeBlocks uses the same registers as above except that it uses AX to store -// the d pointer. - // func writeBlocks(d *Digest, b []byte) int -TEXT ·writeBlocks(SB), NOSPLIT, $0-40 +TEXT ·writeBlocks(SB), NOSPLIT|NOFRAME, $0-40 // Load fixed primes needed for round. - MOVQ ·prime1v(SB), R13 - MOVQ ·prime2v(SB), R14 + MOVQ ·primes+0(SB), prime1 + MOVQ ·primes+8(SB), prime2 // Load slice. - MOVQ b_base+8(FP), SI - MOVQ b_len+16(FP), DX - LEAQ (SI)(DX*1), BX - SUBQ $32, BX + MOVQ b_base+8(FP), p + MOVQ b_len+16(FP), n + LEAQ (p)(n*1), end + SUBQ $32, end // Load vN from d. - MOVQ d+0(FP), AX - MOVQ 0(AX), R8 // v1 - MOVQ 8(AX), R9 // v2 - MOVQ 16(AX), R10 // v3 - MOVQ 24(AX), R11 // v4 + MOVQ s+0(FP), d + MOVQ 0(d), v1 + MOVQ 8(d), v2 + MOVQ 16(d), v3 + MOVQ 24(d), v4 // We don't need to check the loop condition here; this function is // always called with at least one block of data to process. -blockLoop: - round(R8) - round(R9) - round(R10) - round(R11) - - CMPQ SI, BX - JLE blockLoop + blockLoop() // Copy vN back to d. - MOVQ R8, 0(AX) - MOVQ R9, 8(AX) - MOVQ R10, 16(AX) - MOVQ R11, 24(AX) - - // The number of bytes written is SI minus the old base pointer. - SUBQ b_base+8(FP), SI - MOVQ SI, ret+32(FP) + MOVQ v1, 0(d) + MOVQ v2, 8(d) + MOVQ v3, 16(d) + MOVQ v4, 24(d) + + // The number of bytes written is p minus the old base pointer. + SUBQ b_base+8(FP), p + MOVQ p, ret+32(FP) RET diff --git a/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_arm64.s b/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_arm64.s new file mode 100644 index 0000000..17901e0 --- /dev/null +++ b/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_arm64.s @@ -0,0 +1,184 @@ +//go:build !appengine && gc && !purego && !noasm +// +build !appengine +// +build gc +// +build !purego +// +build !noasm + +#include "textflag.h" + +// Registers: +#define digest R1 +#define h R2 // return value +#define p R3 // input pointer +#define n R4 // input length +#define nblocks R5 // n / 32 +#define prime1 R7 +#define prime2 R8 +#define prime3 R9 +#define prime4 R10 +#define prime5 R11 +#define v1 R12 +#define v2 R13 +#define v3 R14 +#define v4 R15 +#define x1 R20 +#define x2 R21 +#define x3 R22 +#define x4 R23 + +#define round(acc, x) \ + MADD prime2, acc, x, acc \ + ROR $64-31, acc \ + MUL prime1, acc + +// round0 performs the operation x = round(0, x). +#define round0(x) \ + MUL prime2, x \ + ROR $64-31, x \ + MUL prime1, x + +#define mergeRound(acc, x) \ + round0(x) \ + EOR x, acc \ + MADD acc, prime4, prime1, acc + +// blockLoop processes as many 32-byte blocks as possible, +// updating v1, v2, v3, and v4. It assumes that n >= 32. +#define blockLoop() \ + LSR $5, n, nblocks \ + PCALIGN $16 \ + loop: \ + LDP.P 16(p), (x1, x2) \ + LDP.P 16(p), (x3, x4) \ + round(v1, x1) \ + round(v2, x2) \ + round(v3, x3) \ + round(v4, x4) \ + SUB $1, nblocks \ + CBNZ nblocks, loop + +// func Sum64(b []byte) uint64 +TEXT ·Sum64(SB), NOSPLIT|NOFRAME, $0-32 + LDP b_base+0(FP), (p, n) + + LDP ·primes+0(SB), (prime1, prime2) + LDP ·primes+16(SB), (prime3, prime4) + MOVD ·primes+32(SB), prime5 + + CMP $32, n + CSEL LT, prime5, ZR, h // if n < 32 { h = prime5 } else { h = 0 } + BLT afterLoop + + ADD prime1, prime2, v1 + MOVD prime2, v2 + MOVD $0, v3 + NEG prime1, v4 + + blockLoop() + + ROR $64-1, v1, x1 + ROR $64-7, v2, x2 + ADD x1, x2 + ROR $64-12, v3, x3 + ROR $64-18, v4, x4 + ADD x3, x4 + ADD x2, x4, h + + mergeRound(h, v1) + mergeRound(h, v2) + mergeRound(h, v3) + mergeRound(h, v4) + +afterLoop: + ADD n, h + + TBZ $4, n, try8 + LDP.P 16(p), (x1, x2) + + round0(x1) + + // NOTE: here and below, sequencing the EOR after the ROR (using a + // rotated register) is worth a small but measurable speedup for small + // inputs. + ROR $64-27, h + EOR x1 @> 64-27, h, h + MADD h, prime4, prime1, h + + round0(x2) + ROR $64-27, h + EOR x2 @> 64-27, h, h + MADD h, prime4, prime1, h + +try8: + TBZ $3, n, try4 + MOVD.P 8(p), x1 + + round0(x1) + ROR $64-27, h + EOR x1 @> 64-27, h, h + MADD h, prime4, prime1, h + +try4: + TBZ $2, n, try2 + MOVWU.P 4(p), x2 + + MUL prime1, x2 + ROR $64-23, h + EOR x2 @> 64-23, h, h + MADD h, prime3, prime2, h + +try2: + TBZ $1, n, try1 + MOVHU.P 2(p), x3 + AND $255, x3, x1 + LSR $8, x3, x2 + + MUL prime5, x1 + ROR $64-11, h + EOR x1 @> 64-11, h, h + MUL prime1, h + + MUL prime5, x2 + ROR $64-11, h + EOR x2 @> 64-11, h, h + MUL prime1, h + +try1: + TBZ $0, n, finalize + MOVBU (p), x4 + + MUL prime5, x4 + ROR $64-11, h + EOR x4 @> 64-11, h, h + MUL prime1, h + +finalize: + EOR h >> 33, h + MUL prime2, h + EOR h >> 29, h + MUL prime3, h + EOR h >> 32, h + + MOVD h, ret+24(FP) + RET + +// func writeBlocks(d *Digest, b []byte) int +TEXT ·writeBlocks(SB), NOSPLIT|NOFRAME, $0-40 + LDP ·primes+0(SB), (prime1, prime2) + + // Load state. Assume v[1-4] are stored contiguously. + MOVD d+0(FP), digest + LDP 0(digest), (v1, v2) + LDP 16(digest), (v3, v4) + + LDP b_base+8(FP), (p, n) + + blockLoop() + + // Store updated state. + STP (v1, v2), 0(digest) + STP (v3, v4), 16(digest) + + BIC $31, n + MOVD n, ret+32(FP) + RET diff --git a/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_asm.go b/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_asm.go new file mode 100644 index 0000000..d4221ed --- /dev/null +++ b/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_asm.go @@ -0,0 +1,16 @@ +//go:build (amd64 || arm64) && !appengine && gc && !purego && !noasm +// +build amd64 arm64 +// +build !appengine +// +build gc +// +build !purego +// +build !noasm + +package xxhash + +// Sum64 computes the 64-bit xxHash digest of b. +// +//go:noescape +func Sum64(b []byte) uint64 + +//go:noescape +func writeBlocks(s *Digest, b []byte) int diff --git a/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_other.go b/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_other.go index 1f52f29..0be16ce 100644 --- a/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_other.go +++ b/vendor/github.com/klauspost/compress/zstd/internal/xxhash/xxhash_other.go @@ -1,5 +1,5 @@ -//go:build !amd64 || appengine || !gc || purego -// +build !amd64 appengine !gc purego +//go:build (!amd64 && !arm64) || appengine || !gc || purego || noasm +// +build !amd64,!arm64 appengine !gc purego noasm package xxhash @@ -15,10 +15,10 @@ func Sum64(b []byte) uint64 { var h uint64 if n >= 32 { - v1 := prime1v + prime2 + v1 := primes[0] + prime2 v2 := prime2 v3 := uint64(0) - v4 := -prime1v + v4 := -primes[0] for len(b) >= 32 { v1 = round(v1, u64(b[0:8:len(b)])) v2 = round(v2, u64(b[8:16:len(b)])) @@ -37,19 +37,18 @@ func Sum64(b []byte) uint64 { h += uint64(n) - i, end := 0, len(b) - for ; i+8 <= end; i += 8 { - k1 := round(0, u64(b[i:i+8:len(b)])) + for ; len(b) >= 8; b = b[8:] { + k1 := round(0, u64(b[:8])) h ^= k1 h = rol27(h)*prime1 + prime4 } - if i+4 <= end { - h ^= uint64(u32(b[i:i+4:len(b)])) * prime1 + if len(b) >= 4 { + h ^= uint64(u32(b[:4])) * prime1 h = rol23(h)*prime2 + prime3 - i += 4 + b = b[4:] } - for ; i < end; i++ { - h ^= uint64(b[i]) * prime5 + for ; len(b) > 0; b = b[1:] { + h ^= uint64(b[0]) * prime5 h = rol11(h) * prime1 } diff --git a/vendor/github.com/klauspost/compress/zstd/matchlen_amd64.go b/vendor/github.com/klauspost/compress/zstd/matchlen_amd64.go new file mode 100644 index 0000000..f41932b --- /dev/null +++ b/vendor/github.com/klauspost/compress/zstd/matchlen_amd64.go @@ -0,0 +1,16 @@ +//go:build amd64 && !appengine && !noasm && gc +// +build amd64,!appengine,!noasm,gc + +// Copyright 2019+ Klaus Post. All rights reserved. +// License information can be found in the LICENSE file. + +package zstd + +// matchLen returns how many bytes match in a and b +// +// It assumes that: +// +// len(a) <= len(b) and len(a) > 0 +// +//go:noescape +func matchLen(a []byte, b []byte) int diff --git a/vendor/github.com/klauspost/compress/zstd/matchlen_amd64.s b/vendor/github.com/klauspost/compress/zstd/matchlen_amd64.s new file mode 100644 index 0000000..9a7655c --- /dev/null +++ b/vendor/github.com/klauspost/compress/zstd/matchlen_amd64.s @@ -0,0 +1,68 @@ +// Copied from S2 implementation. + +//go:build !appengine && !noasm && gc && !noasm + +#include "textflag.h" + +// func matchLen(a []byte, b []byte) int +// Requires: BMI +TEXT ·matchLen(SB), NOSPLIT, $0-56 + MOVQ a_base+0(FP), AX + MOVQ b_base+24(FP), CX + MOVQ a_len+8(FP), DX + + // matchLen + XORL SI, SI + CMPL DX, $0x08 + JB matchlen_match4_standalone + +matchlen_loopback_standalone: + MOVQ (AX)(SI*1), BX + XORQ (CX)(SI*1), BX + TESTQ BX, BX + JZ matchlen_loop_standalone + +#ifdef GOAMD64_v3 + TZCNTQ BX, BX +#else + BSFQ BX, BX +#endif + SARQ $0x03, BX + LEAL (SI)(BX*1), SI + JMP gen_match_len_end + +matchlen_loop_standalone: + LEAL -8(DX), DX + LEAL 8(SI), SI + CMPL DX, $0x08 + JAE matchlen_loopback_standalone + +matchlen_match4_standalone: + CMPL DX, $0x04 + JB matchlen_match2_standalone + MOVL (AX)(SI*1), BX + CMPL (CX)(SI*1), BX + JNE matchlen_match2_standalone + LEAL -4(DX), DX + LEAL 4(SI), SI + +matchlen_match2_standalone: + CMPL DX, $0x02 + JB matchlen_match1_standalone + MOVW (AX)(SI*1), BX + CMPW (CX)(SI*1), BX + JNE matchlen_match1_standalone + LEAL -2(DX), DX + LEAL 2(SI), SI + +matchlen_match1_standalone: + CMPL DX, $0x01 + JB gen_match_len_end + MOVB (AX)(SI*1), BL + CMPB (CX)(SI*1), BL + JNE gen_match_len_end + INCL SI + +gen_match_len_end: + MOVQ SI, ret+48(FP) + RET diff --git a/vendor/github.com/klauspost/compress/zstd/matchlen_generic.go b/vendor/github.com/klauspost/compress/zstd/matchlen_generic.go new file mode 100644 index 0000000..57b9c31 --- /dev/null +++ b/vendor/github.com/klauspost/compress/zstd/matchlen_generic.go @@ -0,0 +1,33 @@ +//go:build !amd64 || appengine || !gc || noasm +// +build !amd64 appengine !gc noasm + +// Copyright 2019+ Klaus Post. All rights reserved. +// License information can be found in the LICENSE file. + +package zstd + +import ( + "encoding/binary" + "math/bits" +) + +// matchLen returns the maximum common prefix length of a and b. +// a must be the shortest of the two. +func matchLen(a, b []byte) (n int) { + for ; len(a) >= 8 && len(b) >= 8; a, b = a[8:], b[8:] { + diff := binary.LittleEndian.Uint64(a) ^ binary.LittleEndian.Uint64(b) + if diff != 0 { + return n + bits.TrailingZeros64(diff)>>3 + } + n += 8 + } + + for i := range a { + if a[i] != b[i] { + break + } + n++ + } + return n + +} diff --git a/vendor/github.com/klauspost/compress/zstd/seqdec.go b/vendor/github.com/klauspost/compress/zstd/seqdec.go index 1dd39e6..9405fcf 100644 --- a/vendor/github.com/klauspost/compress/zstd/seqdec.go +++ b/vendor/github.com/klauspost/compress/zstd/seqdec.go @@ -20,6 +20,10 @@ type seq struct { llCode, mlCode, ofCode uint8 } +type seqVals struct { + ll, ml, mo int +} + func (s seq) String() string { if s.offset <= 3 { if s.offset == 0 { @@ -61,16 +65,19 @@ type sequenceDecs struct { offsets sequenceDec matchLengths sequenceDec prevOffset [3]int - hist []byte dict []byte literals []byte out []byte + nSeqs int + br *bitReader + seqSize int windowSize int maxBits uint8 + maxSyncLen uint64 } // initialize all 3 decoders from the stream input. -func (s *sequenceDecs) initialize(br *bitReader, hist *history, literals, out []byte) error { +func (s *sequenceDecs) initialize(br *bitReader, hist *history, out []byte) error { if err := s.litLengths.init(br); err != nil { return errors.New("litLengths:" + err.Error()) } @@ -80,8 +87,7 @@ func (s *sequenceDecs) initialize(br *bitReader, hist *history, literals, out [] if err := s.matchLengths.init(br); err != nil { return errors.New("matchLengths:" + err.Error()) } - s.literals = literals - s.hist = hist.b + s.br = br s.prevOffset = hist.recentOffsets s.maxBits = s.litLengths.fse.maxBits + s.offsets.fse.maxBits + s.matchLengths.fse.maxBits s.windowSize = hist.windowSize @@ -93,16 +99,149 @@ func (s *sequenceDecs) initialize(br *bitReader, hist *history, literals, out [] return nil } +func (s *sequenceDecs) freeDecoders() { + if f := s.litLengths.fse; f != nil && !f.preDefined { + fseDecoderPool.Put(f) + s.litLengths.fse = nil + } + if f := s.offsets.fse; f != nil && !f.preDefined { + fseDecoderPool.Put(f) + s.offsets.fse = nil + } + if f := s.matchLengths.fse; f != nil && !f.preDefined { + fseDecoderPool.Put(f) + s.matchLengths.fse = nil + } +} + +// execute will execute the decoded sequence with the provided history. +// The sequence must be evaluated before being sent. +func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error { + if len(s.dict) == 0 { + return s.executeSimple(seqs, hist) + } + + // Ensure we have enough output size... + if len(s.out)+s.seqSize > cap(s.out) { + addBytes := s.seqSize + len(s.out) + s.out = append(s.out, make([]byte, addBytes)...) + s.out = s.out[:len(s.out)-addBytes] + } + + if debugDecoder { + printf("Execute %d seqs with hist %d, dict %d, literals: %d into %d bytes\n", len(seqs), len(hist), len(s.dict), len(s.literals), s.seqSize) + } + + var t = len(s.out) + out := s.out[:t+s.seqSize] + + for _, seq := range seqs { + // Add literals + copy(out[t:], s.literals[:seq.ll]) + t += seq.ll + s.literals = s.literals[seq.ll:] + + // Copy from dictionary... + if seq.mo > t+len(hist) || seq.mo > s.windowSize { + if len(s.dict) == 0 { + return fmt.Errorf("match offset (%d) bigger than current history (%d)", seq.mo, t+len(hist)) + } + + // we may be in dictionary. + dictO := len(s.dict) - (seq.mo - (t + len(hist))) + if dictO < 0 || dictO >= len(s.dict) { + return fmt.Errorf("match offset (%d) bigger than current history+dict (%d)", seq.mo, t+len(hist)+len(s.dict)) + } + end := dictO + seq.ml + if end > len(s.dict) { + n := len(s.dict) - dictO + copy(out[t:], s.dict[dictO:]) + t += n + seq.ml -= n + } else { + copy(out[t:], s.dict[dictO:end]) + t += end - dictO + continue + } + } + + // Copy from history. + if v := seq.mo - t; v > 0 { + // v is the start position in history from end. + start := len(hist) - v + if seq.ml > v { + // Some goes into current block. + // Copy remainder of history + copy(out[t:], hist[start:]) + t += v + seq.ml -= v + } else { + copy(out[t:], hist[start:start+seq.ml]) + t += seq.ml + continue + } + } + // We must be in current buffer now + if seq.ml > 0 { + start := t - seq.mo + if seq.ml <= t-start { + // No overlap + copy(out[t:], out[start:start+seq.ml]) + t += seq.ml + continue + } else { + // Overlapping copy + // Extend destination slice and copy one byte at the time. + src := out[start : start+seq.ml] + dst := out[t:] + dst = dst[:len(src)] + t += len(src) + // Destination is the space we just added. + for i := range src { + dst[i] = src[i] + } + } + } + } + + // Add final literals + copy(out[t:], s.literals) + if debugDecoder { + t += len(s.literals) + if t != len(out) { + panic(fmt.Errorf("length mismatch, want %d, got %d, ss: %d", len(out), t, s.seqSize)) + } + } + s.out = out + + return nil +} + // decode sequences from the stream with the provided history. -func (s *sequenceDecs) decode(seqs int, br *bitReader, hist []byte) error { +func (s *sequenceDecs) decodeSync(hist []byte) error { + supported, err := s.decodeSyncSimple(hist) + if supported { + return err + } + + br := s.br + seqs := s.nSeqs startSize := len(s.out) // Grab full sizes tables, to avoid bounds checks. llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize] llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state + out := s.out + maxBlockSize := maxCompressedBlockSize + if s.windowSize < maxBlockSize { + maxBlockSize = s.windowSize + } + if debugDecoder { + println("decodeSync: decoding", seqs, "sequences", br.remain(), "bits remain on stream") + } for i := seqs - 1; i >= 0; i-- { if br.overread() { - printf("reading sequence %d, exceeded available data\n", seqs-i) + printf("reading sequence %d, exceeded available data. Overread by %d\n", seqs-i, -br.remain()) return io.ErrUnexpectedEOF } var ll, mo, ml int @@ -151,7 +290,7 @@ func (s *sequenceDecs) decode(seqs int, br *bitReader, hist []byte) error { if temp == 0 { // 0 is not valid; input is corrupted; force offset to 1 - println("temp was 0") + println("WARNING: temp was 0") temp = 1 } @@ -176,51 +315,49 @@ func (s *sequenceDecs) decode(seqs int, br *bitReader, hist []byte) error { if ll > len(s.literals) { return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, len(s.literals)) } - size := ll + ml + len(s.out) + size := ll + ml + len(out) if size-startSize > maxBlockSize { - return fmt.Errorf("output (%d) bigger than max block size", size) + return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize) } - if size > cap(s.out) { + if size > cap(out) { // Not enough size, which can happen under high volume block streaming conditions // but could be if destination slice is too small for sync operations. // over-allocating here can create a large amount of GC pressure so we try to keep // it as contained as possible - used := len(s.out) - startSize + used := len(out) - startSize addBytes := 256 + ll + ml + used>>2 // Clamp to max block size. if used+addBytes > maxBlockSize { addBytes = maxBlockSize - used } - s.out = append(s.out, make([]byte, addBytes)...) - s.out = s.out[:len(s.out)-addBytes] + out = append(out, make([]byte, addBytes)...) + out = out[:len(out)-addBytes] } if ml > maxMatchLen { return fmt.Errorf("match len (%d) bigger than max allowed length", ml) } // Add literals - s.out = append(s.out, s.literals[:ll]...) + out = append(out, s.literals[:ll]...) s.literals = s.literals[ll:] - out := s.out if mo == 0 && ml > 0 { return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml) } - if mo > len(s.out)+len(hist) || mo > s.windowSize { + if mo > len(out)+len(hist) || mo > s.windowSize { if len(s.dict) == 0 { - return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(s.out)+len(hist)) + return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(out)+len(hist)-startSize) } // we may be in dictionary. - dictO := len(s.dict) - (mo - (len(s.out) + len(hist))) + dictO := len(s.dict) - (mo - (len(out) + len(hist))) if dictO < 0 || dictO >= len(s.dict) { - return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(s.out)+len(hist)) + return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(out)+len(hist)-startSize) } end := dictO + ml if end > len(s.dict) { out = append(out, s.dict[dictO:]...) - mo -= len(s.dict) - dictO ml -= len(s.dict) - dictO } else { out = append(out, s.dict[dictO:end]...) @@ -231,26 +368,25 @@ func (s *sequenceDecs) decode(seqs int, br *bitReader, hist []byte) error { // Copy from history. // TODO: Blocks without history could be made to ignore this completely. - if v := mo - len(s.out); v > 0 { + if v := mo - len(out); v > 0 { // v is the start position in history from end. - start := len(s.hist) - v + start := len(hist) - v if ml > v { // Some goes into current block. // Copy remainder of history - out = append(out, s.hist[start:]...) - mo -= v + out = append(out, hist[start:]...) ml -= v } else { - out = append(out, s.hist[start:start+ml]...) + out = append(out, hist[start:start+ml]...) ml = 0 } } // We must be in current buffer now if ml > 0 { - start := len(s.out) - mo - if ml <= len(s.out)-start { + start := len(out) - mo + if ml <= len(out)-start { // No overlap - out = append(out, s.out[start:start+ml]...) + out = append(out, out[start:start+ml]...) } else { // Overlapping copy // Extend destination slice and copy one byte at the time. @@ -264,7 +400,6 @@ func (s *sequenceDecs) decode(seqs int, br *bitReader, hist []byte) error { } } } - s.out = out if i == 0 { // This is the last sequence, so we shouldn't update state. break @@ -278,7 +413,8 @@ func (s *sequenceDecs) decode(seqs int, br *bitReader, hist []byte) error { mlState = mlTable[mlState.newState()&maxTableMask] ofState = ofTable[ofState.newState()&maxTableMask] } else { - bits := br.getBitsFast(nBits) + bits := br.get32BitsFast(nBits) + lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31)) llState = llTable[(llState.newState()+lowBits)&maxTableMask] @@ -291,19 +427,13 @@ func (s *sequenceDecs) decode(seqs int, br *bitReader, hist []byte) error { } } - // Add final literals - s.out = append(s.out, s.literals...) - return nil -} + if size := len(s.literals) + len(out) - startSize; size > maxBlockSize { + return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize) + } -// update states, at least 27 bits must be available. -func (s *sequenceDecs) update(br *bitReader) { - // Max 8 bits - s.litLengths.state.next(br) - // Max 9 bits - s.matchLengths.state.next(br) - // Max 8 bits - s.offsets.state.next(br) + // Add final literals + s.out = append(out, s.literals...) + return br.close() } var bitMask [16]uint16 @@ -314,87 +444,6 @@ func init() { } } -// update states, at least 27 bits must be available. -func (s *sequenceDecs) updateAlt(br *bitReader) { - // Update all 3 states at once. Approx 20% faster. - a, b, c := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state - - nBits := a.nbBits() + b.nbBits() + c.nbBits() - if nBits == 0 { - s.litLengths.state.state = s.litLengths.state.dt[a.newState()] - s.matchLengths.state.state = s.matchLengths.state.dt[b.newState()] - s.offsets.state.state = s.offsets.state.dt[c.newState()] - return - } - bits := br.getBitsFast(nBits) - lowBits := uint16(bits >> ((c.nbBits() + b.nbBits()) & 31)) - s.litLengths.state.state = s.litLengths.state.dt[a.newState()+lowBits] - - lowBits = uint16(bits >> (c.nbBits() & 31)) - lowBits &= bitMask[b.nbBits()&15] - s.matchLengths.state.state = s.matchLengths.state.dt[b.newState()+lowBits] - - lowBits = uint16(bits) & bitMask[c.nbBits()&15] - s.offsets.state.state = s.offsets.state.dt[c.newState()+lowBits] -} - -// nextFast will return new states when there are at least 4 unused bytes left on the stream when done. -func (s *sequenceDecs) nextFast(br *bitReader, llState, mlState, ofState decSymbol) (ll, mo, ml int) { - // Final will not read from stream. - ll, llB := llState.final() - ml, mlB := mlState.final() - mo, moB := ofState.final() - - // extra bits are stored in reverse order. - br.fillFast() - mo += br.getBits(moB) - if s.maxBits > 32 { - br.fillFast() - } - ml += br.getBits(mlB) - ll += br.getBits(llB) - - if moB > 1 { - s.prevOffset[2] = s.prevOffset[1] - s.prevOffset[1] = s.prevOffset[0] - s.prevOffset[0] = mo - return - } - // mo = s.adjustOffset(mo, ll, moB) - // Inlined for rather big speedup - if ll == 0 { - // There is an exception though, when current sequence's literals_length = 0. - // In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2, - // an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte. - mo++ - } - - if mo == 0 { - mo = s.prevOffset[0] - return - } - var temp int - if mo == 3 { - temp = s.prevOffset[0] - 1 - } else { - temp = s.prevOffset[mo] - } - - if temp == 0 { - // 0 is not valid; input is corrupted; force offset to 1 - println("temp was 0") - temp = 1 - } - - if mo != 1 { - s.prevOffset[2] = s.prevOffset[1] - } - s.prevOffset[1] = s.prevOffset[0] - s.prevOffset[0] = temp - mo = temp - return -} - func (s *sequenceDecs) next(br *bitReader, llState, mlState, ofState decSymbol) (ll, mo, ml int) { // Final will not read from stream. ll, llB := llState.final() @@ -457,36 +506,3 @@ func (s *sequenceDecs) adjustOffset(offset, litLen int, offsetB uint8) int { s.prevOffset[0] = temp return temp } - -// mergeHistory will merge history. -func (s *sequenceDecs) mergeHistory(hist *sequenceDecs) (*sequenceDecs, error) { - for i := uint(0); i < 3; i++ { - var sNew, sHist *sequenceDec - switch i { - default: - // same as "case 0": - sNew = &s.litLengths - sHist = &hist.litLengths - case 1: - sNew = &s.offsets - sHist = &hist.offsets - case 2: - sNew = &s.matchLengths - sHist = &hist.matchLengths - } - if sNew.repeat { - if sHist.fse == nil { - return nil, fmt.Errorf("sequence stream %d, repeat requested, but no history", i) - } - continue - } - if sNew.fse == nil { - return nil, fmt.Errorf("sequence stream %d, no fse found", i) - } - if sHist.fse != nil && !sHist.fse.preDefined { - fseDecoderPool.Put(sHist.fse) - } - sHist.fse = sNew.fse - } - return hist, nil -} diff --git a/vendor/github.com/klauspost/compress/zstd/seqdec_amd64.go b/vendor/github.com/klauspost/compress/zstd/seqdec_amd64.go new file mode 100644 index 0000000..8adabd8 --- /dev/null +++ b/vendor/github.com/klauspost/compress/zstd/seqdec_amd64.go @@ -0,0 +1,394 @@ +//go:build amd64 && !appengine && !noasm && gc +// +build amd64,!appengine,!noasm,gc + +package zstd + +import ( + "fmt" + "io" + + "github.com/klauspost/compress/internal/cpuinfo" +) + +type decodeSyncAsmContext struct { + llTable []decSymbol + mlTable []decSymbol + ofTable []decSymbol + llState uint64 + mlState uint64 + ofState uint64 + iteration int + litRemain int + out []byte + outPosition int + literals []byte + litPosition int + history []byte + windowSize int + ll int // set on error (not for all errors, please refer to _generate/gen.go) + ml int // set on error (not for all errors, please refer to _generate/gen.go) + mo int // set on error (not for all errors, please refer to _generate/gen.go) +} + +// sequenceDecs_decodeSync_amd64 implements the main loop of sequenceDecs.decodeSync in x86 asm. +// +// Please refer to seqdec_generic.go for the reference implementation. +// +//go:noescape +func sequenceDecs_decodeSync_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int + +// sequenceDecs_decodeSync_bmi2 implements the main loop of sequenceDecs.decodeSync in x86 asm with BMI2 extensions. +// +//go:noescape +func sequenceDecs_decodeSync_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int + +// sequenceDecs_decodeSync_safe_amd64 does the same as above, but does not write more than output buffer. +// +//go:noescape +func sequenceDecs_decodeSync_safe_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int + +// sequenceDecs_decodeSync_safe_bmi2 does the same as above, but does not write more than output buffer. +// +//go:noescape +func sequenceDecs_decodeSync_safe_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int + +// decode sequences from the stream with the provided history but without a dictionary. +func (s *sequenceDecs) decodeSyncSimple(hist []byte) (bool, error) { + if len(s.dict) > 0 { + return false, nil + } + if s.maxSyncLen == 0 && cap(s.out)-len(s.out) < maxCompressedBlockSize { + return false, nil + } + + // FIXME: Using unsafe memory copies leads to rare, random crashes + // with fuzz testing. It is therefore disabled for now. + const useSafe = true + /* + useSafe := false + if s.maxSyncLen == 0 && cap(s.out)-len(s.out) < maxCompressedBlockSizeAlloc { + useSafe = true + } + if s.maxSyncLen > 0 && cap(s.out)-len(s.out)-compressedBlockOverAlloc < int(s.maxSyncLen) { + useSafe = true + } + if cap(s.literals) < len(s.literals)+compressedBlockOverAlloc { + useSafe = true + } + */ + + br := s.br + + maxBlockSize := maxCompressedBlockSize + if s.windowSize < maxBlockSize { + maxBlockSize = s.windowSize + } + + ctx := decodeSyncAsmContext{ + llTable: s.litLengths.fse.dt[:maxTablesize], + mlTable: s.matchLengths.fse.dt[:maxTablesize], + ofTable: s.offsets.fse.dt[:maxTablesize], + llState: uint64(s.litLengths.state.state), + mlState: uint64(s.matchLengths.state.state), + ofState: uint64(s.offsets.state.state), + iteration: s.nSeqs - 1, + litRemain: len(s.literals), + out: s.out, + outPosition: len(s.out), + literals: s.literals, + windowSize: s.windowSize, + history: hist, + } + + s.seqSize = 0 + startSize := len(s.out) + + var errCode int + if cpuinfo.HasBMI2() { + if useSafe { + errCode = sequenceDecs_decodeSync_safe_bmi2(s, br, &ctx) + } else { + errCode = sequenceDecs_decodeSync_bmi2(s, br, &ctx) + } + } else { + if useSafe { + errCode = sequenceDecs_decodeSync_safe_amd64(s, br, &ctx) + } else { + errCode = sequenceDecs_decodeSync_amd64(s, br, &ctx) + } + } + switch errCode { + case noError: + break + + case errorMatchLenOfsMismatch: + return true, fmt.Errorf("zero matchoff and matchlen (%d) > 0", ctx.ml) + + case errorMatchLenTooBig: + return true, fmt.Errorf("match len (%d) bigger than max allowed length", ctx.ml) + + case errorMatchOffTooBig: + return true, fmt.Errorf("match offset (%d) bigger than current history (%d)", + ctx.mo, ctx.outPosition+len(hist)-startSize) + + case errorNotEnoughLiterals: + return true, fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", + ctx.ll, ctx.litRemain+ctx.ll) + + case errorOverread: + return true, io.ErrUnexpectedEOF + + case errorNotEnoughSpace: + size := ctx.outPosition + ctx.ll + ctx.ml + if debugDecoder { + println("msl:", s.maxSyncLen, "cap", cap(s.out), "bef:", startSize, "sz:", size-startSize, "mbs:", maxBlockSize, "outsz:", cap(s.out)-startSize) + } + return true, fmt.Errorf("output bigger than max block size (%d)", maxBlockSize) + + default: + return true, fmt.Errorf("sequenceDecs_decode returned erronous code %d", errCode) + } + + s.seqSize += ctx.litRemain + if s.seqSize > maxBlockSize { + return true, fmt.Errorf("output bigger than max block size (%d)", maxBlockSize) + } + err := br.close() + if err != nil { + printf("Closing sequences: %v, %+v\n", err, *br) + return true, err + } + + s.literals = s.literals[ctx.litPosition:] + t := ctx.outPosition + s.out = s.out[:t] + + // Add final literals + s.out = append(s.out, s.literals...) + if debugDecoder { + t += len(s.literals) + if t != len(s.out) { + panic(fmt.Errorf("length mismatch, want %d, got %d", len(s.out), t)) + } + } + + return true, nil +} + +// -------------------------------------------------------------------------------- + +type decodeAsmContext struct { + llTable []decSymbol + mlTable []decSymbol + ofTable []decSymbol + llState uint64 + mlState uint64 + ofState uint64 + iteration int + seqs []seqVals + litRemain int +} + +const noError = 0 + +// error reported when mo == 0 && ml > 0 +const errorMatchLenOfsMismatch = 1 + +// error reported when ml > maxMatchLen +const errorMatchLenTooBig = 2 + +// error reported when mo > available history or mo > s.windowSize +const errorMatchOffTooBig = 3 + +// error reported when the sum of literal lengths exeeceds the literal buffer size +const errorNotEnoughLiterals = 4 + +// error reported when capacity of `out` is too small +const errorNotEnoughSpace = 5 + +// error reported when bits are overread. +const errorOverread = 6 + +// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm. +// +// Please refer to seqdec_generic.go for the reference implementation. +// +//go:noescape +func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int + +// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm. +// +// Please refer to seqdec_generic.go for the reference implementation. +// +//go:noescape +func sequenceDecs_decode_56_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int + +// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions. +// +//go:noescape +func sequenceDecs_decode_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int + +// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions. +// +//go:noescape +func sequenceDecs_decode_56_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int + +// decode sequences from the stream without the provided history. +func (s *sequenceDecs) decode(seqs []seqVals) error { + br := s.br + + maxBlockSize := maxCompressedBlockSize + if s.windowSize < maxBlockSize { + maxBlockSize = s.windowSize + } + + ctx := decodeAsmContext{ + llTable: s.litLengths.fse.dt[:maxTablesize], + mlTable: s.matchLengths.fse.dt[:maxTablesize], + ofTable: s.offsets.fse.dt[:maxTablesize], + llState: uint64(s.litLengths.state.state), + mlState: uint64(s.matchLengths.state.state), + ofState: uint64(s.offsets.state.state), + seqs: seqs, + iteration: len(seqs) - 1, + litRemain: len(s.literals), + } + + if debugDecoder { + println("decode: decoding", len(seqs), "sequences", br.remain(), "bits remain on stream") + } + + s.seqSize = 0 + lte56bits := s.maxBits+s.offsets.fse.actualTableLog+s.matchLengths.fse.actualTableLog+s.litLengths.fse.actualTableLog <= 56 + var errCode int + if cpuinfo.HasBMI2() { + if lte56bits { + errCode = sequenceDecs_decode_56_bmi2(s, br, &ctx) + } else { + errCode = sequenceDecs_decode_bmi2(s, br, &ctx) + } + } else { + if lte56bits { + errCode = sequenceDecs_decode_56_amd64(s, br, &ctx) + } else { + errCode = sequenceDecs_decode_amd64(s, br, &ctx) + } + } + if errCode != 0 { + i := len(seqs) - ctx.iteration - 1 + switch errCode { + case errorMatchLenOfsMismatch: + ml := ctx.seqs[i].ml + return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml) + + case errorMatchLenTooBig: + ml := ctx.seqs[i].ml + return fmt.Errorf("match len (%d) bigger than max allowed length", ml) + + case errorNotEnoughLiterals: + ll := ctx.seqs[i].ll + return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, ctx.litRemain+ll) + case errorOverread: + return io.ErrUnexpectedEOF + } + + return fmt.Errorf("sequenceDecs_decode_amd64 returned erronous code %d", errCode) + } + + if ctx.litRemain < 0 { + return fmt.Errorf("literal count is too big: total available %d, total requested %d", + len(s.literals), len(s.literals)-ctx.litRemain) + } + + s.seqSize += ctx.litRemain + if s.seqSize > maxBlockSize { + return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize) + } + if debugDecoder { + println("decode: ", br.remain(), "bits remain on stream. code:", errCode) + } + err := br.close() + if err != nil { + printf("Closing sequences: %v, %+v\n", err, *br) + } + return err +} + +// -------------------------------------------------------------------------------- + +type executeAsmContext struct { + seqs []seqVals + seqIndex int + out []byte + history []byte + literals []byte + outPosition int + litPosition int + windowSize int +} + +// sequenceDecs_executeSimple_amd64 implements the main loop of sequenceDecs.executeSimple in x86 asm. +// +// Returns false if a match offset is too big. +// +// Please refer to seqdec_generic.go for the reference implementation. +// +//go:noescape +func sequenceDecs_executeSimple_amd64(ctx *executeAsmContext) bool + +// Same as above, but with safe memcopies +// +//go:noescape +func sequenceDecs_executeSimple_safe_amd64(ctx *executeAsmContext) bool + +// executeSimple handles cases when dictionary is not used. +func (s *sequenceDecs) executeSimple(seqs []seqVals, hist []byte) error { + // Ensure we have enough output size... + if len(s.out)+s.seqSize+compressedBlockOverAlloc > cap(s.out) { + addBytes := s.seqSize + len(s.out) + compressedBlockOverAlloc + s.out = append(s.out, make([]byte, addBytes)...) + s.out = s.out[:len(s.out)-addBytes] + } + + if debugDecoder { + printf("Execute %d seqs with literals: %d into %d bytes\n", len(seqs), len(s.literals), s.seqSize) + } + + var t = len(s.out) + out := s.out[:t+s.seqSize] + + ctx := executeAsmContext{ + seqs: seqs, + seqIndex: 0, + out: out, + history: hist, + outPosition: t, + litPosition: 0, + literals: s.literals, + windowSize: s.windowSize, + } + var ok bool + if cap(s.literals) < len(s.literals)+compressedBlockOverAlloc { + ok = sequenceDecs_executeSimple_safe_amd64(&ctx) + } else { + ok = sequenceDecs_executeSimple_amd64(&ctx) + } + if !ok { + return fmt.Errorf("match offset (%d) bigger than current history (%d)", + seqs[ctx.seqIndex].mo, ctx.outPosition+len(hist)) + } + s.literals = s.literals[ctx.litPosition:] + t = ctx.outPosition + + // Add final literals + copy(out[t:], s.literals) + if debugDecoder { + t += len(s.literals) + if t != len(out) { + panic(fmt.Errorf("length mismatch, want %d, got %d, ss: %d", len(out), t, s.seqSize)) + } + } + s.out = out + + return nil +} diff --git a/vendor/github.com/klauspost/compress/zstd/seqdec_amd64.s b/vendor/github.com/klauspost/compress/zstd/seqdec_amd64.s new file mode 100644 index 0000000..b6f4ba6 --- /dev/null +++ b/vendor/github.com/klauspost/compress/zstd/seqdec_amd64.s @@ -0,0 +1,4175 @@ +// Code generated by command: go run gen.go -out ../seqdec_amd64.s -pkg=zstd. DO NOT EDIT. + +//go:build !appengine && !noasm && gc && !noasm + +// func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int +// Requires: CMOV +TEXT ·sequenceDecs_decode_amd64(SB), $8-32 + MOVQ br+8(FP), AX + MOVQ 32(AX), DX + MOVBQZX 40(AX), BX + MOVQ 24(AX), SI + MOVQ (AX), AX + ADDQ SI, AX + MOVQ AX, (SP) + MOVQ ctx+16(FP), AX + MOVQ 72(AX), DI + MOVQ 80(AX), R8 + MOVQ 88(AX), R9 + MOVQ 104(AX), R10 + MOVQ s+0(FP), AX + MOVQ 144(AX), R11 + MOVQ 152(AX), R12 + MOVQ 160(AX), R13 + +sequenceDecs_decode_amd64_main_loop: + MOVQ (SP), R14 + + // Fill bitreader to have enough for the offset and match length. + CMPQ SI, $0x08 + JL sequenceDecs_decode_amd64_fill_byte_by_byte + MOVQ BX, AX + SHRQ $0x03, AX + SUBQ AX, R14 + MOVQ (R14), DX + SUBQ AX, SI + ANDQ $0x07, BX + JMP sequenceDecs_decode_amd64_fill_end + +sequenceDecs_decode_amd64_fill_byte_by_byte: + CMPQ SI, $0x00 + JLE sequenceDecs_decode_amd64_fill_check_overread + CMPQ BX, $0x07 + JLE sequenceDecs_decode_amd64_fill_end + SHLQ $0x08, DX + SUBQ $0x01, R14 + SUBQ $0x01, SI + SUBQ $0x08, BX + MOVBQZX (R14), AX + ORQ AX, DX + JMP sequenceDecs_decode_amd64_fill_byte_by_byte + +sequenceDecs_decode_amd64_fill_check_overread: + CMPQ BX, $0x40 + JA error_overread + +sequenceDecs_decode_amd64_fill_end: + // Update offset + MOVQ R9, AX + MOVQ BX, CX + MOVQ DX, R15 + SHLQ CL, R15 + MOVB AH, CL + SHRQ $0x20, AX + TESTQ CX, CX + JZ sequenceDecs_decode_amd64_of_update_zero + ADDQ CX, BX + CMPQ BX, $0x40 + JA sequenceDecs_decode_amd64_of_update_zero + CMPQ CX, $0x40 + JAE sequenceDecs_decode_amd64_of_update_zero + NEGQ CX + SHRQ CL, R15 + ADDQ R15, AX + +sequenceDecs_decode_amd64_of_update_zero: + MOVQ AX, 16(R10) + + // Update match length + MOVQ R8, AX + MOVQ BX, CX + MOVQ DX, R15 + SHLQ CL, R15 + MOVB AH, CL + SHRQ $0x20, AX + TESTQ CX, CX + JZ sequenceDecs_decode_amd64_ml_update_zero + ADDQ CX, BX + CMPQ BX, $0x40 + JA sequenceDecs_decode_amd64_ml_update_zero + CMPQ CX, $0x40 + JAE sequenceDecs_decode_amd64_ml_update_zero + NEGQ CX + SHRQ CL, R15 + ADDQ R15, AX + +sequenceDecs_decode_amd64_ml_update_zero: + MOVQ AX, 8(R10) + + // Fill bitreader to have enough for the remaining + CMPQ SI, $0x08 + JL sequenceDecs_decode_amd64_fill_2_byte_by_byte + MOVQ BX, AX + SHRQ $0x03, AX + SUBQ AX, R14 + MOVQ (R14), DX + SUBQ AX, SI + ANDQ $0x07, BX + JMP sequenceDecs_decode_amd64_fill_2_end + +sequenceDecs_decode_amd64_fill_2_byte_by_byte: + CMPQ SI, $0x00 + JLE sequenceDecs_decode_amd64_fill_2_check_overread + CMPQ BX, $0x07 + JLE sequenceDecs_decode_amd64_fill_2_end + SHLQ $0x08, DX + SUBQ $0x01, R14 + SUBQ $0x01, SI + SUBQ $0x08, BX + MOVBQZX (R14), AX + ORQ AX, DX + JMP sequenceDecs_decode_amd64_fill_2_byte_by_byte + +sequenceDecs_decode_amd64_fill_2_check_overread: + CMPQ BX, $0x40 + JA error_overread + +sequenceDecs_decode_amd64_fill_2_end: + // Update literal length + MOVQ DI, AX + MOVQ BX, CX + MOVQ DX, R15 + SHLQ CL, R15 + MOVB AH, CL + SHRQ $0x20, AX + TESTQ CX, CX + JZ sequenceDecs_decode_amd64_ll_update_zero + ADDQ CX, BX + CMPQ BX, $0x40 + JA sequenceDecs_decode_amd64_ll_update_zero + CMPQ CX, $0x40 + JAE sequenceDecs_decode_amd64_ll_update_zero + NEGQ CX + SHRQ CL, R15 + ADDQ R15, AX + +sequenceDecs_decode_amd64_ll_update_zero: + MOVQ AX, (R10) + + // Fill bitreader for state updates + MOVQ R14, (SP) + MOVQ R9, AX + SHRQ $0x08, AX + MOVBQZX AL, AX + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decode_amd64_skip_update + + // Update Literal Length State + MOVBQZX DI, R14 + SHRQ $0x10, DI + MOVWQZX DI, DI + LEAQ (BX)(R14*1), CX + MOVQ DX, R15 + MOVQ CX, BX + ROLQ CL, R15 + MOVL $0x00000001, BP + MOVB R14, CL + SHLL CL, BP + DECL BP + ANDQ BP, R15 + ADDQ R15, DI + + // Load ctx.llTable + MOVQ ctx+16(FP), CX + MOVQ (CX), CX + MOVQ (CX)(DI*8), DI + + // Update Match Length State + MOVBQZX R8, R14 + SHRQ $0x10, R8 + MOVWQZX R8, R8 + LEAQ (BX)(R14*1), CX + MOVQ DX, R15 + MOVQ CX, BX + ROLQ CL, R15 + MOVL $0x00000001, BP + MOVB R14, CL + SHLL CL, BP + DECL BP + ANDQ BP, R15 + ADDQ R15, R8 + + // Load ctx.mlTable + MOVQ ctx+16(FP), CX + MOVQ 24(CX), CX + MOVQ (CX)(R8*8), R8 + + // Update Offset State + MOVBQZX R9, R14 + SHRQ $0x10, R9 + MOVWQZX R9, R9 + LEAQ (BX)(R14*1), CX + MOVQ DX, R15 + MOVQ CX, BX + ROLQ CL, R15 + MOVL $0x00000001, BP + MOVB R14, CL + SHLL CL, BP + DECL BP + ANDQ BP, R15 + ADDQ R15, R9 + + // Load ctx.ofTable + MOVQ ctx+16(FP), CX + MOVQ 48(CX), CX + MOVQ (CX)(R9*8), R9 + +sequenceDecs_decode_amd64_skip_update: + // Adjust offset + MOVQ 16(R10), CX + CMPQ AX, $0x01 + JBE sequenceDecs_decode_amd64_adjust_offsetB_1_or_0 + MOVQ R12, R13 + MOVQ R11, R12 + MOVQ CX, R11 + JMP sequenceDecs_decode_amd64_after_adjust + +sequenceDecs_decode_amd64_adjust_offsetB_1_or_0: + CMPQ (R10), $0x00000000 + JNE sequenceDecs_decode_amd64_adjust_offset_maybezero + INCQ CX + JMP sequenceDecs_decode_amd64_adjust_offset_nonzero + +sequenceDecs_decode_amd64_adjust_offset_maybezero: + TESTQ CX, CX + JNZ sequenceDecs_decode_amd64_adjust_offset_nonzero + MOVQ R11, CX + JMP sequenceDecs_decode_amd64_after_adjust + +sequenceDecs_decode_amd64_adjust_offset_nonzero: + CMPQ CX, $0x01 + JB sequenceDecs_decode_amd64_adjust_zero + JEQ sequenceDecs_decode_amd64_adjust_one + CMPQ CX, $0x02 + JA sequenceDecs_decode_amd64_adjust_three + JMP sequenceDecs_decode_amd64_adjust_two + +sequenceDecs_decode_amd64_adjust_zero: + MOVQ R11, AX + JMP sequenceDecs_decode_amd64_adjust_test_temp_valid + +sequenceDecs_decode_amd64_adjust_one: + MOVQ R12, AX + JMP sequenceDecs_decode_amd64_adjust_test_temp_valid + +sequenceDecs_decode_amd64_adjust_two: + MOVQ R13, AX + JMP sequenceDecs_decode_amd64_adjust_test_temp_valid + +sequenceDecs_decode_amd64_adjust_three: + LEAQ -1(R11), AX + +sequenceDecs_decode_amd64_adjust_test_temp_valid: + TESTQ AX, AX + JNZ sequenceDecs_decode_amd64_adjust_temp_valid + MOVQ $0x00000001, AX + +sequenceDecs_decode_amd64_adjust_temp_valid: + CMPQ CX, $0x01 + CMOVQNE R12, R13 + MOVQ R11, R12 + MOVQ AX, R11 + MOVQ AX, CX + +sequenceDecs_decode_amd64_after_adjust: + MOVQ CX, 16(R10) + + // Check values + MOVQ 8(R10), AX + MOVQ (R10), R14 + LEAQ (AX)(R14*1), R15 + MOVQ s+0(FP), BP + ADDQ R15, 256(BP) + MOVQ ctx+16(FP), R15 + SUBQ R14, 128(R15) + JS error_not_enough_literals + CMPQ AX, $0x00020002 + JA sequenceDecs_decode_amd64_error_match_len_too_big + TESTQ CX, CX + JNZ sequenceDecs_decode_amd64_match_len_ofs_ok + TESTQ AX, AX + JNZ sequenceDecs_decode_amd64_error_match_len_ofs_mismatch + +sequenceDecs_decode_amd64_match_len_ofs_ok: + ADDQ $0x18, R10 + MOVQ ctx+16(FP), AX + DECQ 96(AX) + JNS sequenceDecs_decode_amd64_main_loop + MOVQ s+0(FP), AX + MOVQ R11, 144(AX) + MOVQ R12, 152(AX) + MOVQ R13, 160(AX) + MOVQ br+8(FP), AX + MOVQ DX, 32(AX) + MOVB BL, 40(AX) + MOVQ SI, 24(AX) + + // Return success + MOVQ $0x00000000, ret+24(FP) + RET + + // Return with match length error +sequenceDecs_decode_amd64_error_match_len_ofs_mismatch: + MOVQ $0x00000001, ret+24(FP) + RET + + // Return with match too long error +sequenceDecs_decode_amd64_error_match_len_too_big: + MOVQ $0x00000002, ret+24(FP) + RET + + // Return with match offset too long error + MOVQ $0x00000003, ret+24(FP) + RET + + // Return with not enough literals error +error_not_enough_literals: + MOVQ $0x00000004, ret+24(FP) + RET + + // Return with overread error +error_overread: + MOVQ $0x00000006, ret+24(FP) + RET + +// func sequenceDecs_decode_56_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int +// Requires: CMOV +TEXT ·sequenceDecs_decode_56_amd64(SB), $8-32 + MOVQ br+8(FP), AX + MOVQ 32(AX), DX + MOVBQZX 40(AX), BX + MOVQ 24(AX), SI + MOVQ (AX), AX + ADDQ SI, AX + MOVQ AX, (SP) + MOVQ ctx+16(FP), AX + MOVQ 72(AX), DI + MOVQ 80(AX), R8 + MOVQ 88(AX), R9 + MOVQ 104(AX), R10 + MOVQ s+0(FP), AX + MOVQ 144(AX), R11 + MOVQ 152(AX), R12 + MOVQ 160(AX), R13 + +sequenceDecs_decode_56_amd64_main_loop: + MOVQ (SP), R14 + + // Fill bitreader to have enough for the offset and match length. + CMPQ SI, $0x08 + JL sequenceDecs_decode_56_amd64_fill_byte_by_byte + MOVQ BX, AX + SHRQ $0x03, AX + SUBQ AX, R14 + MOVQ (R14), DX + SUBQ AX, SI + ANDQ $0x07, BX + JMP sequenceDecs_decode_56_amd64_fill_end + +sequenceDecs_decode_56_amd64_fill_byte_by_byte: + CMPQ SI, $0x00 + JLE sequenceDecs_decode_56_amd64_fill_check_overread + CMPQ BX, $0x07 + JLE sequenceDecs_decode_56_amd64_fill_end + SHLQ $0x08, DX + SUBQ $0x01, R14 + SUBQ $0x01, SI + SUBQ $0x08, BX + MOVBQZX (R14), AX + ORQ AX, DX + JMP sequenceDecs_decode_56_amd64_fill_byte_by_byte + +sequenceDecs_decode_56_amd64_fill_check_overread: + CMPQ BX, $0x40 + JA error_overread + +sequenceDecs_decode_56_amd64_fill_end: + // Update offset + MOVQ R9, AX + MOVQ BX, CX + MOVQ DX, R15 + SHLQ CL, R15 + MOVB AH, CL + SHRQ $0x20, AX + TESTQ CX, CX + JZ sequenceDecs_decode_56_amd64_of_update_zero + ADDQ CX, BX + CMPQ BX, $0x40 + JA sequenceDecs_decode_56_amd64_of_update_zero + CMPQ CX, $0x40 + JAE sequenceDecs_decode_56_amd64_of_update_zero + NEGQ CX + SHRQ CL, R15 + ADDQ R15, AX + +sequenceDecs_decode_56_amd64_of_update_zero: + MOVQ AX, 16(R10) + + // Update match length + MOVQ R8, AX + MOVQ BX, CX + MOVQ DX, R15 + SHLQ CL, R15 + MOVB AH, CL + SHRQ $0x20, AX + TESTQ CX, CX + JZ sequenceDecs_decode_56_amd64_ml_update_zero + ADDQ CX, BX + CMPQ BX, $0x40 + JA sequenceDecs_decode_56_amd64_ml_update_zero + CMPQ CX, $0x40 + JAE sequenceDecs_decode_56_amd64_ml_update_zero + NEGQ CX + SHRQ CL, R15 + ADDQ R15, AX + +sequenceDecs_decode_56_amd64_ml_update_zero: + MOVQ AX, 8(R10) + + // Update literal length + MOVQ DI, AX + MOVQ BX, CX + MOVQ DX, R15 + SHLQ CL, R15 + MOVB AH, CL + SHRQ $0x20, AX + TESTQ CX, CX + JZ sequenceDecs_decode_56_amd64_ll_update_zero + ADDQ CX, BX + CMPQ BX, $0x40 + JA sequenceDecs_decode_56_amd64_ll_update_zero + CMPQ CX, $0x40 + JAE sequenceDecs_decode_56_amd64_ll_update_zero + NEGQ CX + SHRQ CL, R15 + ADDQ R15, AX + +sequenceDecs_decode_56_amd64_ll_update_zero: + MOVQ AX, (R10) + + // Fill bitreader for state updates + MOVQ R14, (SP) + MOVQ R9, AX + SHRQ $0x08, AX + MOVBQZX AL, AX + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decode_56_amd64_skip_update + + // Update Literal Length State + MOVBQZX DI, R14 + SHRQ $0x10, DI + MOVWQZX DI, DI + LEAQ (BX)(R14*1), CX + MOVQ DX, R15 + MOVQ CX, BX + ROLQ CL, R15 + MOVL $0x00000001, BP + MOVB R14, CL + SHLL CL, BP + DECL BP + ANDQ BP, R15 + ADDQ R15, DI + + // Load ctx.llTable + MOVQ ctx+16(FP), CX + MOVQ (CX), CX + MOVQ (CX)(DI*8), DI + + // Update Match Length State + MOVBQZX R8, R14 + SHRQ $0x10, R8 + MOVWQZX R8, R8 + LEAQ (BX)(R14*1), CX + MOVQ DX, R15 + MOVQ CX, BX + ROLQ CL, R15 + MOVL $0x00000001, BP + MOVB R14, CL + SHLL CL, BP + DECL BP + ANDQ BP, R15 + ADDQ R15, R8 + + // Load ctx.mlTable + MOVQ ctx+16(FP), CX + MOVQ 24(CX), CX + MOVQ (CX)(R8*8), R8 + + // Update Offset State + MOVBQZX R9, R14 + SHRQ $0x10, R9 + MOVWQZX R9, R9 + LEAQ (BX)(R14*1), CX + MOVQ DX, R15 + MOVQ CX, BX + ROLQ CL, R15 + MOVL $0x00000001, BP + MOVB R14, CL + SHLL CL, BP + DECL BP + ANDQ BP, R15 + ADDQ R15, R9 + + // Load ctx.ofTable + MOVQ ctx+16(FP), CX + MOVQ 48(CX), CX + MOVQ (CX)(R9*8), R9 + +sequenceDecs_decode_56_amd64_skip_update: + // Adjust offset + MOVQ 16(R10), CX + CMPQ AX, $0x01 + JBE sequenceDecs_decode_56_amd64_adjust_offsetB_1_or_0 + MOVQ R12, R13 + MOVQ R11, R12 + MOVQ CX, R11 + JMP sequenceDecs_decode_56_amd64_after_adjust + +sequenceDecs_decode_56_amd64_adjust_offsetB_1_or_0: + CMPQ (R10), $0x00000000 + JNE sequenceDecs_decode_56_amd64_adjust_offset_maybezero + INCQ CX + JMP sequenceDecs_decode_56_amd64_adjust_offset_nonzero + +sequenceDecs_decode_56_amd64_adjust_offset_maybezero: + TESTQ CX, CX + JNZ sequenceDecs_decode_56_amd64_adjust_offset_nonzero + MOVQ R11, CX + JMP sequenceDecs_decode_56_amd64_after_adjust + +sequenceDecs_decode_56_amd64_adjust_offset_nonzero: + CMPQ CX, $0x01 + JB sequenceDecs_decode_56_amd64_adjust_zero + JEQ sequenceDecs_decode_56_amd64_adjust_one + CMPQ CX, $0x02 + JA sequenceDecs_decode_56_amd64_adjust_three + JMP sequenceDecs_decode_56_amd64_adjust_two + +sequenceDecs_decode_56_amd64_adjust_zero: + MOVQ R11, AX + JMP sequenceDecs_decode_56_amd64_adjust_test_temp_valid + +sequenceDecs_decode_56_amd64_adjust_one: + MOVQ R12, AX + JMP sequenceDecs_decode_56_amd64_adjust_test_temp_valid + +sequenceDecs_decode_56_amd64_adjust_two: + MOVQ R13, AX + JMP sequenceDecs_decode_56_amd64_adjust_test_temp_valid + +sequenceDecs_decode_56_amd64_adjust_three: + LEAQ -1(R11), AX + +sequenceDecs_decode_56_amd64_adjust_test_temp_valid: + TESTQ AX, AX + JNZ sequenceDecs_decode_56_amd64_adjust_temp_valid + MOVQ $0x00000001, AX + +sequenceDecs_decode_56_amd64_adjust_temp_valid: + CMPQ CX, $0x01 + CMOVQNE R12, R13 + MOVQ R11, R12 + MOVQ AX, R11 + MOVQ AX, CX + +sequenceDecs_decode_56_amd64_after_adjust: + MOVQ CX, 16(R10) + + // Check values + MOVQ 8(R10), AX + MOVQ (R10), R14 + LEAQ (AX)(R14*1), R15 + MOVQ s+0(FP), BP + ADDQ R15, 256(BP) + MOVQ ctx+16(FP), R15 + SUBQ R14, 128(R15) + JS error_not_enough_literals + CMPQ AX, $0x00020002 + JA sequenceDecs_decode_56_amd64_error_match_len_too_big + TESTQ CX, CX + JNZ sequenceDecs_decode_56_amd64_match_len_ofs_ok + TESTQ AX, AX + JNZ sequenceDecs_decode_56_amd64_error_match_len_ofs_mismatch + +sequenceDecs_decode_56_amd64_match_len_ofs_ok: + ADDQ $0x18, R10 + MOVQ ctx+16(FP), AX + DECQ 96(AX) + JNS sequenceDecs_decode_56_amd64_main_loop + MOVQ s+0(FP), AX + MOVQ R11, 144(AX) + MOVQ R12, 152(AX) + MOVQ R13, 160(AX) + MOVQ br+8(FP), AX + MOVQ DX, 32(AX) + MOVB BL, 40(AX) + MOVQ SI, 24(AX) + + // Return success + MOVQ $0x00000000, ret+24(FP) + RET + + // Return with match length error +sequenceDecs_decode_56_amd64_error_match_len_ofs_mismatch: + MOVQ $0x00000001, ret+24(FP) + RET + + // Return with match too long error +sequenceDecs_decode_56_amd64_error_match_len_too_big: + MOVQ $0x00000002, ret+24(FP) + RET + + // Return with match offset too long error + MOVQ $0x00000003, ret+24(FP) + RET + + // Return with not enough literals error +error_not_enough_literals: + MOVQ $0x00000004, ret+24(FP) + RET + + // Return with overread error +error_overread: + MOVQ $0x00000006, ret+24(FP) + RET + +// func sequenceDecs_decode_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int +// Requires: BMI, BMI2, CMOV +TEXT ·sequenceDecs_decode_bmi2(SB), $8-32 + MOVQ br+8(FP), CX + MOVQ 32(CX), AX + MOVBQZX 40(CX), DX + MOVQ 24(CX), BX + MOVQ (CX), CX + ADDQ BX, CX + MOVQ CX, (SP) + MOVQ ctx+16(FP), CX + MOVQ 72(CX), SI + MOVQ 80(CX), DI + MOVQ 88(CX), R8 + MOVQ 104(CX), R9 + MOVQ s+0(FP), CX + MOVQ 144(CX), R10 + MOVQ 152(CX), R11 + MOVQ 160(CX), R12 + +sequenceDecs_decode_bmi2_main_loop: + MOVQ (SP), R13 + + // Fill bitreader to have enough for the offset and match length. + CMPQ BX, $0x08 + JL sequenceDecs_decode_bmi2_fill_byte_by_byte + MOVQ DX, CX + SHRQ $0x03, CX + SUBQ CX, R13 + MOVQ (R13), AX + SUBQ CX, BX + ANDQ $0x07, DX + JMP sequenceDecs_decode_bmi2_fill_end + +sequenceDecs_decode_bmi2_fill_byte_by_byte: + CMPQ BX, $0x00 + JLE sequenceDecs_decode_bmi2_fill_check_overread + CMPQ DX, $0x07 + JLE sequenceDecs_decode_bmi2_fill_end + SHLQ $0x08, AX + SUBQ $0x01, R13 + SUBQ $0x01, BX + SUBQ $0x08, DX + MOVBQZX (R13), CX + ORQ CX, AX + JMP sequenceDecs_decode_bmi2_fill_byte_by_byte + +sequenceDecs_decode_bmi2_fill_check_overread: + CMPQ DX, $0x40 + JA error_overread + +sequenceDecs_decode_bmi2_fill_end: + // Update offset + MOVQ $0x00000808, CX + BEXTRQ CX, R8, R14 + MOVQ AX, R15 + LEAQ (DX)(R14*1), CX + ROLQ CL, R15 + BZHIQ R14, R15, R15 + MOVQ CX, DX + MOVQ R8, CX + SHRQ $0x20, CX + ADDQ R15, CX + MOVQ CX, 16(R9) + + // Update match length + MOVQ $0x00000808, CX + BEXTRQ CX, DI, R14 + MOVQ AX, R15 + LEAQ (DX)(R14*1), CX + ROLQ CL, R15 + BZHIQ R14, R15, R15 + MOVQ CX, DX + MOVQ DI, CX + SHRQ $0x20, CX + ADDQ R15, CX + MOVQ CX, 8(R9) + + // Fill bitreader to have enough for the remaining + CMPQ BX, $0x08 + JL sequenceDecs_decode_bmi2_fill_2_byte_by_byte + MOVQ DX, CX + SHRQ $0x03, CX + SUBQ CX, R13 + MOVQ (R13), AX + SUBQ CX, BX + ANDQ $0x07, DX + JMP sequenceDecs_decode_bmi2_fill_2_end + +sequenceDecs_decode_bmi2_fill_2_byte_by_byte: + CMPQ BX, $0x00 + JLE sequenceDecs_decode_bmi2_fill_2_check_overread + CMPQ DX, $0x07 + JLE sequenceDecs_decode_bmi2_fill_2_end + SHLQ $0x08, AX + SUBQ $0x01, R13 + SUBQ $0x01, BX + SUBQ $0x08, DX + MOVBQZX (R13), CX + ORQ CX, AX + JMP sequenceDecs_decode_bmi2_fill_2_byte_by_byte + +sequenceDecs_decode_bmi2_fill_2_check_overread: + CMPQ DX, $0x40 + JA error_overread + +sequenceDecs_decode_bmi2_fill_2_end: + // Update literal length + MOVQ $0x00000808, CX + BEXTRQ CX, SI, R14 + MOVQ AX, R15 + LEAQ (DX)(R14*1), CX + ROLQ CL, R15 + BZHIQ R14, R15, R15 + MOVQ CX, DX + MOVQ SI, CX + SHRQ $0x20, CX + ADDQ R15, CX + MOVQ CX, (R9) + + // Fill bitreader for state updates + MOVQ R13, (SP) + MOVQ $0x00000808, CX + BEXTRQ CX, R8, R13 + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decode_bmi2_skip_update + LEAQ (SI)(DI*1), R14 + ADDQ R8, R14 + MOVBQZX R14, R14 + LEAQ (DX)(R14*1), CX + MOVQ AX, R15 + MOVQ CX, DX + ROLQ CL, R15 + BZHIQ R14, R15, R15 + + // Update Offset State + BZHIQ R8, R15, CX + SHRXQ R8, R15, R15 + MOVQ $0x00001010, R14 + BEXTRQ R14, R8, R8 + ADDQ CX, R8 + + // Load ctx.ofTable + MOVQ ctx+16(FP), CX + MOVQ 48(CX), CX + MOVQ (CX)(R8*8), R8 + + // Update Match Length State + BZHIQ DI, R15, CX + SHRXQ DI, R15, R15 + MOVQ $0x00001010, R14 + BEXTRQ R14, DI, DI + ADDQ CX, DI + + // Load ctx.mlTable + MOVQ ctx+16(FP), CX + MOVQ 24(CX), CX + MOVQ (CX)(DI*8), DI + + // Update Literal Length State + BZHIQ SI, R15, CX + MOVQ $0x00001010, R14 + BEXTRQ R14, SI, SI + ADDQ CX, SI + + // Load ctx.llTable + MOVQ ctx+16(FP), CX + MOVQ (CX), CX + MOVQ (CX)(SI*8), SI + +sequenceDecs_decode_bmi2_skip_update: + // Adjust offset + MOVQ 16(R9), CX + CMPQ R13, $0x01 + JBE sequenceDecs_decode_bmi2_adjust_offsetB_1_or_0 + MOVQ R11, R12 + MOVQ R10, R11 + MOVQ CX, R10 + JMP sequenceDecs_decode_bmi2_after_adjust + +sequenceDecs_decode_bmi2_adjust_offsetB_1_or_0: + CMPQ (R9), $0x00000000 + JNE sequenceDecs_decode_bmi2_adjust_offset_maybezero + INCQ CX + JMP sequenceDecs_decode_bmi2_adjust_offset_nonzero + +sequenceDecs_decode_bmi2_adjust_offset_maybezero: + TESTQ CX, CX + JNZ sequenceDecs_decode_bmi2_adjust_offset_nonzero + MOVQ R10, CX + JMP sequenceDecs_decode_bmi2_after_adjust + +sequenceDecs_decode_bmi2_adjust_offset_nonzero: + CMPQ CX, $0x01 + JB sequenceDecs_decode_bmi2_adjust_zero + JEQ sequenceDecs_decode_bmi2_adjust_one + CMPQ CX, $0x02 + JA sequenceDecs_decode_bmi2_adjust_three + JMP sequenceDecs_decode_bmi2_adjust_two + +sequenceDecs_decode_bmi2_adjust_zero: + MOVQ R10, R13 + JMP sequenceDecs_decode_bmi2_adjust_test_temp_valid + +sequenceDecs_decode_bmi2_adjust_one: + MOVQ R11, R13 + JMP sequenceDecs_decode_bmi2_adjust_test_temp_valid + +sequenceDecs_decode_bmi2_adjust_two: + MOVQ R12, R13 + JMP sequenceDecs_decode_bmi2_adjust_test_temp_valid + +sequenceDecs_decode_bmi2_adjust_three: + LEAQ -1(R10), R13 + +sequenceDecs_decode_bmi2_adjust_test_temp_valid: + TESTQ R13, R13 + JNZ sequenceDecs_decode_bmi2_adjust_temp_valid + MOVQ $0x00000001, R13 + +sequenceDecs_decode_bmi2_adjust_temp_valid: + CMPQ CX, $0x01 + CMOVQNE R11, R12 + MOVQ R10, R11 + MOVQ R13, R10 + MOVQ R13, CX + +sequenceDecs_decode_bmi2_after_adjust: + MOVQ CX, 16(R9) + + // Check values + MOVQ 8(R9), R13 + MOVQ (R9), R14 + LEAQ (R13)(R14*1), R15 + MOVQ s+0(FP), BP + ADDQ R15, 256(BP) + MOVQ ctx+16(FP), R15 + SUBQ R14, 128(R15) + JS error_not_enough_literals + CMPQ R13, $0x00020002 + JA sequenceDecs_decode_bmi2_error_match_len_too_big + TESTQ CX, CX + JNZ sequenceDecs_decode_bmi2_match_len_ofs_ok + TESTQ R13, R13 + JNZ sequenceDecs_decode_bmi2_error_match_len_ofs_mismatch + +sequenceDecs_decode_bmi2_match_len_ofs_ok: + ADDQ $0x18, R9 + MOVQ ctx+16(FP), CX + DECQ 96(CX) + JNS sequenceDecs_decode_bmi2_main_loop + MOVQ s+0(FP), CX + MOVQ R10, 144(CX) + MOVQ R11, 152(CX) + MOVQ R12, 160(CX) + MOVQ br+8(FP), CX + MOVQ AX, 32(CX) + MOVB DL, 40(CX) + MOVQ BX, 24(CX) + + // Return success + MOVQ $0x00000000, ret+24(FP) + RET + + // Return with match length error +sequenceDecs_decode_bmi2_error_match_len_ofs_mismatch: + MOVQ $0x00000001, ret+24(FP) + RET + + // Return with match too long error +sequenceDecs_decode_bmi2_error_match_len_too_big: + MOVQ $0x00000002, ret+24(FP) + RET + + // Return with match offset too long error + MOVQ $0x00000003, ret+24(FP) + RET + + // Return with not enough literals error +error_not_enough_literals: + MOVQ $0x00000004, ret+24(FP) + RET + + // Return with overread error +error_overread: + MOVQ $0x00000006, ret+24(FP) + RET + +// func sequenceDecs_decode_56_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int +// Requires: BMI, BMI2, CMOV +TEXT ·sequenceDecs_decode_56_bmi2(SB), $8-32 + MOVQ br+8(FP), CX + MOVQ 32(CX), AX + MOVBQZX 40(CX), DX + MOVQ 24(CX), BX + MOVQ (CX), CX + ADDQ BX, CX + MOVQ CX, (SP) + MOVQ ctx+16(FP), CX + MOVQ 72(CX), SI + MOVQ 80(CX), DI + MOVQ 88(CX), R8 + MOVQ 104(CX), R9 + MOVQ s+0(FP), CX + MOVQ 144(CX), R10 + MOVQ 152(CX), R11 + MOVQ 160(CX), R12 + +sequenceDecs_decode_56_bmi2_main_loop: + MOVQ (SP), R13 + + // Fill bitreader to have enough for the offset and match length. + CMPQ BX, $0x08 + JL sequenceDecs_decode_56_bmi2_fill_byte_by_byte + MOVQ DX, CX + SHRQ $0x03, CX + SUBQ CX, R13 + MOVQ (R13), AX + SUBQ CX, BX + ANDQ $0x07, DX + JMP sequenceDecs_decode_56_bmi2_fill_end + +sequenceDecs_decode_56_bmi2_fill_byte_by_byte: + CMPQ BX, $0x00 + JLE sequenceDecs_decode_56_bmi2_fill_check_overread + CMPQ DX, $0x07 + JLE sequenceDecs_decode_56_bmi2_fill_end + SHLQ $0x08, AX + SUBQ $0x01, R13 + SUBQ $0x01, BX + SUBQ $0x08, DX + MOVBQZX (R13), CX + ORQ CX, AX + JMP sequenceDecs_decode_56_bmi2_fill_byte_by_byte + +sequenceDecs_decode_56_bmi2_fill_check_overread: + CMPQ DX, $0x40 + JA error_overread + +sequenceDecs_decode_56_bmi2_fill_end: + // Update offset + MOVQ $0x00000808, CX + BEXTRQ CX, R8, R14 + MOVQ AX, R15 + LEAQ (DX)(R14*1), CX + ROLQ CL, R15 + BZHIQ R14, R15, R15 + MOVQ CX, DX + MOVQ R8, CX + SHRQ $0x20, CX + ADDQ R15, CX + MOVQ CX, 16(R9) + + // Update match length + MOVQ $0x00000808, CX + BEXTRQ CX, DI, R14 + MOVQ AX, R15 + LEAQ (DX)(R14*1), CX + ROLQ CL, R15 + BZHIQ R14, R15, R15 + MOVQ CX, DX + MOVQ DI, CX + SHRQ $0x20, CX + ADDQ R15, CX + MOVQ CX, 8(R9) + + // Update literal length + MOVQ $0x00000808, CX + BEXTRQ CX, SI, R14 + MOVQ AX, R15 + LEAQ (DX)(R14*1), CX + ROLQ CL, R15 + BZHIQ R14, R15, R15 + MOVQ CX, DX + MOVQ SI, CX + SHRQ $0x20, CX + ADDQ R15, CX + MOVQ CX, (R9) + + // Fill bitreader for state updates + MOVQ R13, (SP) + MOVQ $0x00000808, CX + BEXTRQ CX, R8, R13 + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decode_56_bmi2_skip_update + LEAQ (SI)(DI*1), R14 + ADDQ R8, R14 + MOVBQZX R14, R14 + LEAQ (DX)(R14*1), CX + MOVQ AX, R15 + MOVQ CX, DX + ROLQ CL, R15 + BZHIQ R14, R15, R15 + + // Update Offset State + BZHIQ R8, R15, CX + SHRXQ R8, R15, R15 + MOVQ $0x00001010, R14 + BEXTRQ R14, R8, R8 + ADDQ CX, R8 + + // Load ctx.ofTable + MOVQ ctx+16(FP), CX + MOVQ 48(CX), CX + MOVQ (CX)(R8*8), R8 + + // Update Match Length State + BZHIQ DI, R15, CX + SHRXQ DI, R15, R15 + MOVQ $0x00001010, R14 + BEXTRQ R14, DI, DI + ADDQ CX, DI + + // Load ctx.mlTable + MOVQ ctx+16(FP), CX + MOVQ 24(CX), CX + MOVQ (CX)(DI*8), DI + + // Update Literal Length State + BZHIQ SI, R15, CX + MOVQ $0x00001010, R14 + BEXTRQ R14, SI, SI + ADDQ CX, SI + + // Load ctx.llTable + MOVQ ctx+16(FP), CX + MOVQ (CX), CX + MOVQ (CX)(SI*8), SI + +sequenceDecs_decode_56_bmi2_skip_update: + // Adjust offset + MOVQ 16(R9), CX + CMPQ R13, $0x01 + JBE sequenceDecs_decode_56_bmi2_adjust_offsetB_1_or_0 + MOVQ R11, R12 + MOVQ R10, R11 + MOVQ CX, R10 + JMP sequenceDecs_decode_56_bmi2_after_adjust + +sequenceDecs_decode_56_bmi2_adjust_offsetB_1_or_0: + CMPQ (R9), $0x00000000 + JNE sequenceDecs_decode_56_bmi2_adjust_offset_maybezero + INCQ CX + JMP sequenceDecs_decode_56_bmi2_adjust_offset_nonzero + +sequenceDecs_decode_56_bmi2_adjust_offset_maybezero: + TESTQ CX, CX + JNZ sequenceDecs_decode_56_bmi2_adjust_offset_nonzero + MOVQ R10, CX + JMP sequenceDecs_decode_56_bmi2_after_adjust + +sequenceDecs_decode_56_bmi2_adjust_offset_nonzero: + CMPQ CX, $0x01 + JB sequenceDecs_decode_56_bmi2_adjust_zero + JEQ sequenceDecs_decode_56_bmi2_adjust_one + CMPQ CX, $0x02 + JA sequenceDecs_decode_56_bmi2_adjust_three + JMP sequenceDecs_decode_56_bmi2_adjust_two + +sequenceDecs_decode_56_bmi2_adjust_zero: + MOVQ R10, R13 + JMP sequenceDecs_decode_56_bmi2_adjust_test_temp_valid + +sequenceDecs_decode_56_bmi2_adjust_one: + MOVQ R11, R13 + JMP sequenceDecs_decode_56_bmi2_adjust_test_temp_valid + +sequenceDecs_decode_56_bmi2_adjust_two: + MOVQ R12, R13 + JMP sequenceDecs_decode_56_bmi2_adjust_test_temp_valid + +sequenceDecs_decode_56_bmi2_adjust_three: + LEAQ -1(R10), R13 + +sequenceDecs_decode_56_bmi2_adjust_test_temp_valid: + TESTQ R13, R13 + JNZ sequenceDecs_decode_56_bmi2_adjust_temp_valid + MOVQ $0x00000001, R13 + +sequenceDecs_decode_56_bmi2_adjust_temp_valid: + CMPQ CX, $0x01 + CMOVQNE R11, R12 + MOVQ R10, R11 + MOVQ R13, R10 + MOVQ R13, CX + +sequenceDecs_decode_56_bmi2_after_adjust: + MOVQ CX, 16(R9) + + // Check values + MOVQ 8(R9), R13 + MOVQ (R9), R14 + LEAQ (R13)(R14*1), R15 + MOVQ s+0(FP), BP + ADDQ R15, 256(BP) + MOVQ ctx+16(FP), R15 + SUBQ R14, 128(R15) + JS error_not_enough_literals + CMPQ R13, $0x00020002 + JA sequenceDecs_decode_56_bmi2_error_match_len_too_big + TESTQ CX, CX + JNZ sequenceDecs_decode_56_bmi2_match_len_ofs_ok + TESTQ R13, R13 + JNZ sequenceDecs_decode_56_bmi2_error_match_len_ofs_mismatch + +sequenceDecs_decode_56_bmi2_match_len_ofs_ok: + ADDQ $0x18, R9 + MOVQ ctx+16(FP), CX + DECQ 96(CX) + JNS sequenceDecs_decode_56_bmi2_main_loop + MOVQ s+0(FP), CX + MOVQ R10, 144(CX) + MOVQ R11, 152(CX) + MOVQ R12, 160(CX) + MOVQ br+8(FP), CX + MOVQ AX, 32(CX) + MOVB DL, 40(CX) + MOVQ BX, 24(CX) + + // Return success + MOVQ $0x00000000, ret+24(FP) + RET + + // Return with match length error +sequenceDecs_decode_56_bmi2_error_match_len_ofs_mismatch: + MOVQ $0x00000001, ret+24(FP) + RET + + // Return with match too long error +sequenceDecs_decode_56_bmi2_error_match_len_too_big: + MOVQ $0x00000002, ret+24(FP) + RET + + // Return with match offset too long error + MOVQ $0x00000003, ret+24(FP) + RET + + // Return with not enough literals error +error_not_enough_literals: + MOVQ $0x00000004, ret+24(FP) + RET + + // Return with overread error +error_overread: + MOVQ $0x00000006, ret+24(FP) + RET + +// func sequenceDecs_executeSimple_amd64(ctx *executeAsmContext) bool +// Requires: SSE +TEXT ·sequenceDecs_executeSimple_amd64(SB), $8-9 + MOVQ ctx+0(FP), R10 + MOVQ 8(R10), CX + TESTQ CX, CX + JZ empty_seqs + MOVQ (R10), AX + MOVQ 24(R10), DX + MOVQ 32(R10), BX + MOVQ 80(R10), SI + MOVQ 104(R10), DI + MOVQ 120(R10), R8 + MOVQ 56(R10), R9 + MOVQ 64(R10), R10 + ADDQ R10, R9 + + // seqsBase += 24 * seqIndex + LEAQ (DX)(DX*2), R11 + SHLQ $0x03, R11 + ADDQ R11, AX + + // outBase += outPosition + ADDQ DI, BX + +main_loop: + MOVQ (AX), R11 + MOVQ 16(AX), R12 + MOVQ 8(AX), R13 + + // Copy literals + TESTQ R11, R11 + JZ check_offset + XORQ R14, R14 + +copy_1: + MOVUPS (SI)(R14*1), X0 + MOVUPS X0, (BX)(R14*1) + ADDQ $0x10, R14 + CMPQ R14, R11 + JB copy_1 + ADDQ R11, SI + ADDQ R11, BX + ADDQ R11, DI + + // Malformed input if seq.mo > t+len(hist) || seq.mo > s.windowSize) +check_offset: + LEAQ (DI)(R10*1), R11 + CMPQ R12, R11 + JG error_match_off_too_big + CMPQ R12, R8 + JG error_match_off_too_big + + // Copy match from history + MOVQ R12, R11 + SUBQ DI, R11 + JLS copy_match + MOVQ R9, R14 + SUBQ R11, R14 + CMPQ R13, R11 + JG copy_all_from_history + MOVQ R13, R11 + SUBQ $0x10, R11 + JB copy_4_small + +copy_4_loop: + MOVUPS (R14), X0 + MOVUPS X0, (BX) + ADDQ $0x10, R14 + ADDQ $0x10, BX + SUBQ $0x10, R11 + JAE copy_4_loop + LEAQ 16(R14)(R11*1), R14 + LEAQ 16(BX)(R11*1), BX + MOVUPS -16(R14), X0 + MOVUPS X0, -16(BX) + JMP copy_4_end + +copy_4_small: + CMPQ R13, $0x03 + JE copy_4_move_3 + CMPQ R13, $0x08 + JB copy_4_move_4through7 + JMP copy_4_move_8through16 + +copy_4_move_3: + MOVW (R14), R11 + MOVB 2(R14), R12 + MOVW R11, (BX) + MOVB R12, 2(BX) + ADDQ R13, R14 + ADDQ R13, BX + JMP copy_4_end + +copy_4_move_4through7: + MOVL (R14), R11 + MOVL -4(R14)(R13*1), R12 + MOVL R11, (BX) + MOVL R12, -4(BX)(R13*1) + ADDQ R13, R14 + ADDQ R13, BX + JMP copy_4_end + +copy_4_move_8through16: + MOVQ (R14), R11 + MOVQ -8(R14)(R13*1), R12 + MOVQ R11, (BX) + MOVQ R12, -8(BX)(R13*1) + ADDQ R13, R14 + ADDQ R13, BX + +copy_4_end: + ADDQ R13, DI + ADDQ $0x18, AX + INCQ DX + CMPQ DX, CX + JB main_loop + JMP loop_finished + +copy_all_from_history: + MOVQ R11, R15 + SUBQ $0x10, R15 + JB copy_5_small + +copy_5_loop: + MOVUPS (R14), X0 + MOVUPS X0, (BX) + ADDQ $0x10, R14 + ADDQ $0x10, BX + SUBQ $0x10, R15 + JAE copy_5_loop + LEAQ 16(R14)(R15*1), R14 + LEAQ 16(BX)(R15*1), BX + MOVUPS -16(R14), X0 + MOVUPS X0, -16(BX) + JMP copy_5_end + +copy_5_small: + CMPQ R11, $0x03 + JE copy_5_move_3 + JB copy_5_move_1or2 + CMPQ R11, $0x08 + JB copy_5_move_4through7 + JMP copy_5_move_8through16 + +copy_5_move_1or2: + MOVB (R14), R15 + MOVB -1(R14)(R11*1), BP + MOVB R15, (BX) + MOVB BP, -1(BX)(R11*1) + ADDQ R11, R14 + ADDQ R11, BX + JMP copy_5_end + +copy_5_move_3: + MOVW (R14), R15 + MOVB 2(R14), BP + MOVW R15, (BX) + MOVB BP, 2(BX) + ADDQ R11, R14 + ADDQ R11, BX + JMP copy_5_end + +copy_5_move_4through7: + MOVL (R14), R15 + MOVL -4(R14)(R11*1), BP + MOVL R15, (BX) + MOVL BP, -4(BX)(R11*1) + ADDQ R11, R14 + ADDQ R11, BX + JMP copy_5_end + +copy_5_move_8through16: + MOVQ (R14), R15 + MOVQ -8(R14)(R11*1), BP + MOVQ R15, (BX) + MOVQ BP, -8(BX)(R11*1) + ADDQ R11, R14 + ADDQ R11, BX + +copy_5_end: + ADDQ R11, DI + SUBQ R11, R13 + + // Copy match from the current buffer +copy_match: + MOVQ BX, R11 + SUBQ R12, R11 + + // ml <= mo + CMPQ R13, R12 + JA copy_overlapping_match + + // Copy non-overlapping match + ADDQ R13, DI + MOVQ BX, R12 + ADDQ R13, BX + +copy_2: + MOVUPS (R11), X0 + MOVUPS X0, (R12) + ADDQ $0x10, R11 + ADDQ $0x10, R12 + SUBQ $0x10, R13 + JHI copy_2 + JMP handle_loop + + // Copy overlapping match +copy_overlapping_match: + ADDQ R13, DI + +copy_slow_3: + MOVB (R11), R12 + MOVB R12, (BX) + INCQ R11 + INCQ BX + DECQ R13 + JNZ copy_slow_3 + +handle_loop: + ADDQ $0x18, AX + INCQ DX + CMPQ DX, CX + JB main_loop + +loop_finished: + // Return value + MOVB $0x01, ret+8(FP) + + // Update the context + MOVQ ctx+0(FP), AX + MOVQ DX, 24(AX) + MOVQ DI, 104(AX) + SUBQ 80(AX), SI + MOVQ SI, 112(AX) + RET + +error_match_off_too_big: + // Return value + MOVB $0x00, ret+8(FP) + + // Update the context + MOVQ ctx+0(FP), AX + MOVQ DX, 24(AX) + MOVQ DI, 104(AX) + SUBQ 80(AX), SI + MOVQ SI, 112(AX) + RET + +empty_seqs: + // Return value + MOVB $0x01, ret+8(FP) + RET + +// func sequenceDecs_executeSimple_safe_amd64(ctx *executeAsmContext) bool +// Requires: SSE +TEXT ·sequenceDecs_executeSimple_safe_amd64(SB), $8-9 + MOVQ ctx+0(FP), R10 + MOVQ 8(R10), CX + TESTQ CX, CX + JZ empty_seqs + MOVQ (R10), AX + MOVQ 24(R10), DX + MOVQ 32(R10), BX + MOVQ 80(R10), SI + MOVQ 104(R10), DI + MOVQ 120(R10), R8 + MOVQ 56(R10), R9 + MOVQ 64(R10), R10 + ADDQ R10, R9 + + // seqsBase += 24 * seqIndex + LEAQ (DX)(DX*2), R11 + SHLQ $0x03, R11 + ADDQ R11, AX + + // outBase += outPosition + ADDQ DI, BX + +main_loop: + MOVQ (AX), R11 + MOVQ 16(AX), R12 + MOVQ 8(AX), R13 + + // Copy literals + TESTQ R11, R11 + JZ check_offset + MOVQ R11, R14 + SUBQ $0x10, R14 + JB copy_1_small + +copy_1_loop: + MOVUPS (SI), X0 + MOVUPS X0, (BX) + ADDQ $0x10, SI + ADDQ $0x10, BX + SUBQ $0x10, R14 + JAE copy_1_loop + LEAQ 16(SI)(R14*1), SI + LEAQ 16(BX)(R14*1), BX + MOVUPS -16(SI), X0 + MOVUPS X0, -16(BX) + JMP copy_1_end + +copy_1_small: + CMPQ R11, $0x03 + JE copy_1_move_3 + JB copy_1_move_1or2 + CMPQ R11, $0x08 + JB copy_1_move_4through7 + JMP copy_1_move_8through16 + +copy_1_move_1or2: + MOVB (SI), R14 + MOVB -1(SI)(R11*1), R15 + MOVB R14, (BX) + MOVB R15, -1(BX)(R11*1) + ADDQ R11, SI + ADDQ R11, BX + JMP copy_1_end + +copy_1_move_3: + MOVW (SI), R14 + MOVB 2(SI), R15 + MOVW R14, (BX) + MOVB R15, 2(BX) + ADDQ R11, SI + ADDQ R11, BX + JMP copy_1_end + +copy_1_move_4through7: + MOVL (SI), R14 + MOVL -4(SI)(R11*1), R15 + MOVL R14, (BX) + MOVL R15, -4(BX)(R11*1) + ADDQ R11, SI + ADDQ R11, BX + JMP copy_1_end + +copy_1_move_8through16: + MOVQ (SI), R14 + MOVQ -8(SI)(R11*1), R15 + MOVQ R14, (BX) + MOVQ R15, -8(BX)(R11*1) + ADDQ R11, SI + ADDQ R11, BX + +copy_1_end: + ADDQ R11, DI + + // Malformed input if seq.mo > t+len(hist) || seq.mo > s.windowSize) +check_offset: + LEAQ (DI)(R10*1), R11 + CMPQ R12, R11 + JG error_match_off_too_big + CMPQ R12, R8 + JG error_match_off_too_big + + // Copy match from history + MOVQ R12, R11 + SUBQ DI, R11 + JLS copy_match + MOVQ R9, R14 + SUBQ R11, R14 + CMPQ R13, R11 + JG copy_all_from_history + MOVQ R13, R11 + SUBQ $0x10, R11 + JB copy_4_small + +copy_4_loop: + MOVUPS (R14), X0 + MOVUPS X0, (BX) + ADDQ $0x10, R14 + ADDQ $0x10, BX + SUBQ $0x10, R11 + JAE copy_4_loop + LEAQ 16(R14)(R11*1), R14 + LEAQ 16(BX)(R11*1), BX + MOVUPS -16(R14), X0 + MOVUPS X0, -16(BX) + JMP copy_4_end + +copy_4_small: + CMPQ R13, $0x03 + JE copy_4_move_3 + CMPQ R13, $0x08 + JB copy_4_move_4through7 + JMP copy_4_move_8through16 + +copy_4_move_3: + MOVW (R14), R11 + MOVB 2(R14), R12 + MOVW R11, (BX) + MOVB R12, 2(BX) + ADDQ R13, R14 + ADDQ R13, BX + JMP copy_4_end + +copy_4_move_4through7: + MOVL (R14), R11 + MOVL -4(R14)(R13*1), R12 + MOVL R11, (BX) + MOVL R12, -4(BX)(R13*1) + ADDQ R13, R14 + ADDQ R13, BX + JMP copy_4_end + +copy_4_move_8through16: + MOVQ (R14), R11 + MOVQ -8(R14)(R13*1), R12 + MOVQ R11, (BX) + MOVQ R12, -8(BX)(R13*1) + ADDQ R13, R14 + ADDQ R13, BX + +copy_4_end: + ADDQ R13, DI + ADDQ $0x18, AX + INCQ DX + CMPQ DX, CX + JB main_loop + JMP loop_finished + +copy_all_from_history: + MOVQ R11, R15 + SUBQ $0x10, R15 + JB copy_5_small + +copy_5_loop: + MOVUPS (R14), X0 + MOVUPS X0, (BX) + ADDQ $0x10, R14 + ADDQ $0x10, BX + SUBQ $0x10, R15 + JAE copy_5_loop + LEAQ 16(R14)(R15*1), R14 + LEAQ 16(BX)(R15*1), BX + MOVUPS -16(R14), X0 + MOVUPS X0, -16(BX) + JMP copy_5_end + +copy_5_small: + CMPQ R11, $0x03 + JE copy_5_move_3 + JB copy_5_move_1or2 + CMPQ R11, $0x08 + JB copy_5_move_4through7 + JMP copy_5_move_8through16 + +copy_5_move_1or2: + MOVB (R14), R15 + MOVB -1(R14)(R11*1), BP + MOVB R15, (BX) + MOVB BP, -1(BX)(R11*1) + ADDQ R11, R14 + ADDQ R11, BX + JMP copy_5_end + +copy_5_move_3: + MOVW (R14), R15 + MOVB 2(R14), BP + MOVW R15, (BX) + MOVB BP, 2(BX) + ADDQ R11, R14 + ADDQ R11, BX + JMP copy_5_end + +copy_5_move_4through7: + MOVL (R14), R15 + MOVL -4(R14)(R11*1), BP + MOVL R15, (BX) + MOVL BP, -4(BX)(R11*1) + ADDQ R11, R14 + ADDQ R11, BX + JMP copy_5_end + +copy_5_move_8through16: + MOVQ (R14), R15 + MOVQ -8(R14)(R11*1), BP + MOVQ R15, (BX) + MOVQ BP, -8(BX)(R11*1) + ADDQ R11, R14 + ADDQ R11, BX + +copy_5_end: + ADDQ R11, DI + SUBQ R11, R13 + + // Copy match from the current buffer +copy_match: + MOVQ BX, R11 + SUBQ R12, R11 + + // ml <= mo + CMPQ R13, R12 + JA copy_overlapping_match + + // Copy non-overlapping match + ADDQ R13, DI + MOVQ R13, R12 + SUBQ $0x10, R12 + JB copy_2_small + +copy_2_loop: + MOVUPS (R11), X0 + MOVUPS X0, (BX) + ADDQ $0x10, R11 + ADDQ $0x10, BX + SUBQ $0x10, R12 + JAE copy_2_loop + LEAQ 16(R11)(R12*1), R11 + LEAQ 16(BX)(R12*1), BX + MOVUPS -16(R11), X0 + MOVUPS X0, -16(BX) + JMP copy_2_end + +copy_2_small: + CMPQ R13, $0x03 + JE copy_2_move_3 + JB copy_2_move_1or2 + CMPQ R13, $0x08 + JB copy_2_move_4through7 + JMP copy_2_move_8through16 + +copy_2_move_1or2: + MOVB (R11), R12 + MOVB -1(R11)(R13*1), R14 + MOVB R12, (BX) + MOVB R14, -1(BX)(R13*1) + ADDQ R13, R11 + ADDQ R13, BX + JMP copy_2_end + +copy_2_move_3: + MOVW (R11), R12 + MOVB 2(R11), R14 + MOVW R12, (BX) + MOVB R14, 2(BX) + ADDQ R13, R11 + ADDQ R13, BX + JMP copy_2_end + +copy_2_move_4through7: + MOVL (R11), R12 + MOVL -4(R11)(R13*1), R14 + MOVL R12, (BX) + MOVL R14, -4(BX)(R13*1) + ADDQ R13, R11 + ADDQ R13, BX + JMP copy_2_end + +copy_2_move_8through16: + MOVQ (R11), R12 + MOVQ -8(R11)(R13*1), R14 + MOVQ R12, (BX) + MOVQ R14, -8(BX)(R13*1) + ADDQ R13, R11 + ADDQ R13, BX + +copy_2_end: + JMP handle_loop + + // Copy overlapping match +copy_overlapping_match: + ADDQ R13, DI + +copy_slow_3: + MOVB (R11), R12 + MOVB R12, (BX) + INCQ R11 + INCQ BX + DECQ R13 + JNZ copy_slow_3 + +handle_loop: + ADDQ $0x18, AX + INCQ DX + CMPQ DX, CX + JB main_loop + +loop_finished: + // Return value + MOVB $0x01, ret+8(FP) + + // Update the context + MOVQ ctx+0(FP), AX + MOVQ DX, 24(AX) + MOVQ DI, 104(AX) + SUBQ 80(AX), SI + MOVQ SI, 112(AX) + RET + +error_match_off_too_big: + // Return value + MOVB $0x00, ret+8(FP) + + // Update the context + MOVQ ctx+0(FP), AX + MOVQ DX, 24(AX) + MOVQ DI, 104(AX) + SUBQ 80(AX), SI + MOVQ SI, 112(AX) + RET + +empty_seqs: + // Return value + MOVB $0x01, ret+8(FP) + RET + +// func sequenceDecs_decodeSync_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int +// Requires: CMOV, SSE +TEXT ·sequenceDecs_decodeSync_amd64(SB), $64-32 + MOVQ br+8(FP), AX + MOVQ 32(AX), DX + MOVBQZX 40(AX), BX + MOVQ 24(AX), SI + MOVQ (AX), AX + ADDQ SI, AX + MOVQ AX, (SP) + MOVQ ctx+16(FP), AX + MOVQ 72(AX), DI + MOVQ 80(AX), R8 + MOVQ 88(AX), R9 + XORQ CX, CX + MOVQ CX, 8(SP) + MOVQ CX, 16(SP) + MOVQ CX, 24(SP) + MOVQ 112(AX), R10 + MOVQ 128(AX), CX + MOVQ CX, 32(SP) + MOVQ 144(AX), R11 + MOVQ 136(AX), R12 + MOVQ 200(AX), CX + MOVQ CX, 56(SP) + MOVQ 176(AX), CX + MOVQ CX, 48(SP) + MOVQ 184(AX), AX + MOVQ AX, 40(SP) + MOVQ 40(SP), AX + ADDQ AX, 48(SP) + + // Calculate poiter to s.out[cap(s.out)] (a past-end pointer) + ADDQ R10, 32(SP) + + // outBase += outPosition + ADDQ R12, R10 + +sequenceDecs_decodeSync_amd64_main_loop: + MOVQ (SP), R13 + + // Fill bitreader to have enough for the offset and match length. + CMPQ SI, $0x08 + JL sequenceDecs_decodeSync_amd64_fill_byte_by_byte + MOVQ BX, AX + SHRQ $0x03, AX + SUBQ AX, R13 + MOVQ (R13), DX + SUBQ AX, SI + ANDQ $0x07, BX + JMP sequenceDecs_decodeSync_amd64_fill_end + +sequenceDecs_decodeSync_amd64_fill_byte_by_byte: + CMPQ SI, $0x00 + JLE sequenceDecs_decodeSync_amd64_fill_check_overread + CMPQ BX, $0x07 + JLE sequenceDecs_decodeSync_amd64_fill_end + SHLQ $0x08, DX + SUBQ $0x01, R13 + SUBQ $0x01, SI + SUBQ $0x08, BX + MOVBQZX (R13), AX + ORQ AX, DX + JMP sequenceDecs_decodeSync_amd64_fill_byte_by_byte + +sequenceDecs_decodeSync_amd64_fill_check_overread: + CMPQ BX, $0x40 + JA error_overread + +sequenceDecs_decodeSync_amd64_fill_end: + // Update offset + MOVQ R9, AX + MOVQ BX, CX + MOVQ DX, R14 + SHLQ CL, R14 + MOVB AH, CL + SHRQ $0x20, AX + TESTQ CX, CX + JZ sequenceDecs_decodeSync_amd64_of_update_zero + ADDQ CX, BX + CMPQ BX, $0x40 + JA sequenceDecs_decodeSync_amd64_of_update_zero + CMPQ CX, $0x40 + JAE sequenceDecs_decodeSync_amd64_of_update_zero + NEGQ CX + SHRQ CL, R14 + ADDQ R14, AX + +sequenceDecs_decodeSync_amd64_of_update_zero: + MOVQ AX, 8(SP) + + // Update match length + MOVQ R8, AX + MOVQ BX, CX + MOVQ DX, R14 + SHLQ CL, R14 + MOVB AH, CL + SHRQ $0x20, AX + TESTQ CX, CX + JZ sequenceDecs_decodeSync_amd64_ml_update_zero + ADDQ CX, BX + CMPQ BX, $0x40 + JA sequenceDecs_decodeSync_amd64_ml_update_zero + CMPQ CX, $0x40 + JAE sequenceDecs_decodeSync_amd64_ml_update_zero + NEGQ CX + SHRQ CL, R14 + ADDQ R14, AX + +sequenceDecs_decodeSync_amd64_ml_update_zero: + MOVQ AX, 16(SP) + + // Fill bitreader to have enough for the remaining + CMPQ SI, $0x08 + JL sequenceDecs_decodeSync_amd64_fill_2_byte_by_byte + MOVQ BX, AX + SHRQ $0x03, AX + SUBQ AX, R13 + MOVQ (R13), DX + SUBQ AX, SI + ANDQ $0x07, BX + JMP sequenceDecs_decodeSync_amd64_fill_2_end + +sequenceDecs_decodeSync_amd64_fill_2_byte_by_byte: + CMPQ SI, $0x00 + JLE sequenceDecs_decodeSync_amd64_fill_2_check_overread + CMPQ BX, $0x07 + JLE sequenceDecs_decodeSync_amd64_fill_2_end + SHLQ $0x08, DX + SUBQ $0x01, R13 + SUBQ $0x01, SI + SUBQ $0x08, BX + MOVBQZX (R13), AX + ORQ AX, DX + JMP sequenceDecs_decodeSync_amd64_fill_2_byte_by_byte + +sequenceDecs_decodeSync_amd64_fill_2_check_overread: + CMPQ BX, $0x40 + JA error_overread + +sequenceDecs_decodeSync_amd64_fill_2_end: + // Update literal length + MOVQ DI, AX + MOVQ BX, CX + MOVQ DX, R14 + SHLQ CL, R14 + MOVB AH, CL + SHRQ $0x20, AX + TESTQ CX, CX + JZ sequenceDecs_decodeSync_amd64_ll_update_zero + ADDQ CX, BX + CMPQ BX, $0x40 + JA sequenceDecs_decodeSync_amd64_ll_update_zero + CMPQ CX, $0x40 + JAE sequenceDecs_decodeSync_amd64_ll_update_zero + NEGQ CX + SHRQ CL, R14 + ADDQ R14, AX + +sequenceDecs_decodeSync_amd64_ll_update_zero: + MOVQ AX, 24(SP) + + // Fill bitreader for state updates + MOVQ R13, (SP) + MOVQ R9, AX + SHRQ $0x08, AX + MOVBQZX AL, AX + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decodeSync_amd64_skip_update + + // Update Literal Length State + MOVBQZX DI, R13 + SHRQ $0x10, DI + MOVWQZX DI, DI + LEAQ (BX)(R13*1), CX + MOVQ DX, R14 + MOVQ CX, BX + ROLQ CL, R14 + MOVL $0x00000001, R15 + MOVB R13, CL + SHLL CL, R15 + DECL R15 + ANDQ R15, R14 + ADDQ R14, DI + + // Load ctx.llTable + MOVQ ctx+16(FP), CX + MOVQ (CX), CX + MOVQ (CX)(DI*8), DI + + // Update Match Length State + MOVBQZX R8, R13 + SHRQ $0x10, R8 + MOVWQZX R8, R8 + LEAQ (BX)(R13*1), CX + MOVQ DX, R14 + MOVQ CX, BX + ROLQ CL, R14 + MOVL $0x00000001, R15 + MOVB R13, CL + SHLL CL, R15 + DECL R15 + ANDQ R15, R14 + ADDQ R14, R8 + + // Load ctx.mlTable + MOVQ ctx+16(FP), CX + MOVQ 24(CX), CX + MOVQ (CX)(R8*8), R8 + + // Update Offset State + MOVBQZX R9, R13 + SHRQ $0x10, R9 + MOVWQZX R9, R9 + LEAQ (BX)(R13*1), CX + MOVQ DX, R14 + MOVQ CX, BX + ROLQ CL, R14 + MOVL $0x00000001, R15 + MOVB R13, CL + SHLL CL, R15 + DECL R15 + ANDQ R15, R14 + ADDQ R14, R9 + + // Load ctx.ofTable + MOVQ ctx+16(FP), CX + MOVQ 48(CX), CX + MOVQ (CX)(R9*8), R9 + +sequenceDecs_decodeSync_amd64_skip_update: + // Adjust offset + MOVQ s+0(FP), CX + MOVQ 8(SP), R13 + CMPQ AX, $0x01 + JBE sequenceDecs_decodeSync_amd64_adjust_offsetB_1_or_0 + MOVUPS 144(CX), X0 + MOVQ R13, 144(CX) + MOVUPS X0, 152(CX) + JMP sequenceDecs_decodeSync_amd64_after_adjust + +sequenceDecs_decodeSync_amd64_adjust_offsetB_1_or_0: + CMPQ 24(SP), $0x00000000 + JNE sequenceDecs_decodeSync_amd64_adjust_offset_maybezero + INCQ R13 + JMP sequenceDecs_decodeSync_amd64_adjust_offset_nonzero + +sequenceDecs_decodeSync_amd64_adjust_offset_maybezero: + TESTQ R13, R13 + JNZ sequenceDecs_decodeSync_amd64_adjust_offset_nonzero + MOVQ 144(CX), R13 + JMP sequenceDecs_decodeSync_amd64_after_adjust + +sequenceDecs_decodeSync_amd64_adjust_offset_nonzero: + MOVQ R13, AX + XORQ R14, R14 + MOVQ $-1, R15 + CMPQ R13, $0x03 + CMOVQEQ R14, AX + CMOVQEQ R15, R14 + ADDQ 144(CX)(AX*8), R14 + JNZ sequenceDecs_decodeSync_amd64_adjust_temp_valid + MOVQ $0x00000001, R14 + +sequenceDecs_decodeSync_amd64_adjust_temp_valid: + CMPQ R13, $0x01 + JZ sequenceDecs_decodeSync_amd64_adjust_skip + MOVQ 152(CX), AX + MOVQ AX, 160(CX) + +sequenceDecs_decodeSync_amd64_adjust_skip: + MOVQ 144(CX), AX + MOVQ AX, 152(CX) + MOVQ R14, 144(CX) + MOVQ R14, R13 + +sequenceDecs_decodeSync_amd64_after_adjust: + MOVQ R13, 8(SP) + + // Check values + MOVQ 16(SP), AX + MOVQ 24(SP), CX + LEAQ (AX)(CX*1), R14 + MOVQ s+0(FP), R15 + ADDQ R14, 256(R15) + MOVQ ctx+16(FP), R14 + SUBQ CX, 104(R14) + JS error_not_enough_literals + CMPQ AX, $0x00020002 + JA sequenceDecs_decodeSync_amd64_error_match_len_too_big + TESTQ R13, R13 + JNZ sequenceDecs_decodeSync_amd64_match_len_ofs_ok + TESTQ AX, AX + JNZ sequenceDecs_decodeSync_amd64_error_match_len_ofs_mismatch + +sequenceDecs_decodeSync_amd64_match_len_ofs_ok: + MOVQ 24(SP), AX + MOVQ 8(SP), CX + MOVQ 16(SP), R13 + + // Check if we have enough space in s.out + LEAQ (AX)(R13*1), R14 + ADDQ R10, R14 + CMPQ R14, 32(SP) + JA error_not_enough_space + + // Copy literals + TESTQ AX, AX + JZ check_offset + XORQ R14, R14 + +copy_1: + MOVUPS (R11)(R14*1), X0 + MOVUPS X0, (R10)(R14*1) + ADDQ $0x10, R14 + CMPQ R14, AX + JB copy_1 + ADDQ AX, R11 + ADDQ AX, R10 + ADDQ AX, R12 + + // Malformed input if seq.mo > t+len(hist) || seq.mo > s.windowSize) +check_offset: + MOVQ R12, AX + ADDQ 40(SP), AX + CMPQ CX, AX + JG error_match_off_too_big + CMPQ CX, 56(SP) + JG error_match_off_too_big + + // Copy match from history + MOVQ CX, AX + SUBQ R12, AX + JLS copy_match + MOVQ 48(SP), R14 + SUBQ AX, R14 + CMPQ R13, AX + JG copy_all_from_history + MOVQ R13, AX + SUBQ $0x10, AX + JB copy_4_small + +copy_4_loop: + MOVUPS (R14), X0 + MOVUPS X0, (R10) + ADDQ $0x10, R14 + ADDQ $0x10, R10 + SUBQ $0x10, AX + JAE copy_4_loop + LEAQ 16(R14)(AX*1), R14 + LEAQ 16(R10)(AX*1), R10 + MOVUPS -16(R14), X0 + MOVUPS X0, -16(R10) + JMP copy_4_end + +copy_4_small: + CMPQ R13, $0x03 + JE copy_4_move_3 + CMPQ R13, $0x08 + JB copy_4_move_4through7 + JMP copy_4_move_8through16 + +copy_4_move_3: + MOVW (R14), AX + MOVB 2(R14), CL + MOVW AX, (R10) + MOVB CL, 2(R10) + ADDQ R13, R14 + ADDQ R13, R10 + JMP copy_4_end + +copy_4_move_4through7: + MOVL (R14), AX + MOVL -4(R14)(R13*1), CX + MOVL AX, (R10) + MOVL CX, -4(R10)(R13*1) + ADDQ R13, R14 + ADDQ R13, R10 + JMP copy_4_end + +copy_4_move_8through16: + MOVQ (R14), AX + MOVQ -8(R14)(R13*1), CX + MOVQ AX, (R10) + MOVQ CX, -8(R10)(R13*1) + ADDQ R13, R14 + ADDQ R13, R10 + +copy_4_end: + ADDQ R13, R12 + JMP handle_loop + JMP loop_finished + +copy_all_from_history: + MOVQ AX, R15 + SUBQ $0x10, R15 + JB copy_5_small + +copy_5_loop: + MOVUPS (R14), X0 + MOVUPS X0, (R10) + ADDQ $0x10, R14 + ADDQ $0x10, R10 + SUBQ $0x10, R15 + JAE copy_5_loop + LEAQ 16(R14)(R15*1), R14 + LEAQ 16(R10)(R15*1), R10 + MOVUPS -16(R14), X0 + MOVUPS X0, -16(R10) + JMP copy_5_end + +copy_5_small: + CMPQ AX, $0x03 + JE copy_5_move_3 + JB copy_5_move_1or2 + CMPQ AX, $0x08 + JB copy_5_move_4through7 + JMP copy_5_move_8through16 + +copy_5_move_1or2: + MOVB (R14), R15 + MOVB -1(R14)(AX*1), BP + MOVB R15, (R10) + MOVB BP, -1(R10)(AX*1) + ADDQ AX, R14 + ADDQ AX, R10 + JMP copy_5_end + +copy_5_move_3: + MOVW (R14), R15 + MOVB 2(R14), BP + MOVW R15, (R10) + MOVB BP, 2(R10) + ADDQ AX, R14 + ADDQ AX, R10 + JMP copy_5_end + +copy_5_move_4through7: + MOVL (R14), R15 + MOVL -4(R14)(AX*1), BP + MOVL R15, (R10) + MOVL BP, -4(R10)(AX*1) + ADDQ AX, R14 + ADDQ AX, R10 + JMP copy_5_end + +copy_5_move_8through16: + MOVQ (R14), R15 + MOVQ -8(R14)(AX*1), BP + MOVQ R15, (R10) + MOVQ BP, -8(R10)(AX*1) + ADDQ AX, R14 + ADDQ AX, R10 + +copy_5_end: + ADDQ AX, R12 + SUBQ AX, R13 + + // Copy match from the current buffer +copy_match: + MOVQ R10, AX + SUBQ CX, AX + + // ml <= mo + CMPQ R13, CX + JA copy_overlapping_match + + // Copy non-overlapping match + ADDQ R13, R12 + MOVQ R10, CX + ADDQ R13, R10 + +copy_2: + MOVUPS (AX), X0 + MOVUPS X0, (CX) + ADDQ $0x10, AX + ADDQ $0x10, CX + SUBQ $0x10, R13 + JHI copy_2 + JMP handle_loop + + // Copy overlapping match +copy_overlapping_match: + ADDQ R13, R12 + +copy_slow_3: + MOVB (AX), CL + MOVB CL, (R10) + INCQ AX + INCQ R10 + DECQ R13 + JNZ copy_slow_3 + +handle_loop: + MOVQ ctx+16(FP), AX + DECQ 96(AX) + JNS sequenceDecs_decodeSync_amd64_main_loop + +loop_finished: + MOVQ br+8(FP), AX + MOVQ DX, 32(AX) + MOVB BL, 40(AX) + MOVQ SI, 24(AX) + + // Update the context + MOVQ ctx+16(FP), AX + MOVQ R12, 136(AX) + MOVQ 144(AX), CX + SUBQ CX, R11 + MOVQ R11, 168(AX) + + // Return success + MOVQ $0x00000000, ret+24(FP) + RET + + // Return with match length error +sequenceDecs_decodeSync_amd64_error_match_len_ofs_mismatch: + MOVQ 16(SP), AX + MOVQ ctx+16(FP), CX + MOVQ AX, 216(CX) + MOVQ $0x00000001, ret+24(FP) + RET + + // Return with match too long error +sequenceDecs_decodeSync_amd64_error_match_len_too_big: + MOVQ ctx+16(FP), AX + MOVQ 16(SP), CX + MOVQ CX, 216(AX) + MOVQ $0x00000002, ret+24(FP) + RET + + // Return with match offset too long error +error_match_off_too_big: + MOVQ ctx+16(FP), AX + MOVQ 8(SP), CX + MOVQ CX, 224(AX) + MOVQ R12, 136(AX) + MOVQ $0x00000003, ret+24(FP) + RET + + // Return with not enough literals error +error_not_enough_literals: + MOVQ ctx+16(FP), AX + MOVQ 24(SP), CX + MOVQ CX, 208(AX) + MOVQ $0x00000004, ret+24(FP) + RET + + // Return with overread error +error_overread: + MOVQ $0x00000006, ret+24(FP) + RET + + // Return with not enough output space error +error_not_enough_space: + MOVQ ctx+16(FP), AX + MOVQ 24(SP), CX + MOVQ CX, 208(AX) + MOVQ 16(SP), CX + MOVQ CX, 216(AX) + MOVQ R12, 136(AX) + MOVQ $0x00000005, ret+24(FP) + RET + +// func sequenceDecs_decodeSync_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int +// Requires: BMI, BMI2, CMOV, SSE +TEXT ·sequenceDecs_decodeSync_bmi2(SB), $64-32 + MOVQ br+8(FP), CX + MOVQ 32(CX), AX + MOVBQZX 40(CX), DX + MOVQ 24(CX), BX + MOVQ (CX), CX + ADDQ BX, CX + MOVQ CX, (SP) + MOVQ ctx+16(FP), CX + MOVQ 72(CX), SI + MOVQ 80(CX), DI + MOVQ 88(CX), R8 + XORQ R9, R9 + MOVQ R9, 8(SP) + MOVQ R9, 16(SP) + MOVQ R9, 24(SP) + MOVQ 112(CX), R9 + MOVQ 128(CX), R10 + MOVQ R10, 32(SP) + MOVQ 144(CX), R10 + MOVQ 136(CX), R11 + MOVQ 200(CX), R12 + MOVQ R12, 56(SP) + MOVQ 176(CX), R12 + MOVQ R12, 48(SP) + MOVQ 184(CX), CX + MOVQ CX, 40(SP) + MOVQ 40(SP), CX + ADDQ CX, 48(SP) + + // Calculate poiter to s.out[cap(s.out)] (a past-end pointer) + ADDQ R9, 32(SP) + + // outBase += outPosition + ADDQ R11, R9 + +sequenceDecs_decodeSync_bmi2_main_loop: + MOVQ (SP), R12 + + // Fill bitreader to have enough for the offset and match length. + CMPQ BX, $0x08 + JL sequenceDecs_decodeSync_bmi2_fill_byte_by_byte + MOVQ DX, CX + SHRQ $0x03, CX + SUBQ CX, R12 + MOVQ (R12), AX + SUBQ CX, BX + ANDQ $0x07, DX + JMP sequenceDecs_decodeSync_bmi2_fill_end + +sequenceDecs_decodeSync_bmi2_fill_byte_by_byte: + CMPQ BX, $0x00 + JLE sequenceDecs_decodeSync_bmi2_fill_check_overread + CMPQ DX, $0x07 + JLE sequenceDecs_decodeSync_bmi2_fill_end + SHLQ $0x08, AX + SUBQ $0x01, R12 + SUBQ $0x01, BX + SUBQ $0x08, DX + MOVBQZX (R12), CX + ORQ CX, AX + JMP sequenceDecs_decodeSync_bmi2_fill_byte_by_byte + +sequenceDecs_decodeSync_bmi2_fill_check_overread: + CMPQ DX, $0x40 + JA error_overread + +sequenceDecs_decodeSync_bmi2_fill_end: + // Update offset + MOVQ $0x00000808, CX + BEXTRQ CX, R8, R13 + MOVQ AX, R14 + LEAQ (DX)(R13*1), CX + ROLQ CL, R14 + BZHIQ R13, R14, R14 + MOVQ CX, DX + MOVQ R8, CX + SHRQ $0x20, CX + ADDQ R14, CX + MOVQ CX, 8(SP) + + // Update match length + MOVQ $0x00000808, CX + BEXTRQ CX, DI, R13 + MOVQ AX, R14 + LEAQ (DX)(R13*1), CX + ROLQ CL, R14 + BZHIQ R13, R14, R14 + MOVQ CX, DX + MOVQ DI, CX + SHRQ $0x20, CX + ADDQ R14, CX + MOVQ CX, 16(SP) + + // Fill bitreader to have enough for the remaining + CMPQ BX, $0x08 + JL sequenceDecs_decodeSync_bmi2_fill_2_byte_by_byte + MOVQ DX, CX + SHRQ $0x03, CX + SUBQ CX, R12 + MOVQ (R12), AX + SUBQ CX, BX + ANDQ $0x07, DX + JMP sequenceDecs_decodeSync_bmi2_fill_2_end + +sequenceDecs_decodeSync_bmi2_fill_2_byte_by_byte: + CMPQ BX, $0x00 + JLE sequenceDecs_decodeSync_bmi2_fill_2_check_overread + CMPQ DX, $0x07 + JLE sequenceDecs_decodeSync_bmi2_fill_2_end + SHLQ $0x08, AX + SUBQ $0x01, R12 + SUBQ $0x01, BX + SUBQ $0x08, DX + MOVBQZX (R12), CX + ORQ CX, AX + JMP sequenceDecs_decodeSync_bmi2_fill_2_byte_by_byte + +sequenceDecs_decodeSync_bmi2_fill_2_check_overread: + CMPQ DX, $0x40 + JA error_overread + +sequenceDecs_decodeSync_bmi2_fill_2_end: + // Update literal length + MOVQ $0x00000808, CX + BEXTRQ CX, SI, R13 + MOVQ AX, R14 + LEAQ (DX)(R13*1), CX + ROLQ CL, R14 + BZHIQ R13, R14, R14 + MOVQ CX, DX + MOVQ SI, CX + SHRQ $0x20, CX + ADDQ R14, CX + MOVQ CX, 24(SP) + + // Fill bitreader for state updates + MOVQ R12, (SP) + MOVQ $0x00000808, CX + BEXTRQ CX, R8, R12 + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decodeSync_bmi2_skip_update + LEAQ (SI)(DI*1), R13 + ADDQ R8, R13 + MOVBQZX R13, R13 + LEAQ (DX)(R13*1), CX + MOVQ AX, R14 + MOVQ CX, DX + ROLQ CL, R14 + BZHIQ R13, R14, R14 + + // Update Offset State + BZHIQ R8, R14, CX + SHRXQ R8, R14, R14 + MOVQ $0x00001010, R13 + BEXTRQ R13, R8, R8 + ADDQ CX, R8 + + // Load ctx.ofTable + MOVQ ctx+16(FP), CX + MOVQ 48(CX), CX + MOVQ (CX)(R8*8), R8 + + // Update Match Length State + BZHIQ DI, R14, CX + SHRXQ DI, R14, R14 + MOVQ $0x00001010, R13 + BEXTRQ R13, DI, DI + ADDQ CX, DI + + // Load ctx.mlTable + MOVQ ctx+16(FP), CX + MOVQ 24(CX), CX + MOVQ (CX)(DI*8), DI + + // Update Literal Length State + BZHIQ SI, R14, CX + MOVQ $0x00001010, R13 + BEXTRQ R13, SI, SI + ADDQ CX, SI + + // Load ctx.llTable + MOVQ ctx+16(FP), CX + MOVQ (CX), CX + MOVQ (CX)(SI*8), SI + +sequenceDecs_decodeSync_bmi2_skip_update: + // Adjust offset + MOVQ s+0(FP), CX + MOVQ 8(SP), R13 + CMPQ R12, $0x01 + JBE sequenceDecs_decodeSync_bmi2_adjust_offsetB_1_or_0 + MOVUPS 144(CX), X0 + MOVQ R13, 144(CX) + MOVUPS X0, 152(CX) + JMP sequenceDecs_decodeSync_bmi2_after_adjust + +sequenceDecs_decodeSync_bmi2_adjust_offsetB_1_or_0: + CMPQ 24(SP), $0x00000000 + JNE sequenceDecs_decodeSync_bmi2_adjust_offset_maybezero + INCQ R13 + JMP sequenceDecs_decodeSync_bmi2_adjust_offset_nonzero + +sequenceDecs_decodeSync_bmi2_adjust_offset_maybezero: + TESTQ R13, R13 + JNZ sequenceDecs_decodeSync_bmi2_adjust_offset_nonzero + MOVQ 144(CX), R13 + JMP sequenceDecs_decodeSync_bmi2_after_adjust + +sequenceDecs_decodeSync_bmi2_adjust_offset_nonzero: + MOVQ R13, R12 + XORQ R14, R14 + MOVQ $-1, R15 + CMPQ R13, $0x03 + CMOVQEQ R14, R12 + CMOVQEQ R15, R14 + ADDQ 144(CX)(R12*8), R14 + JNZ sequenceDecs_decodeSync_bmi2_adjust_temp_valid + MOVQ $0x00000001, R14 + +sequenceDecs_decodeSync_bmi2_adjust_temp_valid: + CMPQ R13, $0x01 + JZ sequenceDecs_decodeSync_bmi2_adjust_skip + MOVQ 152(CX), R12 + MOVQ R12, 160(CX) + +sequenceDecs_decodeSync_bmi2_adjust_skip: + MOVQ 144(CX), R12 + MOVQ R12, 152(CX) + MOVQ R14, 144(CX) + MOVQ R14, R13 + +sequenceDecs_decodeSync_bmi2_after_adjust: + MOVQ R13, 8(SP) + + // Check values + MOVQ 16(SP), CX + MOVQ 24(SP), R12 + LEAQ (CX)(R12*1), R14 + MOVQ s+0(FP), R15 + ADDQ R14, 256(R15) + MOVQ ctx+16(FP), R14 + SUBQ R12, 104(R14) + JS error_not_enough_literals + CMPQ CX, $0x00020002 + JA sequenceDecs_decodeSync_bmi2_error_match_len_too_big + TESTQ R13, R13 + JNZ sequenceDecs_decodeSync_bmi2_match_len_ofs_ok + TESTQ CX, CX + JNZ sequenceDecs_decodeSync_bmi2_error_match_len_ofs_mismatch + +sequenceDecs_decodeSync_bmi2_match_len_ofs_ok: + MOVQ 24(SP), CX + MOVQ 8(SP), R12 + MOVQ 16(SP), R13 + + // Check if we have enough space in s.out + LEAQ (CX)(R13*1), R14 + ADDQ R9, R14 + CMPQ R14, 32(SP) + JA error_not_enough_space + + // Copy literals + TESTQ CX, CX + JZ check_offset + XORQ R14, R14 + +copy_1: + MOVUPS (R10)(R14*1), X0 + MOVUPS X0, (R9)(R14*1) + ADDQ $0x10, R14 + CMPQ R14, CX + JB copy_1 + ADDQ CX, R10 + ADDQ CX, R9 + ADDQ CX, R11 + + // Malformed input if seq.mo > t+len(hist) || seq.mo > s.windowSize) +check_offset: + MOVQ R11, CX + ADDQ 40(SP), CX + CMPQ R12, CX + JG error_match_off_too_big + CMPQ R12, 56(SP) + JG error_match_off_too_big + + // Copy match from history + MOVQ R12, CX + SUBQ R11, CX + JLS copy_match + MOVQ 48(SP), R14 + SUBQ CX, R14 + CMPQ R13, CX + JG copy_all_from_history + MOVQ R13, CX + SUBQ $0x10, CX + JB copy_4_small + +copy_4_loop: + MOVUPS (R14), X0 + MOVUPS X0, (R9) + ADDQ $0x10, R14 + ADDQ $0x10, R9 + SUBQ $0x10, CX + JAE copy_4_loop + LEAQ 16(R14)(CX*1), R14 + LEAQ 16(R9)(CX*1), R9 + MOVUPS -16(R14), X0 + MOVUPS X0, -16(R9) + JMP copy_4_end + +copy_4_small: + CMPQ R13, $0x03 + JE copy_4_move_3 + CMPQ R13, $0x08 + JB copy_4_move_4through7 + JMP copy_4_move_8through16 + +copy_4_move_3: + MOVW (R14), CX + MOVB 2(R14), R12 + MOVW CX, (R9) + MOVB R12, 2(R9) + ADDQ R13, R14 + ADDQ R13, R9 + JMP copy_4_end + +copy_4_move_4through7: + MOVL (R14), CX + MOVL -4(R14)(R13*1), R12 + MOVL CX, (R9) + MOVL R12, -4(R9)(R13*1) + ADDQ R13, R14 + ADDQ R13, R9 + JMP copy_4_end + +copy_4_move_8through16: + MOVQ (R14), CX + MOVQ -8(R14)(R13*1), R12 + MOVQ CX, (R9) + MOVQ R12, -8(R9)(R13*1) + ADDQ R13, R14 + ADDQ R13, R9 + +copy_4_end: + ADDQ R13, R11 + JMP handle_loop + JMP loop_finished + +copy_all_from_history: + MOVQ CX, R15 + SUBQ $0x10, R15 + JB copy_5_small + +copy_5_loop: + MOVUPS (R14), X0 + MOVUPS X0, (R9) + ADDQ $0x10, R14 + ADDQ $0x10, R9 + SUBQ $0x10, R15 + JAE copy_5_loop + LEAQ 16(R14)(R15*1), R14 + LEAQ 16(R9)(R15*1), R9 + MOVUPS -16(R14), X0 + MOVUPS X0, -16(R9) + JMP copy_5_end + +copy_5_small: + CMPQ CX, $0x03 + JE copy_5_move_3 + JB copy_5_move_1or2 + CMPQ CX, $0x08 + JB copy_5_move_4through7 + JMP copy_5_move_8through16 + +copy_5_move_1or2: + MOVB (R14), R15 + MOVB -1(R14)(CX*1), BP + MOVB R15, (R9) + MOVB BP, -1(R9)(CX*1) + ADDQ CX, R14 + ADDQ CX, R9 + JMP copy_5_end + +copy_5_move_3: + MOVW (R14), R15 + MOVB 2(R14), BP + MOVW R15, (R9) + MOVB BP, 2(R9) + ADDQ CX, R14 + ADDQ CX, R9 + JMP copy_5_end + +copy_5_move_4through7: + MOVL (R14), R15 + MOVL -4(R14)(CX*1), BP + MOVL R15, (R9) + MOVL BP, -4(R9)(CX*1) + ADDQ CX, R14 + ADDQ CX, R9 + JMP copy_5_end + +copy_5_move_8through16: + MOVQ (R14), R15 + MOVQ -8(R14)(CX*1), BP + MOVQ R15, (R9) + MOVQ BP, -8(R9)(CX*1) + ADDQ CX, R14 + ADDQ CX, R9 + +copy_5_end: + ADDQ CX, R11 + SUBQ CX, R13 + + // Copy match from the current buffer +copy_match: + MOVQ R9, CX + SUBQ R12, CX + + // ml <= mo + CMPQ R13, R12 + JA copy_overlapping_match + + // Copy non-overlapping match + ADDQ R13, R11 + MOVQ R9, R12 + ADDQ R13, R9 + +copy_2: + MOVUPS (CX), X0 + MOVUPS X0, (R12) + ADDQ $0x10, CX + ADDQ $0x10, R12 + SUBQ $0x10, R13 + JHI copy_2 + JMP handle_loop + + // Copy overlapping match +copy_overlapping_match: + ADDQ R13, R11 + +copy_slow_3: + MOVB (CX), R12 + MOVB R12, (R9) + INCQ CX + INCQ R9 + DECQ R13 + JNZ copy_slow_3 + +handle_loop: + MOVQ ctx+16(FP), CX + DECQ 96(CX) + JNS sequenceDecs_decodeSync_bmi2_main_loop + +loop_finished: + MOVQ br+8(FP), CX + MOVQ AX, 32(CX) + MOVB DL, 40(CX) + MOVQ BX, 24(CX) + + // Update the context + MOVQ ctx+16(FP), AX + MOVQ R11, 136(AX) + MOVQ 144(AX), CX + SUBQ CX, R10 + MOVQ R10, 168(AX) + + // Return success + MOVQ $0x00000000, ret+24(FP) + RET + + // Return with match length error +sequenceDecs_decodeSync_bmi2_error_match_len_ofs_mismatch: + MOVQ 16(SP), AX + MOVQ ctx+16(FP), CX + MOVQ AX, 216(CX) + MOVQ $0x00000001, ret+24(FP) + RET + + // Return with match too long error +sequenceDecs_decodeSync_bmi2_error_match_len_too_big: + MOVQ ctx+16(FP), AX + MOVQ 16(SP), CX + MOVQ CX, 216(AX) + MOVQ $0x00000002, ret+24(FP) + RET + + // Return with match offset too long error +error_match_off_too_big: + MOVQ ctx+16(FP), AX + MOVQ 8(SP), CX + MOVQ CX, 224(AX) + MOVQ R11, 136(AX) + MOVQ $0x00000003, ret+24(FP) + RET + + // Return with not enough literals error +error_not_enough_literals: + MOVQ ctx+16(FP), AX + MOVQ 24(SP), CX + MOVQ CX, 208(AX) + MOVQ $0x00000004, ret+24(FP) + RET + + // Return with overread error +error_overread: + MOVQ $0x00000006, ret+24(FP) + RET + + // Return with not enough output space error +error_not_enough_space: + MOVQ ctx+16(FP), AX + MOVQ 24(SP), CX + MOVQ CX, 208(AX) + MOVQ 16(SP), CX + MOVQ CX, 216(AX) + MOVQ R11, 136(AX) + MOVQ $0x00000005, ret+24(FP) + RET + +// func sequenceDecs_decodeSync_safe_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int +// Requires: CMOV, SSE +TEXT ·sequenceDecs_decodeSync_safe_amd64(SB), $64-32 + MOVQ br+8(FP), AX + MOVQ 32(AX), DX + MOVBQZX 40(AX), BX + MOVQ 24(AX), SI + MOVQ (AX), AX + ADDQ SI, AX + MOVQ AX, (SP) + MOVQ ctx+16(FP), AX + MOVQ 72(AX), DI + MOVQ 80(AX), R8 + MOVQ 88(AX), R9 + XORQ CX, CX + MOVQ CX, 8(SP) + MOVQ CX, 16(SP) + MOVQ CX, 24(SP) + MOVQ 112(AX), R10 + MOVQ 128(AX), CX + MOVQ CX, 32(SP) + MOVQ 144(AX), R11 + MOVQ 136(AX), R12 + MOVQ 200(AX), CX + MOVQ CX, 56(SP) + MOVQ 176(AX), CX + MOVQ CX, 48(SP) + MOVQ 184(AX), AX + MOVQ AX, 40(SP) + MOVQ 40(SP), AX + ADDQ AX, 48(SP) + + // Calculate poiter to s.out[cap(s.out)] (a past-end pointer) + ADDQ R10, 32(SP) + + // outBase += outPosition + ADDQ R12, R10 + +sequenceDecs_decodeSync_safe_amd64_main_loop: + MOVQ (SP), R13 + + // Fill bitreader to have enough for the offset and match length. + CMPQ SI, $0x08 + JL sequenceDecs_decodeSync_safe_amd64_fill_byte_by_byte + MOVQ BX, AX + SHRQ $0x03, AX + SUBQ AX, R13 + MOVQ (R13), DX + SUBQ AX, SI + ANDQ $0x07, BX + JMP sequenceDecs_decodeSync_safe_amd64_fill_end + +sequenceDecs_decodeSync_safe_amd64_fill_byte_by_byte: + CMPQ SI, $0x00 + JLE sequenceDecs_decodeSync_safe_amd64_fill_check_overread + CMPQ BX, $0x07 + JLE sequenceDecs_decodeSync_safe_amd64_fill_end + SHLQ $0x08, DX + SUBQ $0x01, R13 + SUBQ $0x01, SI + SUBQ $0x08, BX + MOVBQZX (R13), AX + ORQ AX, DX + JMP sequenceDecs_decodeSync_safe_amd64_fill_byte_by_byte + +sequenceDecs_decodeSync_safe_amd64_fill_check_overread: + CMPQ BX, $0x40 + JA error_overread + +sequenceDecs_decodeSync_safe_amd64_fill_end: + // Update offset + MOVQ R9, AX + MOVQ BX, CX + MOVQ DX, R14 + SHLQ CL, R14 + MOVB AH, CL + SHRQ $0x20, AX + TESTQ CX, CX + JZ sequenceDecs_decodeSync_safe_amd64_of_update_zero + ADDQ CX, BX + CMPQ BX, $0x40 + JA sequenceDecs_decodeSync_safe_amd64_of_update_zero + CMPQ CX, $0x40 + JAE sequenceDecs_decodeSync_safe_amd64_of_update_zero + NEGQ CX + SHRQ CL, R14 + ADDQ R14, AX + +sequenceDecs_decodeSync_safe_amd64_of_update_zero: + MOVQ AX, 8(SP) + + // Update match length + MOVQ R8, AX + MOVQ BX, CX + MOVQ DX, R14 + SHLQ CL, R14 + MOVB AH, CL + SHRQ $0x20, AX + TESTQ CX, CX + JZ sequenceDecs_decodeSync_safe_amd64_ml_update_zero + ADDQ CX, BX + CMPQ BX, $0x40 + JA sequenceDecs_decodeSync_safe_amd64_ml_update_zero + CMPQ CX, $0x40 + JAE sequenceDecs_decodeSync_safe_amd64_ml_update_zero + NEGQ CX + SHRQ CL, R14 + ADDQ R14, AX + +sequenceDecs_decodeSync_safe_amd64_ml_update_zero: + MOVQ AX, 16(SP) + + // Fill bitreader to have enough for the remaining + CMPQ SI, $0x08 + JL sequenceDecs_decodeSync_safe_amd64_fill_2_byte_by_byte + MOVQ BX, AX + SHRQ $0x03, AX + SUBQ AX, R13 + MOVQ (R13), DX + SUBQ AX, SI + ANDQ $0x07, BX + JMP sequenceDecs_decodeSync_safe_amd64_fill_2_end + +sequenceDecs_decodeSync_safe_amd64_fill_2_byte_by_byte: + CMPQ SI, $0x00 + JLE sequenceDecs_decodeSync_safe_amd64_fill_2_check_overread + CMPQ BX, $0x07 + JLE sequenceDecs_decodeSync_safe_amd64_fill_2_end + SHLQ $0x08, DX + SUBQ $0x01, R13 + SUBQ $0x01, SI + SUBQ $0x08, BX + MOVBQZX (R13), AX + ORQ AX, DX + JMP sequenceDecs_decodeSync_safe_amd64_fill_2_byte_by_byte + +sequenceDecs_decodeSync_safe_amd64_fill_2_check_overread: + CMPQ BX, $0x40 + JA error_overread + +sequenceDecs_decodeSync_safe_amd64_fill_2_end: + // Update literal length + MOVQ DI, AX + MOVQ BX, CX + MOVQ DX, R14 + SHLQ CL, R14 + MOVB AH, CL + SHRQ $0x20, AX + TESTQ CX, CX + JZ sequenceDecs_decodeSync_safe_amd64_ll_update_zero + ADDQ CX, BX + CMPQ BX, $0x40 + JA sequenceDecs_decodeSync_safe_amd64_ll_update_zero + CMPQ CX, $0x40 + JAE sequenceDecs_decodeSync_safe_amd64_ll_update_zero + NEGQ CX + SHRQ CL, R14 + ADDQ R14, AX + +sequenceDecs_decodeSync_safe_amd64_ll_update_zero: + MOVQ AX, 24(SP) + + // Fill bitreader for state updates + MOVQ R13, (SP) + MOVQ R9, AX + SHRQ $0x08, AX + MOVBQZX AL, AX + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decodeSync_safe_amd64_skip_update + + // Update Literal Length State + MOVBQZX DI, R13 + SHRQ $0x10, DI + MOVWQZX DI, DI + LEAQ (BX)(R13*1), CX + MOVQ DX, R14 + MOVQ CX, BX + ROLQ CL, R14 + MOVL $0x00000001, R15 + MOVB R13, CL + SHLL CL, R15 + DECL R15 + ANDQ R15, R14 + ADDQ R14, DI + + // Load ctx.llTable + MOVQ ctx+16(FP), CX + MOVQ (CX), CX + MOVQ (CX)(DI*8), DI + + // Update Match Length State + MOVBQZX R8, R13 + SHRQ $0x10, R8 + MOVWQZX R8, R8 + LEAQ (BX)(R13*1), CX + MOVQ DX, R14 + MOVQ CX, BX + ROLQ CL, R14 + MOVL $0x00000001, R15 + MOVB R13, CL + SHLL CL, R15 + DECL R15 + ANDQ R15, R14 + ADDQ R14, R8 + + // Load ctx.mlTable + MOVQ ctx+16(FP), CX + MOVQ 24(CX), CX + MOVQ (CX)(R8*8), R8 + + // Update Offset State + MOVBQZX R9, R13 + SHRQ $0x10, R9 + MOVWQZX R9, R9 + LEAQ (BX)(R13*1), CX + MOVQ DX, R14 + MOVQ CX, BX + ROLQ CL, R14 + MOVL $0x00000001, R15 + MOVB R13, CL + SHLL CL, R15 + DECL R15 + ANDQ R15, R14 + ADDQ R14, R9 + + // Load ctx.ofTable + MOVQ ctx+16(FP), CX + MOVQ 48(CX), CX + MOVQ (CX)(R9*8), R9 + +sequenceDecs_decodeSync_safe_amd64_skip_update: + // Adjust offset + MOVQ s+0(FP), CX + MOVQ 8(SP), R13 + CMPQ AX, $0x01 + JBE sequenceDecs_decodeSync_safe_amd64_adjust_offsetB_1_or_0 + MOVUPS 144(CX), X0 + MOVQ R13, 144(CX) + MOVUPS X0, 152(CX) + JMP sequenceDecs_decodeSync_safe_amd64_after_adjust + +sequenceDecs_decodeSync_safe_amd64_adjust_offsetB_1_or_0: + CMPQ 24(SP), $0x00000000 + JNE sequenceDecs_decodeSync_safe_amd64_adjust_offset_maybezero + INCQ R13 + JMP sequenceDecs_decodeSync_safe_amd64_adjust_offset_nonzero + +sequenceDecs_decodeSync_safe_amd64_adjust_offset_maybezero: + TESTQ R13, R13 + JNZ sequenceDecs_decodeSync_safe_amd64_adjust_offset_nonzero + MOVQ 144(CX), R13 + JMP sequenceDecs_decodeSync_safe_amd64_after_adjust + +sequenceDecs_decodeSync_safe_amd64_adjust_offset_nonzero: + MOVQ R13, AX + XORQ R14, R14 + MOVQ $-1, R15 + CMPQ R13, $0x03 + CMOVQEQ R14, AX + CMOVQEQ R15, R14 + ADDQ 144(CX)(AX*8), R14 + JNZ sequenceDecs_decodeSync_safe_amd64_adjust_temp_valid + MOVQ $0x00000001, R14 + +sequenceDecs_decodeSync_safe_amd64_adjust_temp_valid: + CMPQ R13, $0x01 + JZ sequenceDecs_decodeSync_safe_amd64_adjust_skip + MOVQ 152(CX), AX + MOVQ AX, 160(CX) + +sequenceDecs_decodeSync_safe_amd64_adjust_skip: + MOVQ 144(CX), AX + MOVQ AX, 152(CX) + MOVQ R14, 144(CX) + MOVQ R14, R13 + +sequenceDecs_decodeSync_safe_amd64_after_adjust: + MOVQ R13, 8(SP) + + // Check values + MOVQ 16(SP), AX + MOVQ 24(SP), CX + LEAQ (AX)(CX*1), R14 + MOVQ s+0(FP), R15 + ADDQ R14, 256(R15) + MOVQ ctx+16(FP), R14 + SUBQ CX, 104(R14) + JS error_not_enough_literals + CMPQ AX, $0x00020002 + JA sequenceDecs_decodeSync_safe_amd64_error_match_len_too_big + TESTQ R13, R13 + JNZ sequenceDecs_decodeSync_safe_amd64_match_len_ofs_ok + TESTQ AX, AX + JNZ sequenceDecs_decodeSync_safe_amd64_error_match_len_ofs_mismatch + +sequenceDecs_decodeSync_safe_amd64_match_len_ofs_ok: + MOVQ 24(SP), AX + MOVQ 8(SP), CX + MOVQ 16(SP), R13 + + // Check if we have enough space in s.out + LEAQ (AX)(R13*1), R14 + ADDQ R10, R14 + CMPQ R14, 32(SP) + JA error_not_enough_space + + // Copy literals + TESTQ AX, AX + JZ check_offset + MOVQ AX, R14 + SUBQ $0x10, R14 + JB copy_1_small + +copy_1_loop: + MOVUPS (R11), X0 + MOVUPS X0, (R10) + ADDQ $0x10, R11 + ADDQ $0x10, R10 + SUBQ $0x10, R14 + JAE copy_1_loop + LEAQ 16(R11)(R14*1), R11 + LEAQ 16(R10)(R14*1), R10 + MOVUPS -16(R11), X0 + MOVUPS X0, -16(R10) + JMP copy_1_end + +copy_1_small: + CMPQ AX, $0x03 + JE copy_1_move_3 + JB copy_1_move_1or2 + CMPQ AX, $0x08 + JB copy_1_move_4through7 + JMP copy_1_move_8through16 + +copy_1_move_1or2: + MOVB (R11), R14 + MOVB -1(R11)(AX*1), R15 + MOVB R14, (R10) + MOVB R15, -1(R10)(AX*1) + ADDQ AX, R11 + ADDQ AX, R10 + JMP copy_1_end + +copy_1_move_3: + MOVW (R11), R14 + MOVB 2(R11), R15 + MOVW R14, (R10) + MOVB R15, 2(R10) + ADDQ AX, R11 + ADDQ AX, R10 + JMP copy_1_end + +copy_1_move_4through7: + MOVL (R11), R14 + MOVL -4(R11)(AX*1), R15 + MOVL R14, (R10) + MOVL R15, -4(R10)(AX*1) + ADDQ AX, R11 + ADDQ AX, R10 + JMP copy_1_end + +copy_1_move_8through16: + MOVQ (R11), R14 + MOVQ -8(R11)(AX*1), R15 + MOVQ R14, (R10) + MOVQ R15, -8(R10)(AX*1) + ADDQ AX, R11 + ADDQ AX, R10 + +copy_1_end: + ADDQ AX, R12 + + // Malformed input if seq.mo > t+len(hist) || seq.mo > s.windowSize) +check_offset: + MOVQ R12, AX + ADDQ 40(SP), AX + CMPQ CX, AX + JG error_match_off_too_big + CMPQ CX, 56(SP) + JG error_match_off_too_big + + // Copy match from history + MOVQ CX, AX + SUBQ R12, AX + JLS copy_match + MOVQ 48(SP), R14 + SUBQ AX, R14 + CMPQ R13, AX + JG copy_all_from_history + MOVQ R13, AX + SUBQ $0x10, AX + JB copy_4_small + +copy_4_loop: + MOVUPS (R14), X0 + MOVUPS X0, (R10) + ADDQ $0x10, R14 + ADDQ $0x10, R10 + SUBQ $0x10, AX + JAE copy_4_loop + LEAQ 16(R14)(AX*1), R14 + LEAQ 16(R10)(AX*1), R10 + MOVUPS -16(R14), X0 + MOVUPS X0, -16(R10) + JMP copy_4_end + +copy_4_small: + CMPQ R13, $0x03 + JE copy_4_move_3 + CMPQ R13, $0x08 + JB copy_4_move_4through7 + JMP copy_4_move_8through16 + +copy_4_move_3: + MOVW (R14), AX + MOVB 2(R14), CL + MOVW AX, (R10) + MOVB CL, 2(R10) + ADDQ R13, R14 + ADDQ R13, R10 + JMP copy_4_end + +copy_4_move_4through7: + MOVL (R14), AX + MOVL -4(R14)(R13*1), CX + MOVL AX, (R10) + MOVL CX, -4(R10)(R13*1) + ADDQ R13, R14 + ADDQ R13, R10 + JMP copy_4_end + +copy_4_move_8through16: + MOVQ (R14), AX + MOVQ -8(R14)(R13*1), CX + MOVQ AX, (R10) + MOVQ CX, -8(R10)(R13*1) + ADDQ R13, R14 + ADDQ R13, R10 + +copy_4_end: + ADDQ R13, R12 + JMP handle_loop + JMP loop_finished + +copy_all_from_history: + MOVQ AX, R15 + SUBQ $0x10, R15 + JB copy_5_small + +copy_5_loop: + MOVUPS (R14), X0 + MOVUPS X0, (R10) + ADDQ $0x10, R14 + ADDQ $0x10, R10 + SUBQ $0x10, R15 + JAE copy_5_loop + LEAQ 16(R14)(R15*1), R14 + LEAQ 16(R10)(R15*1), R10 + MOVUPS -16(R14), X0 + MOVUPS X0, -16(R10) + JMP copy_5_end + +copy_5_small: + CMPQ AX, $0x03 + JE copy_5_move_3 + JB copy_5_move_1or2 + CMPQ AX, $0x08 + JB copy_5_move_4through7 + JMP copy_5_move_8through16 + +copy_5_move_1or2: + MOVB (R14), R15 + MOVB -1(R14)(AX*1), BP + MOVB R15, (R10) + MOVB BP, -1(R10)(AX*1) + ADDQ AX, R14 + ADDQ AX, R10 + JMP copy_5_end + +copy_5_move_3: + MOVW (R14), R15 + MOVB 2(R14), BP + MOVW R15, (R10) + MOVB BP, 2(R10) + ADDQ AX, R14 + ADDQ AX, R10 + JMP copy_5_end + +copy_5_move_4through7: + MOVL (R14), R15 + MOVL -4(R14)(AX*1), BP + MOVL R15, (R10) + MOVL BP, -4(R10)(AX*1) + ADDQ AX, R14 + ADDQ AX, R10 + JMP copy_5_end + +copy_5_move_8through16: + MOVQ (R14), R15 + MOVQ -8(R14)(AX*1), BP + MOVQ R15, (R10) + MOVQ BP, -8(R10)(AX*1) + ADDQ AX, R14 + ADDQ AX, R10 + +copy_5_end: + ADDQ AX, R12 + SUBQ AX, R13 + + // Copy match from the current buffer +copy_match: + MOVQ R10, AX + SUBQ CX, AX + + // ml <= mo + CMPQ R13, CX + JA copy_overlapping_match + + // Copy non-overlapping match + ADDQ R13, R12 + MOVQ R13, CX + SUBQ $0x10, CX + JB copy_2_small + +copy_2_loop: + MOVUPS (AX), X0 + MOVUPS X0, (R10) + ADDQ $0x10, AX + ADDQ $0x10, R10 + SUBQ $0x10, CX + JAE copy_2_loop + LEAQ 16(AX)(CX*1), AX + LEAQ 16(R10)(CX*1), R10 + MOVUPS -16(AX), X0 + MOVUPS X0, -16(R10) + JMP copy_2_end + +copy_2_small: + CMPQ R13, $0x03 + JE copy_2_move_3 + JB copy_2_move_1or2 + CMPQ R13, $0x08 + JB copy_2_move_4through7 + JMP copy_2_move_8through16 + +copy_2_move_1or2: + MOVB (AX), CL + MOVB -1(AX)(R13*1), R14 + MOVB CL, (R10) + MOVB R14, -1(R10)(R13*1) + ADDQ R13, AX + ADDQ R13, R10 + JMP copy_2_end + +copy_2_move_3: + MOVW (AX), CX + MOVB 2(AX), R14 + MOVW CX, (R10) + MOVB R14, 2(R10) + ADDQ R13, AX + ADDQ R13, R10 + JMP copy_2_end + +copy_2_move_4through7: + MOVL (AX), CX + MOVL -4(AX)(R13*1), R14 + MOVL CX, (R10) + MOVL R14, -4(R10)(R13*1) + ADDQ R13, AX + ADDQ R13, R10 + JMP copy_2_end + +copy_2_move_8through16: + MOVQ (AX), CX + MOVQ -8(AX)(R13*1), R14 + MOVQ CX, (R10) + MOVQ R14, -8(R10)(R13*1) + ADDQ R13, AX + ADDQ R13, R10 + +copy_2_end: + JMP handle_loop + + // Copy overlapping match +copy_overlapping_match: + ADDQ R13, R12 + +copy_slow_3: + MOVB (AX), CL + MOVB CL, (R10) + INCQ AX + INCQ R10 + DECQ R13 + JNZ copy_slow_3 + +handle_loop: + MOVQ ctx+16(FP), AX + DECQ 96(AX) + JNS sequenceDecs_decodeSync_safe_amd64_main_loop + +loop_finished: + MOVQ br+8(FP), AX + MOVQ DX, 32(AX) + MOVB BL, 40(AX) + MOVQ SI, 24(AX) + + // Update the context + MOVQ ctx+16(FP), AX + MOVQ R12, 136(AX) + MOVQ 144(AX), CX + SUBQ CX, R11 + MOVQ R11, 168(AX) + + // Return success + MOVQ $0x00000000, ret+24(FP) + RET + + // Return with match length error +sequenceDecs_decodeSync_safe_amd64_error_match_len_ofs_mismatch: + MOVQ 16(SP), AX + MOVQ ctx+16(FP), CX + MOVQ AX, 216(CX) + MOVQ $0x00000001, ret+24(FP) + RET + + // Return with match too long error +sequenceDecs_decodeSync_safe_amd64_error_match_len_too_big: + MOVQ ctx+16(FP), AX + MOVQ 16(SP), CX + MOVQ CX, 216(AX) + MOVQ $0x00000002, ret+24(FP) + RET + + // Return with match offset too long error +error_match_off_too_big: + MOVQ ctx+16(FP), AX + MOVQ 8(SP), CX + MOVQ CX, 224(AX) + MOVQ R12, 136(AX) + MOVQ $0x00000003, ret+24(FP) + RET + + // Return with not enough literals error +error_not_enough_literals: + MOVQ ctx+16(FP), AX + MOVQ 24(SP), CX + MOVQ CX, 208(AX) + MOVQ $0x00000004, ret+24(FP) + RET + + // Return with overread error +error_overread: + MOVQ $0x00000006, ret+24(FP) + RET + + // Return with not enough output space error +error_not_enough_space: + MOVQ ctx+16(FP), AX + MOVQ 24(SP), CX + MOVQ CX, 208(AX) + MOVQ 16(SP), CX + MOVQ CX, 216(AX) + MOVQ R12, 136(AX) + MOVQ $0x00000005, ret+24(FP) + RET + +// func sequenceDecs_decodeSync_safe_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int +// Requires: BMI, BMI2, CMOV, SSE +TEXT ·sequenceDecs_decodeSync_safe_bmi2(SB), $64-32 + MOVQ br+8(FP), CX + MOVQ 32(CX), AX + MOVBQZX 40(CX), DX + MOVQ 24(CX), BX + MOVQ (CX), CX + ADDQ BX, CX + MOVQ CX, (SP) + MOVQ ctx+16(FP), CX + MOVQ 72(CX), SI + MOVQ 80(CX), DI + MOVQ 88(CX), R8 + XORQ R9, R9 + MOVQ R9, 8(SP) + MOVQ R9, 16(SP) + MOVQ R9, 24(SP) + MOVQ 112(CX), R9 + MOVQ 128(CX), R10 + MOVQ R10, 32(SP) + MOVQ 144(CX), R10 + MOVQ 136(CX), R11 + MOVQ 200(CX), R12 + MOVQ R12, 56(SP) + MOVQ 176(CX), R12 + MOVQ R12, 48(SP) + MOVQ 184(CX), CX + MOVQ CX, 40(SP) + MOVQ 40(SP), CX + ADDQ CX, 48(SP) + + // Calculate poiter to s.out[cap(s.out)] (a past-end pointer) + ADDQ R9, 32(SP) + + // outBase += outPosition + ADDQ R11, R9 + +sequenceDecs_decodeSync_safe_bmi2_main_loop: + MOVQ (SP), R12 + + // Fill bitreader to have enough for the offset and match length. + CMPQ BX, $0x08 + JL sequenceDecs_decodeSync_safe_bmi2_fill_byte_by_byte + MOVQ DX, CX + SHRQ $0x03, CX + SUBQ CX, R12 + MOVQ (R12), AX + SUBQ CX, BX + ANDQ $0x07, DX + JMP sequenceDecs_decodeSync_safe_bmi2_fill_end + +sequenceDecs_decodeSync_safe_bmi2_fill_byte_by_byte: + CMPQ BX, $0x00 + JLE sequenceDecs_decodeSync_safe_bmi2_fill_check_overread + CMPQ DX, $0x07 + JLE sequenceDecs_decodeSync_safe_bmi2_fill_end + SHLQ $0x08, AX + SUBQ $0x01, R12 + SUBQ $0x01, BX + SUBQ $0x08, DX + MOVBQZX (R12), CX + ORQ CX, AX + JMP sequenceDecs_decodeSync_safe_bmi2_fill_byte_by_byte + +sequenceDecs_decodeSync_safe_bmi2_fill_check_overread: + CMPQ DX, $0x40 + JA error_overread + +sequenceDecs_decodeSync_safe_bmi2_fill_end: + // Update offset + MOVQ $0x00000808, CX + BEXTRQ CX, R8, R13 + MOVQ AX, R14 + LEAQ (DX)(R13*1), CX + ROLQ CL, R14 + BZHIQ R13, R14, R14 + MOVQ CX, DX + MOVQ R8, CX + SHRQ $0x20, CX + ADDQ R14, CX + MOVQ CX, 8(SP) + + // Update match length + MOVQ $0x00000808, CX + BEXTRQ CX, DI, R13 + MOVQ AX, R14 + LEAQ (DX)(R13*1), CX + ROLQ CL, R14 + BZHIQ R13, R14, R14 + MOVQ CX, DX + MOVQ DI, CX + SHRQ $0x20, CX + ADDQ R14, CX + MOVQ CX, 16(SP) + + // Fill bitreader to have enough for the remaining + CMPQ BX, $0x08 + JL sequenceDecs_decodeSync_safe_bmi2_fill_2_byte_by_byte + MOVQ DX, CX + SHRQ $0x03, CX + SUBQ CX, R12 + MOVQ (R12), AX + SUBQ CX, BX + ANDQ $0x07, DX + JMP sequenceDecs_decodeSync_safe_bmi2_fill_2_end + +sequenceDecs_decodeSync_safe_bmi2_fill_2_byte_by_byte: + CMPQ BX, $0x00 + JLE sequenceDecs_decodeSync_safe_bmi2_fill_2_check_overread + CMPQ DX, $0x07 + JLE sequenceDecs_decodeSync_safe_bmi2_fill_2_end + SHLQ $0x08, AX + SUBQ $0x01, R12 + SUBQ $0x01, BX + SUBQ $0x08, DX + MOVBQZX (R12), CX + ORQ CX, AX + JMP sequenceDecs_decodeSync_safe_bmi2_fill_2_byte_by_byte + +sequenceDecs_decodeSync_safe_bmi2_fill_2_check_overread: + CMPQ DX, $0x40 + JA error_overread + +sequenceDecs_decodeSync_safe_bmi2_fill_2_end: + // Update literal length + MOVQ $0x00000808, CX + BEXTRQ CX, SI, R13 + MOVQ AX, R14 + LEAQ (DX)(R13*1), CX + ROLQ CL, R14 + BZHIQ R13, R14, R14 + MOVQ CX, DX + MOVQ SI, CX + SHRQ $0x20, CX + ADDQ R14, CX + MOVQ CX, 24(SP) + + // Fill bitreader for state updates + MOVQ R12, (SP) + MOVQ $0x00000808, CX + BEXTRQ CX, R8, R12 + MOVQ ctx+16(FP), CX + CMPQ 96(CX), $0x00 + JZ sequenceDecs_decodeSync_safe_bmi2_skip_update + LEAQ (SI)(DI*1), R13 + ADDQ R8, R13 + MOVBQZX R13, R13 + LEAQ (DX)(R13*1), CX + MOVQ AX, R14 + MOVQ CX, DX + ROLQ CL, R14 + BZHIQ R13, R14, R14 + + // Update Offset State + BZHIQ R8, R14, CX + SHRXQ R8, R14, R14 + MOVQ $0x00001010, R13 + BEXTRQ R13, R8, R8 + ADDQ CX, R8 + + // Load ctx.ofTable + MOVQ ctx+16(FP), CX + MOVQ 48(CX), CX + MOVQ (CX)(R8*8), R8 + + // Update Match Length State + BZHIQ DI, R14, CX + SHRXQ DI, R14, R14 + MOVQ $0x00001010, R13 + BEXTRQ R13, DI, DI + ADDQ CX, DI + + // Load ctx.mlTable + MOVQ ctx+16(FP), CX + MOVQ 24(CX), CX + MOVQ (CX)(DI*8), DI + + // Update Literal Length State + BZHIQ SI, R14, CX + MOVQ $0x00001010, R13 + BEXTRQ R13, SI, SI + ADDQ CX, SI + + // Load ctx.llTable + MOVQ ctx+16(FP), CX + MOVQ (CX), CX + MOVQ (CX)(SI*8), SI + +sequenceDecs_decodeSync_safe_bmi2_skip_update: + // Adjust offset + MOVQ s+0(FP), CX + MOVQ 8(SP), R13 + CMPQ R12, $0x01 + JBE sequenceDecs_decodeSync_safe_bmi2_adjust_offsetB_1_or_0 + MOVUPS 144(CX), X0 + MOVQ R13, 144(CX) + MOVUPS X0, 152(CX) + JMP sequenceDecs_decodeSync_safe_bmi2_after_adjust + +sequenceDecs_decodeSync_safe_bmi2_adjust_offsetB_1_or_0: + CMPQ 24(SP), $0x00000000 + JNE sequenceDecs_decodeSync_safe_bmi2_adjust_offset_maybezero + INCQ R13 + JMP sequenceDecs_decodeSync_safe_bmi2_adjust_offset_nonzero + +sequenceDecs_decodeSync_safe_bmi2_adjust_offset_maybezero: + TESTQ R13, R13 + JNZ sequenceDecs_decodeSync_safe_bmi2_adjust_offset_nonzero + MOVQ 144(CX), R13 + JMP sequenceDecs_decodeSync_safe_bmi2_after_adjust + +sequenceDecs_decodeSync_safe_bmi2_adjust_offset_nonzero: + MOVQ R13, R12 + XORQ R14, R14 + MOVQ $-1, R15 + CMPQ R13, $0x03 + CMOVQEQ R14, R12 + CMOVQEQ R15, R14 + ADDQ 144(CX)(R12*8), R14 + JNZ sequenceDecs_decodeSync_safe_bmi2_adjust_temp_valid + MOVQ $0x00000001, R14 + +sequenceDecs_decodeSync_safe_bmi2_adjust_temp_valid: + CMPQ R13, $0x01 + JZ sequenceDecs_decodeSync_safe_bmi2_adjust_skip + MOVQ 152(CX), R12 + MOVQ R12, 160(CX) + +sequenceDecs_decodeSync_safe_bmi2_adjust_skip: + MOVQ 144(CX), R12 + MOVQ R12, 152(CX) + MOVQ R14, 144(CX) + MOVQ R14, R13 + +sequenceDecs_decodeSync_safe_bmi2_after_adjust: + MOVQ R13, 8(SP) + + // Check values + MOVQ 16(SP), CX + MOVQ 24(SP), R12 + LEAQ (CX)(R12*1), R14 + MOVQ s+0(FP), R15 + ADDQ R14, 256(R15) + MOVQ ctx+16(FP), R14 + SUBQ R12, 104(R14) + JS error_not_enough_literals + CMPQ CX, $0x00020002 + JA sequenceDecs_decodeSync_safe_bmi2_error_match_len_too_big + TESTQ R13, R13 + JNZ sequenceDecs_decodeSync_safe_bmi2_match_len_ofs_ok + TESTQ CX, CX + JNZ sequenceDecs_decodeSync_safe_bmi2_error_match_len_ofs_mismatch + +sequenceDecs_decodeSync_safe_bmi2_match_len_ofs_ok: + MOVQ 24(SP), CX + MOVQ 8(SP), R12 + MOVQ 16(SP), R13 + + // Check if we have enough space in s.out + LEAQ (CX)(R13*1), R14 + ADDQ R9, R14 + CMPQ R14, 32(SP) + JA error_not_enough_space + + // Copy literals + TESTQ CX, CX + JZ check_offset + MOVQ CX, R14 + SUBQ $0x10, R14 + JB copy_1_small + +copy_1_loop: + MOVUPS (R10), X0 + MOVUPS X0, (R9) + ADDQ $0x10, R10 + ADDQ $0x10, R9 + SUBQ $0x10, R14 + JAE copy_1_loop + LEAQ 16(R10)(R14*1), R10 + LEAQ 16(R9)(R14*1), R9 + MOVUPS -16(R10), X0 + MOVUPS X0, -16(R9) + JMP copy_1_end + +copy_1_small: + CMPQ CX, $0x03 + JE copy_1_move_3 + JB copy_1_move_1or2 + CMPQ CX, $0x08 + JB copy_1_move_4through7 + JMP copy_1_move_8through16 + +copy_1_move_1or2: + MOVB (R10), R14 + MOVB -1(R10)(CX*1), R15 + MOVB R14, (R9) + MOVB R15, -1(R9)(CX*1) + ADDQ CX, R10 + ADDQ CX, R9 + JMP copy_1_end + +copy_1_move_3: + MOVW (R10), R14 + MOVB 2(R10), R15 + MOVW R14, (R9) + MOVB R15, 2(R9) + ADDQ CX, R10 + ADDQ CX, R9 + JMP copy_1_end + +copy_1_move_4through7: + MOVL (R10), R14 + MOVL -4(R10)(CX*1), R15 + MOVL R14, (R9) + MOVL R15, -4(R9)(CX*1) + ADDQ CX, R10 + ADDQ CX, R9 + JMP copy_1_end + +copy_1_move_8through16: + MOVQ (R10), R14 + MOVQ -8(R10)(CX*1), R15 + MOVQ R14, (R9) + MOVQ R15, -8(R9)(CX*1) + ADDQ CX, R10 + ADDQ CX, R9 + +copy_1_end: + ADDQ CX, R11 + + // Malformed input if seq.mo > t+len(hist) || seq.mo > s.windowSize) +check_offset: + MOVQ R11, CX + ADDQ 40(SP), CX + CMPQ R12, CX + JG error_match_off_too_big + CMPQ R12, 56(SP) + JG error_match_off_too_big + + // Copy match from history + MOVQ R12, CX + SUBQ R11, CX + JLS copy_match + MOVQ 48(SP), R14 + SUBQ CX, R14 + CMPQ R13, CX + JG copy_all_from_history + MOVQ R13, CX + SUBQ $0x10, CX + JB copy_4_small + +copy_4_loop: + MOVUPS (R14), X0 + MOVUPS X0, (R9) + ADDQ $0x10, R14 + ADDQ $0x10, R9 + SUBQ $0x10, CX + JAE copy_4_loop + LEAQ 16(R14)(CX*1), R14 + LEAQ 16(R9)(CX*1), R9 + MOVUPS -16(R14), X0 + MOVUPS X0, -16(R9) + JMP copy_4_end + +copy_4_small: + CMPQ R13, $0x03 + JE copy_4_move_3 + CMPQ R13, $0x08 + JB copy_4_move_4through7 + JMP copy_4_move_8through16 + +copy_4_move_3: + MOVW (R14), CX + MOVB 2(R14), R12 + MOVW CX, (R9) + MOVB R12, 2(R9) + ADDQ R13, R14 + ADDQ R13, R9 + JMP copy_4_end + +copy_4_move_4through7: + MOVL (R14), CX + MOVL -4(R14)(R13*1), R12 + MOVL CX, (R9) + MOVL R12, -4(R9)(R13*1) + ADDQ R13, R14 + ADDQ R13, R9 + JMP copy_4_end + +copy_4_move_8through16: + MOVQ (R14), CX + MOVQ -8(R14)(R13*1), R12 + MOVQ CX, (R9) + MOVQ R12, -8(R9)(R13*1) + ADDQ R13, R14 + ADDQ R13, R9 + +copy_4_end: + ADDQ R13, R11 + JMP handle_loop + JMP loop_finished + +copy_all_from_history: + MOVQ CX, R15 + SUBQ $0x10, R15 + JB copy_5_small + +copy_5_loop: + MOVUPS (R14), X0 + MOVUPS X0, (R9) + ADDQ $0x10, R14 + ADDQ $0x10, R9 + SUBQ $0x10, R15 + JAE copy_5_loop + LEAQ 16(R14)(R15*1), R14 + LEAQ 16(R9)(R15*1), R9 + MOVUPS -16(R14), X0 + MOVUPS X0, -16(R9) + JMP copy_5_end + +copy_5_small: + CMPQ CX, $0x03 + JE copy_5_move_3 + JB copy_5_move_1or2 + CMPQ CX, $0x08 + JB copy_5_move_4through7 + JMP copy_5_move_8through16 + +copy_5_move_1or2: + MOVB (R14), R15 + MOVB -1(R14)(CX*1), BP + MOVB R15, (R9) + MOVB BP, -1(R9)(CX*1) + ADDQ CX, R14 + ADDQ CX, R9 + JMP copy_5_end + +copy_5_move_3: + MOVW (R14), R15 + MOVB 2(R14), BP + MOVW R15, (R9) + MOVB BP, 2(R9) + ADDQ CX, R14 + ADDQ CX, R9 + JMP copy_5_end + +copy_5_move_4through7: + MOVL (R14), R15 + MOVL -4(R14)(CX*1), BP + MOVL R15, (R9) + MOVL BP, -4(R9)(CX*1) + ADDQ CX, R14 + ADDQ CX, R9 + JMP copy_5_end + +copy_5_move_8through16: + MOVQ (R14), R15 + MOVQ -8(R14)(CX*1), BP + MOVQ R15, (R9) + MOVQ BP, -8(R9)(CX*1) + ADDQ CX, R14 + ADDQ CX, R9 + +copy_5_end: + ADDQ CX, R11 + SUBQ CX, R13 + + // Copy match from the current buffer +copy_match: + MOVQ R9, CX + SUBQ R12, CX + + // ml <= mo + CMPQ R13, R12 + JA copy_overlapping_match + + // Copy non-overlapping match + ADDQ R13, R11 + MOVQ R13, R12 + SUBQ $0x10, R12 + JB copy_2_small + +copy_2_loop: + MOVUPS (CX), X0 + MOVUPS X0, (R9) + ADDQ $0x10, CX + ADDQ $0x10, R9 + SUBQ $0x10, R12 + JAE copy_2_loop + LEAQ 16(CX)(R12*1), CX + LEAQ 16(R9)(R12*1), R9 + MOVUPS -16(CX), X0 + MOVUPS X0, -16(R9) + JMP copy_2_end + +copy_2_small: + CMPQ R13, $0x03 + JE copy_2_move_3 + JB copy_2_move_1or2 + CMPQ R13, $0x08 + JB copy_2_move_4through7 + JMP copy_2_move_8through16 + +copy_2_move_1or2: + MOVB (CX), R12 + MOVB -1(CX)(R13*1), R14 + MOVB R12, (R9) + MOVB R14, -1(R9)(R13*1) + ADDQ R13, CX + ADDQ R13, R9 + JMP copy_2_end + +copy_2_move_3: + MOVW (CX), R12 + MOVB 2(CX), R14 + MOVW R12, (R9) + MOVB R14, 2(R9) + ADDQ R13, CX + ADDQ R13, R9 + JMP copy_2_end + +copy_2_move_4through7: + MOVL (CX), R12 + MOVL -4(CX)(R13*1), R14 + MOVL R12, (R9) + MOVL R14, -4(R9)(R13*1) + ADDQ R13, CX + ADDQ R13, R9 + JMP copy_2_end + +copy_2_move_8through16: + MOVQ (CX), R12 + MOVQ -8(CX)(R13*1), R14 + MOVQ R12, (R9) + MOVQ R14, -8(R9)(R13*1) + ADDQ R13, CX + ADDQ R13, R9 + +copy_2_end: + JMP handle_loop + + // Copy overlapping match +copy_overlapping_match: + ADDQ R13, R11 + +copy_slow_3: + MOVB (CX), R12 + MOVB R12, (R9) + INCQ CX + INCQ R9 + DECQ R13 + JNZ copy_slow_3 + +handle_loop: + MOVQ ctx+16(FP), CX + DECQ 96(CX) + JNS sequenceDecs_decodeSync_safe_bmi2_main_loop + +loop_finished: + MOVQ br+8(FP), CX + MOVQ AX, 32(CX) + MOVB DL, 40(CX) + MOVQ BX, 24(CX) + + // Update the context + MOVQ ctx+16(FP), AX + MOVQ R11, 136(AX) + MOVQ 144(AX), CX + SUBQ CX, R10 + MOVQ R10, 168(AX) + + // Return success + MOVQ $0x00000000, ret+24(FP) + RET + + // Return with match length error +sequenceDecs_decodeSync_safe_bmi2_error_match_len_ofs_mismatch: + MOVQ 16(SP), AX + MOVQ ctx+16(FP), CX + MOVQ AX, 216(CX) + MOVQ $0x00000001, ret+24(FP) + RET + + // Return with match too long error +sequenceDecs_decodeSync_safe_bmi2_error_match_len_too_big: + MOVQ ctx+16(FP), AX + MOVQ 16(SP), CX + MOVQ CX, 216(AX) + MOVQ $0x00000002, ret+24(FP) + RET + + // Return with match offset too long error +error_match_off_too_big: + MOVQ ctx+16(FP), AX + MOVQ 8(SP), CX + MOVQ CX, 224(AX) + MOVQ R11, 136(AX) + MOVQ $0x00000003, ret+24(FP) + RET + + // Return with not enough literals error +error_not_enough_literals: + MOVQ ctx+16(FP), AX + MOVQ 24(SP), CX + MOVQ CX, 208(AX) + MOVQ $0x00000004, ret+24(FP) + RET + + // Return with overread error +error_overread: + MOVQ $0x00000006, ret+24(FP) + RET + + // Return with not enough output space error +error_not_enough_space: + MOVQ ctx+16(FP), AX + MOVQ 24(SP), CX + MOVQ CX, 208(AX) + MOVQ 16(SP), CX + MOVQ CX, 216(AX) + MOVQ R11, 136(AX) + MOVQ $0x00000005, ret+24(FP) + RET diff --git a/vendor/github.com/klauspost/compress/zstd/seqdec_generic.go b/vendor/github.com/klauspost/compress/zstd/seqdec_generic.go new file mode 100644 index 0000000..ac2a80d --- /dev/null +++ b/vendor/github.com/klauspost/compress/zstd/seqdec_generic.go @@ -0,0 +1,237 @@ +//go:build !amd64 || appengine || !gc || noasm +// +build !amd64 appengine !gc noasm + +package zstd + +import ( + "fmt" + "io" +) + +// decode sequences from the stream with the provided history but without dictionary. +func (s *sequenceDecs) decodeSyncSimple(hist []byte) (bool, error) { + return false, nil +} + +// decode sequences from the stream without the provided history. +func (s *sequenceDecs) decode(seqs []seqVals) error { + br := s.br + + // Grab full sizes tables, to avoid bounds checks. + llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize] + llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state + s.seqSize = 0 + litRemain := len(s.literals) + + maxBlockSize := maxCompressedBlockSize + if s.windowSize < maxBlockSize { + maxBlockSize = s.windowSize + } + for i := range seqs { + var ll, mo, ml int + if br.off > 4+((maxOffsetBits+16+16)>>3) { + // inlined function: + // ll, mo, ml = s.nextFast(br, llState, mlState, ofState) + + // Final will not read from stream. + var llB, mlB, moB uint8 + ll, llB = llState.final() + ml, mlB = mlState.final() + mo, moB = ofState.final() + + // extra bits are stored in reverse order. + br.fillFast() + mo += br.getBits(moB) + if s.maxBits > 32 { + br.fillFast() + } + ml += br.getBits(mlB) + ll += br.getBits(llB) + + if moB > 1 { + s.prevOffset[2] = s.prevOffset[1] + s.prevOffset[1] = s.prevOffset[0] + s.prevOffset[0] = mo + } else { + // mo = s.adjustOffset(mo, ll, moB) + // Inlined for rather big speedup + if ll == 0 { + // There is an exception though, when current sequence's literals_length = 0. + // In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2, + // an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte. + mo++ + } + + if mo == 0 { + mo = s.prevOffset[0] + } else { + var temp int + if mo == 3 { + temp = s.prevOffset[0] - 1 + } else { + temp = s.prevOffset[mo] + } + + if temp == 0 { + // 0 is not valid; input is corrupted; force offset to 1 + println("WARNING: temp was 0") + temp = 1 + } + + if mo != 1 { + s.prevOffset[2] = s.prevOffset[1] + } + s.prevOffset[1] = s.prevOffset[0] + s.prevOffset[0] = temp + mo = temp + } + } + br.fillFast() + } else { + if br.overread() { + if debugDecoder { + printf("reading sequence %d, exceeded available data\n", i) + } + return io.ErrUnexpectedEOF + } + ll, mo, ml = s.next(br, llState, mlState, ofState) + br.fill() + } + + if debugSequences { + println("Seq", i, "Litlen:", ll, "mo:", mo, "(abs) ml:", ml) + } + // Evaluate. + // We might be doing this async, so do it early. + if mo == 0 && ml > 0 { + return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml) + } + if ml > maxMatchLen { + return fmt.Errorf("match len (%d) bigger than max allowed length", ml) + } + s.seqSize += ll + ml + if s.seqSize > maxBlockSize { + return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize) + } + litRemain -= ll + if litRemain < 0 { + return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, litRemain+ll) + } + seqs[i] = seqVals{ + ll: ll, + ml: ml, + mo: mo, + } + if i == len(seqs)-1 { + // This is the last sequence, so we shouldn't update state. + break + } + + // Manually inlined, ~ 5-20% faster + // Update all 3 states at once. Approx 20% faster. + nBits := llState.nbBits() + mlState.nbBits() + ofState.nbBits() + if nBits == 0 { + llState = llTable[llState.newState()&maxTableMask] + mlState = mlTable[mlState.newState()&maxTableMask] + ofState = ofTable[ofState.newState()&maxTableMask] + } else { + bits := br.get32BitsFast(nBits) + lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31)) + llState = llTable[(llState.newState()+lowBits)&maxTableMask] + + lowBits = uint16(bits >> (ofState.nbBits() & 31)) + lowBits &= bitMask[mlState.nbBits()&15] + mlState = mlTable[(mlState.newState()+lowBits)&maxTableMask] + + lowBits = uint16(bits) & bitMask[ofState.nbBits()&15] + ofState = ofTable[(ofState.newState()+lowBits)&maxTableMask] + } + } + s.seqSize += litRemain + if s.seqSize > maxBlockSize { + return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize) + } + err := br.close() + if err != nil { + printf("Closing sequences: %v, %+v\n", err, *br) + } + return err +} + +// executeSimple handles cases when a dictionary is not used. +func (s *sequenceDecs) executeSimple(seqs []seqVals, hist []byte) error { + // Ensure we have enough output size... + if len(s.out)+s.seqSize > cap(s.out) { + addBytes := s.seqSize + len(s.out) + s.out = append(s.out, make([]byte, addBytes)...) + s.out = s.out[:len(s.out)-addBytes] + } + + if debugDecoder { + printf("Execute %d seqs with literals: %d into %d bytes\n", len(seqs), len(s.literals), s.seqSize) + } + + var t = len(s.out) + out := s.out[:t+s.seqSize] + + for _, seq := range seqs { + // Add literals + copy(out[t:], s.literals[:seq.ll]) + t += seq.ll + s.literals = s.literals[seq.ll:] + + // Malformed input + if seq.mo > t+len(hist) || seq.mo > s.windowSize { + return fmt.Errorf("match offset (%d) bigger than current history (%d)", seq.mo, t+len(hist)) + } + + // Copy from history. + if v := seq.mo - t; v > 0 { + // v is the start position in history from end. + start := len(hist) - v + if seq.ml > v { + // Some goes into the current block. + // Copy remainder of history + copy(out[t:], hist[start:]) + t += v + seq.ml -= v + } else { + copy(out[t:], hist[start:start+seq.ml]) + t += seq.ml + continue + } + } + + // We must be in the current buffer now + if seq.ml > 0 { + start := t - seq.mo + if seq.ml <= t-start { + // No overlap + copy(out[t:], out[start:start+seq.ml]) + t += seq.ml + } else { + // Overlapping copy + // Extend destination slice and copy one byte at the time. + src := out[start : start+seq.ml] + dst := out[t:] + dst = dst[:len(src)] + t += len(src) + // Destination is the space we just added. + for i := range src { + dst[i] = src[i] + } + } + } + } + // Add final literals + copy(out[t:], s.literals) + if debugDecoder { + t += len(s.literals) + if t != len(out) { + panic(fmt.Errorf("length mismatch, want %d, got %d, ss: %d", len(out), t, s.seqSize)) + } + } + s.out = out + + return nil +} diff --git a/vendor/github.com/klauspost/compress/zstd/zip.go b/vendor/github.com/klauspost/compress/zstd/zip.go index 967f29b..29c15c8 100644 --- a/vendor/github.com/klauspost/compress/zstd/zip.go +++ b/vendor/github.com/klauspost/compress/zstd/zip.go @@ -18,36 +18,58 @@ const ZipMethodWinZip = 93 // See https://pkware.cachefly.net/webdocs/APPNOTE/APPNOTE-6.3.9.TXT const ZipMethodPKWare = 20 -var zipReaderPool sync.Pool +// zipReaderPool is the default reader pool. +var zipReaderPool = sync.Pool{New: func() interface{} { + z, err := NewReader(nil, WithDecoderLowmem(true), WithDecoderMaxWindow(128<<20), WithDecoderConcurrency(1)) + if err != nil { + panic(err) + } + return z +}} -// newZipReader cannot be used since we would leak goroutines... -func newZipReader(r io.Reader) io.ReadCloser { - dec, ok := zipReaderPool.Get().(*Decoder) - if ok { - dec.Reset(r) - } else { - d, err := NewReader(r, WithDecoderConcurrency(1), WithDecoderLowmem(true)) - if err != nil { - panic(err) +// newZipReader creates a pooled zip decompressor. +func newZipReader(opts ...DOption) func(r io.Reader) io.ReadCloser { + pool := &zipReaderPool + if len(opts) > 0 { + opts = append([]DOption{WithDecoderLowmem(true), WithDecoderMaxWindow(128 << 20)}, opts...) + // Force concurrency 1 + opts = append(opts, WithDecoderConcurrency(1)) + // Create our own pool + pool = &sync.Pool{} + } + return func(r io.Reader) io.ReadCloser { + dec, ok := pool.Get().(*Decoder) + if ok { + dec.Reset(r) + } else { + d, err := NewReader(r, opts...) + if err != nil { + panic(err) + } + dec = d } - dec = d + return &pooledZipReader{dec: dec, pool: pool} } - return &pooledZipReader{dec: dec} } type pooledZipReader struct { - mu sync.Mutex // guards Close and Read - dec *Decoder + mu sync.Mutex // guards Close and Read + pool *sync.Pool + dec *Decoder } func (r *pooledZipReader) Read(p []byte) (n int, err error) { r.mu.Lock() defer r.mu.Unlock() if r.dec == nil { - return 0, errors.New("Read after Close") + return 0, errors.New("read after close or EOF") } dec, err := r.dec.Read(p) - + if err == io.EOF { + r.dec.Reset(nil) + r.pool.Put(r.dec) + r.dec = nil + } return dec, err } @@ -57,7 +79,7 @@ func (r *pooledZipReader) Close() error { var err error if r.dec != nil { err = r.dec.Reset(nil) - zipReaderPool.Put(r.dec) + r.pool.Put(r.dec) r.dec = nil } return err @@ -111,12 +133,9 @@ func ZipCompressor(opts ...EOption) func(w io.Writer) (io.WriteCloser, error) { // ZipDecompressor returns a decompressor that can be registered with zip libraries. // See ZipCompressor for example. -func ZipDecompressor() func(r io.Reader) io.ReadCloser { - return func(r io.Reader) io.ReadCloser { - d, err := NewReader(r, WithDecoderConcurrency(1), WithDecoderLowmem(true)) - if err != nil { - panic(err) - } - return d.IOReadCloser() - } +// Options can be specified. WithDecoderConcurrency(1) is forced, +// and by default a 128MB maximum decompression window is specified. +// The window size can be overridden if required. +func ZipDecompressor(opts ...DOption) func(r io.Reader) io.ReadCloser { + return newZipReader(opts...) } diff --git a/vendor/github.com/klauspost/compress/zstd/zstd.go b/vendor/github.com/klauspost/compress/zstd/zstd.go index ef1d49a..4be7cc7 100644 --- a/vendor/github.com/klauspost/compress/zstd/zstd.go +++ b/vendor/github.com/klauspost/compress/zstd/zstd.go @@ -9,7 +9,6 @@ import ( "errors" "log" "math" - "math/bits" ) // enable debug printing @@ -36,8 +35,8 @@ const forcePreDef = false // zstdMinMatch is the minimum zstd match length. const zstdMinMatch = 3 -// Reset the buffer offset when reaching this. -const bufferReset = math.MaxInt32 - MaxWindowSize +// fcsUnknown is used for unknown frame content size. +const fcsUnknown = math.MaxUint64 var ( // ErrReservedBlockType is returned when a reserved block type is found. @@ -52,6 +51,10 @@ var ( // Typically returned on invalid input. ErrBlockTooSmall = errors.New("block too small") + // ErrUnexpectedBlockSize is returned when a block has unexpected size. + // Typically returned on invalid input. + ErrUnexpectedBlockSize = errors.New("unexpected block size") + // ErrMagicMismatch is returned when a "magic" number isn't what is expected. // Typically this indicates wrong or corrupted input. ErrMagicMismatch = errors.New("invalid input: magic number mismatch") @@ -68,13 +71,16 @@ var ( ErrDecoderSizeExceeded = errors.New("decompressed size exceeds configured limit") // ErrUnknownDictionary is returned if the dictionary ID is unknown. - // For the time being dictionaries are not supported. ErrUnknownDictionary = errors.New("unknown dictionary") // ErrFrameSizeExceeded is returned if the stated frame size is exceeded. // This is only returned if SingleSegment is specified on the frame. ErrFrameSizeExceeded = errors.New("frame size exceeded") + // ErrFrameSizeMismatch is returned if the stated frame size does not match the expected size. + // This is only returned if SingleSegment is specified on the frame. + ErrFrameSizeMismatch = errors.New("frame size does not match size on stream") + // ErrCRCMismatch is returned if CRC mismatches. ErrCRCMismatch = errors.New("CRC check failed") @@ -99,49 +105,12 @@ func printf(format string, a ...interface{}) { } } -// matchLenFast does matching, but will not match the last up to 7 bytes. -func matchLenFast(a, b []byte) int { - endI := len(a) & (math.MaxInt32 - 7) - for i := 0; i < endI; i += 8 { - if diff := load64(a, i) ^ load64(b, i); diff != 0 { - return i + bits.TrailingZeros64(diff)>>3 - } - } - return endI -} - -// matchLen returns the maximum length. -// a must be the shortest of the two. -// The function also returns whether all bytes matched. -func matchLen(a, b []byte) int { - b = b[:len(a)] - for i := 0; i < len(a)-7; i += 8 { - if diff := load64(a, i) ^ load64(b, i); diff != 0 { - return i + (bits.TrailingZeros64(diff) >> 3) - } - } - - checked := (len(a) >> 3) << 3 - a = a[checked:] - b = b[checked:] - for i := range a { - if a[i] != b[i] { - return i + checked - } - } - return len(a) + checked -} - func load3232(b []byte, i int32) uint32 { - return binary.LittleEndian.Uint32(b[i:]) + return binary.LittleEndian.Uint32(b[:len(b):len(b)][i:]) } func load6432(b []byte, i int32) uint64 { - return binary.LittleEndian.Uint64(b[i:]) -} - -func load64(b []byte, i int) uint64 { - return binary.LittleEndian.Uint64(b[i:]) + return binary.LittleEndian.Uint64(b[:len(b):len(b)][i:]) } type byter interface { diff --git a/vendor/github.com/montanaflynn/stats/.gitignore b/vendor/github.com/montanaflynn/stats/.gitignore new file mode 100644 index 0000000..75a2a3a --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/.gitignore @@ -0,0 +1,7 @@ +coverage.out +coverage.txt +release-notes.txt +.directory +.chglog +.vscode +.DS_Store \ No newline at end of file diff --git a/vendor/github.com/montanaflynn/stats/CHANGELOG.md b/vendor/github.com/montanaflynn/stats/CHANGELOG.md new file mode 100644 index 0000000..73c3b78 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/CHANGELOG.md @@ -0,0 +1,534 @@ + +## [Unreleased] + + + +## [v0.7.1] - 2023-05-11 +### Add +- Add describe functions ([#77](https://github.com/montanaflynn/stats/issues/77)) + +### Update +- Update .gitignore +- Update README.md, LICENSE and DOCUMENTATION.md files +- Update github action go workflow to run on push + + + +## [v0.7.0] - 2023-01-08 +### Add +- Add geometric distribution functions ([#75](https://github.com/montanaflynn/stats/issues/75)) +- Add GitHub action go workflow + +### Remove +- Remove travis CI config + +### Update +- Update changelog with v0.7.0 changes +- Update changelog with v0.7.0 changes +- Update github action go workflow +- Update geometric distribution tests + + + +## [v0.6.6] - 2021-04-26 +### Add +- Add support for string and io.Reader in LoadRawData (pr [#68](https://github.com/montanaflynn/stats/issues/68)) +- Add latest versions of Go to test against + +### Update +- Update changelog with v0.6.6 changes + +### Use +- Use math.Sqrt in StandardDeviation (PR [#64](https://github.com/montanaflynn/stats/issues/64)) + + + +## [v0.6.5] - 2021-02-21 +### Add +- Add Float64Data.Quartiles documentation +- Add Quartiles method to Float64Data type (issue [#60](https://github.com/montanaflynn/stats/issues/60)) + +### Fix +- Fix make release changelog command and add changelog history + +### Update +- Update changelog with v0.6.5 changes +- Update changelog with v0.6.4 changes +- Update README.md links to CHANGELOG.md and DOCUMENTATION.md +- Update README.md and Makefile with new release commands + + + +## [v0.6.4] - 2021-01-13 +### Fix +- Fix failing tests due to precision errors on arm64 ([#58](https://github.com/montanaflynn/stats/issues/58)) + +### Update +- Update changelog with v0.6.4 changes +- Update examples directory to include a README.md used for synopsis +- Update go.mod to include go version where modules are enabled by default +- Update changelog with v0.6.3 changes + + + +## [v0.6.3] - 2020-02-18 +### Add +- Add creating and committing changelog to Makefile release directive +- Add release-notes.txt and .chglog directory to .gitignore + +### Update +- Update exported tests to use import for better example documentation +- Update documentation using godoc2md +- Update changelog with v0.6.2 release + + + +## [v0.6.2] - 2020-02-18 +### Fix +- Fix linting errcheck warnings in go benchmarks + +### Update +- Update Makefile release directive to use correct release name + + + +## [v0.6.1] - 2020-02-18 +### Add +- Add StableSample function signature to readme + +### Fix +- Fix linting warnings for normal distribution functions formatting and tests + +### Update +- Update documentation links and rename DOC.md to DOCUMENTATION.md +- Update README with link to pkg.go.dev reference and release section +- Update Makefile with new changelog, docs, and release directives +- Update DOC.md links to GitHub source code +- Update doc.go comment and add DOC.md package reference file +- Update changelog using git-chglog + + + +## [v0.6.0] - 2020-02-17 +### Add +- Add Normal Distribution Functions ([#56](https://github.com/montanaflynn/stats/issues/56)) +- Add previous versions of Go to travis CI config +- Add check for distinct values in Mode function ([#51](https://github.com/montanaflynn/stats/issues/51)) +- Add StableSample function ([#48](https://github.com/montanaflynn/stats/issues/48)) +- Add doc.go file to show description and usage on godoc.org +- Add comments to new error and legacy error variables +- Add ExampleRound function to tests +- Add go.mod file for module support +- Add Sigmoid, SoftMax and Entropy methods and tests +- Add Entropy documentation, example and benchmarks +- Add Entropy function ([#44](https://github.com/montanaflynn/stats/issues/44)) + +### Fix +- Fix percentile when only one element ([#47](https://github.com/montanaflynn/stats/issues/47)) +- Fix AutoCorrelation name in comments and remove unneeded Sprintf + +### Improve +- Improve documentation section with command comments + +### Remove +- Remove very old versions of Go in travis CI config +- Remove boolean comparison to get rid of gometalinter warning + +### Update +- Update license dates +- Update Distance functions signatures to use Float64Data +- Update Sigmoid examples +- Update error names with backward compatibility + +### Use +- Use relative link to examples/main.go +- Use a single var block for exported errors + + + +## [v0.5.0] - 2019-01-16 +### Add +- Add Sigmoid and Softmax functions + +### Fix +- Fix syntax highlighting and add CumulativeSum func + + + +## [v0.4.0] - 2019-01-14 +### Add +- Add goreport badge and documentation section to README.md +- Add Examples to test files +- Add AutoCorrelation and nist tests +- Add String method to statsErr type +- Add Y coordinate error for ExponentialRegression +- Add syntax highlighting ([#43](https://github.com/montanaflynn/stats/issues/43)) +- Add CumulativeSum ([#40](https://github.com/montanaflynn/stats/issues/40)) +- Add more tests and rename distance files +- Add coverage and benchmarks to azure pipeline +- Add go tests to azure pipeline + +### Change +- Change travis tip alias to master +- Change codecov to coveralls for code coverage + +### Fix +- Fix a few lint warnings +- Fix example error + +### Improve +- Improve test coverage of distance functions + +### Only +- Only run travis on stable and tip versions +- Only check code coverage on tip + +### Remove +- Remove azure CI pipeline +- Remove unnecessary type conversions + +### Return +- Return EmptyInputErr instead of EmptyInput + +### Set +- Set up CI with Azure Pipelines + + + +## [0.3.0] - 2017-12-02 +### Add +- Add Chebyshev, Manhattan, Euclidean and Minkowski distance functions ([#35](https://github.com/montanaflynn/stats/issues/35)) +- Add function for computing chebyshev distance. ([#34](https://github.com/montanaflynn/stats/issues/34)) +- Add support for time.Duration +- Add LoadRawData to docs and examples +- Add unit test for edge case that wasn't covered +- Add unit tests for edge cases that weren't covered +- Add pearson alias delegating to correlation +- Add CovariancePopulation to Float64Data +- Add pearson product-moment correlation coefficient +- Add population covariance +- Add random slice benchmarks +- Add all applicable functions as methods to Float64Data type +- Add MIT license badge +- Add link to examples/methods.go +- Add Protips for usage and documentation sections +- Add tests for rounding up +- Add webdoc target and remove linting from test target +- Add example usage and consolidate contributing information + +### Added +- Added MedianAbsoluteDeviation + +### Annotation +- Annotation spelling error + +### Auto +- auto commit +- auto commit + +### Calculate +- Calculate correlation with sdev and covp + +### Clean +- Clean up README.md and add info for offline docs + +### Consolidated +- Consolidated all error values. + +### Fix +- Fix Percentile logic +- Fix InterQuartileRange method test +- Fix zero percent bug and add test +- Fix usage example output typos + +### Improve +- Improve bounds checking in Percentile +- Improve error log messaging + +### Imput +- Imput -> Input + +### Include +- Include alternative way to set Float64Data in example + +### Make +- Make various changes to README.md + +### Merge +- Merge branch 'master' of github.com:montanaflynn/stats +- Merge master + +### Mode +- Mode calculation fix and tests + +### Realized +- Realized the obvious efficiency gains of ignoring the unique numbers at the beginning of the slice. Benchmark joy ensued. + +### Refactor +- Refactor testing of Round() +- Refactor setting Coordinate y field using Exp in place of Pow +- Refactor Makefile and add docs target + +### Remove +- Remove deep links to types and functions + +### Rename +- Rename file from types to data + +### Retrieve +- Retrieve InterQuartileRange for the Float64Data. + +### Split +- Split up stats.go into separate files + +### Support +- Support more types on LoadRawData() ([#36](https://github.com/montanaflynn/stats/issues/36)) + +### Switch +- Switch default and check targets + +### Update +- Update Readme +- Update example methods and some text +- Update README and include Float64Data type method examples + +### Pull Requests +- Merge pull request [#32](https://github.com/montanaflynn/stats/issues/32) from a-robinson/percentile +- Merge pull request [#30](https://github.com/montanaflynn/stats/issues/30) from montanaflynn/fix-test +- Merge pull request [#29](https://github.com/montanaflynn/stats/issues/29) from edupsousa/master +- Merge pull request [#27](https://github.com/montanaflynn/stats/issues/27) from andrey-yantsen/fix-percentile-out-of-bounds +- Merge pull request [#25](https://github.com/montanaflynn/stats/issues/25) from kazhuravlev/patch-1 +- Merge pull request [#22](https://github.com/montanaflynn/stats/issues/22) from JanBerktold/time-duration +- Merge pull request [#24](https://github.com/montanaflynn/stats/issues/24) from alouche/master +- Merge pull request [#21](https://github.com/montanaflynn/stats/issues/21) from brydavis/master +- Merge pull request [#19](https://github.com/montanaflynn/stats/issues/19) from ginodeis/mode-bug +- Merge pull request [#17](https://github.com/montanaflynn/stats/issues/17) from Kunde21/master +- Merge pull request [#3](https://github.com/montanaflynn/stats/issues/3) from montanaflynn/master +- Merge pull request [#2](https://github.com/montanaflynn/stats/issues/2) from montanaflynn/master +- Merge pull request [#13](https://github.com/montanaflynn/stats/issues/13) from toashd/pearson +- Merge pull request [#12](https://github.com/montanaflynn/stats/issues/12) from alixaxel/MAD +- Merge pull request [#1](https://github.com/montanaflynn/stats/issues/1) from montanaflynn/master +- Merge pull request [#11](https://github.com/montanaflynn/stats/issues/11) from Kunde21/modeMemReduce +- Merge pull request [#10](https://github.com/montanaflynn/stats/issues/10) from Kunde21/ModeRewrite + + + +## [0.2.0] - 2015-10-14 +### Add +- Add Makefile with gometalinter, testing, benchmarking and coverage report targets +- Add comments describing functions and structs +- Add Correlation func +- Add Covariance func +- Add tests for new function shortcuts +- Add StandardDeviation function as a shortcut to StandardDeviationPopulation +- Add Float64Data and Series types + +### Change +- Change Sample to return a standard []float64 type + +### Fix +- Fix broken link to Makefile +- Fix broken link and simplify code coverage reporting command +- Fix go vet warning about printf type placeholder +- Fix failing codecov test coverage reporting +- Fix link to CHANGELOG.md + +### Fixed +- Fixed typographical error, changed accomdate to accommodate in README. + +### Include +- Include Variance and StandardDeviation shortcuts + +### Pass +- Pass gometalinter + +### Refactor +- Refactor Variance function to be the same as population variance + +### Release +- Release version 0.2.0 + +### Remove +- Remove unneeded do packages and update cover URL +- Remove sudo from pip install + +### Reorder +- Reorder functions and sections + +### Revert +- Revert to legacy containers to preserve go1.1 testing + +### Switch +- Switch from legacy to container-based CI infrastructure + +### Update +- Update contributing instructions and mention Makefile + +### Pull Requests +- Merge pull request [#5](https://github.com/montanaflynn/stats/issues/5) from orthographic-pedant/spell_check/accommodate + + + +## [0.1.0] - 2015-08-19 +### Add +- Add CONTRIBUTING.md + +### Rename +- Rename functions while preserving backwards compatibility + + + +## 0.0.9 - 2015-08-18 +### Add +- Add HarmonicMean func +- Add GeometricMean func +- Add .gitignore to avoid commiting test coverage report +- Add Outliers stuct and QuantileOutliers func +- Add Interquartile Range, Midhinge and Trimean examples +- Add Trimean +- Add Midhinge +- Add Inter Quartile Range +- Add a unit test to check for an empty slice error +- Add Quantiles struct and Quantile func +- Add more tests and fix a typo +- Add Golang 1.5 to build tests +- Add a standard MIT license file +- Add basic benchmarking +- Add regression models +- Add codecov token +- Add codecov +- Add check for slices with a single item +- Add coverage tests +- Add back previous Go versions to Travis CI +- Add Travis CI +- Add GoDoc badge +- Add Percentile and Float64ToInt functions +- Add another rounding test for whole numbers +- Add build status badge +- Add code coverage badge +- Add test for NaN, achieving 100% code coverage +- Add round function +- Add standard deviation function +- Add sum function + +### Add +- add tests for sample +- add sample + +### Added +- Added sample and population variance and deviation functions +- Added README + +### Adjust +- Adjust API ordering + +### Avoid +- Avoid unintended consequence of using sort + +### Better +- Better performing min/max +- Better description + +### Change +- Change package path to potentially fix a bug in earlier versions of Go + +### Clean +- Clean up README and add some more information +- Clean up test error + +### Consistent +- Consistent empty slice error messages +- Consistent var naming +- Consistent func declaration + +### Convert +- Convert ints to floats + +### Duplicate +- Duplicate packages for all versions + +### Export +- Export Coordinate struct fields + +### First +- First commit + +### Fix +- Fix copy pasta mistake testing the wrong function +- Fix error message +- Fix usage output and edit API doc section +- Fix testing edgecase where map was in wrong order +- Fix usage example +- Fix usage examples + +### Include +- Include the Nearest Rank method of calculating percentiles + +### More +- More commenting + +### Move +- Move GoDoc link to top + +### Redirect +- Redirect kills newer versions of Go + +### Refactor +- Refactor code and error checking + +### Remove +- Remove unnecassary typecasting in sum func +- Remove cover since it doesn't work for later versions of go +- Remove golint and gocoveralls + +### Rename +- Rename StandardDev to StdDev +- Rename StandardDev to StdDev + +### Return +- Return errors for all functions + +### Run +- Run go fmt to clean up formatting + +### Simplify +- Simplify min/max function + +### Start +- Start with minimal tests + +### Switch +- Switch wercker to travis and update todos + +### Table +- table testing style + +### Update +- Update README and move the example main.go into it's own file +- Update TODO list +- Update README +- Update usage examples and todos + +### Use +- Use codecov the recommended way +- Use correct string formatting types + +### Pull Requests +- Merge pull request [#4](https://github.com/montanaflynn/stats/issues/4) from saromanov/sample + + +[Unreleased]: https://github.com/montanaflynn/stats/compare/v0.7.1...HEAD +[v0.7.1]: https://github.com/montanaflynn/stats/compare/v0.7.0...v0.7.1 +[v0.7.0]: https://github.com/montanaflynn/stats/compare/v0.6.6...v0.7.0 +[v0.6.6]: https://github.com/montanaflynn/stats/compare/v0.6.5...v0.6.6 +[v0.6.5]: https://github.com/montanaflynn/stats/compare/v0.6.4...v0.6.5 +[v0.6.4]: https://github.com/montanaflynn/stats/compare/v0.6.3...v0.6.4 +[v0.6.3]: https://github.com/montanaflynn/stats/compare/v0.6.2...v0.6.3 +[v0.6.2]: https://github.com/montanaflynn/stats/compare/v0.6.1...v0.6.2 +[v0.6.1]: https://github.com/montanaflynn/stats/compare/v0.6.0...v0.6.1 +[v0.6.0]: https://github.com/montanaflynn/stats/compare/v0.5.0...v0.6.0 +[v0.5.0]: https://github.com/montanaflynn/stats/compare/v0.4.0...v0.5.0 +[v0.4.0]: https://github.com/montanaflynn/stats/compare/0.3.0...v0.4.0 +[0.3.0]: https://github.com/montanaflynn/stats/compare/0.2.0...0.3.0 +[0.2.0]: https://github.com/montanaflynn/stats/compare/0.1.0...0.2.0 +[0.1.0]: https://github.com/montanaflynn/stats/compare/0.0.9...0.1.0 diff --git a/vendor/github.com/montanaflynn/stats/DOCUMENTATION.md b/vendor/github.com/montanaflynn/stats/DOCUMENTATION.md new file mode 100644 index 0000000..978df2f --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/DOCUMENTATION.md @@ -0,0 +1,1271 @@ + + +# stats +`import "github.com/montanaflynn/stats"` + +* [Overview](#pkg-overview) +* [Index](#pkg-index) +* [Examples](#pkg-examples) +* [Subdirectories](#pkg-subdirectories) + +## Overview +Package stats is a well tested and comprehensive +statistics library package with no dependencies. + +Example Usage: + + + // start with some source data to use + data := []float64{1.0, 2.1, 3.2, 4.823, 4.1, 5.8} + + // you could also use different types like this + // data := stats.LoadRawData([]int{1, 2, 3, 4, 5}) + // data := stats.LoadRawData([]interface{}{1.1, "2", 3}) + // etc... + + median, _ := stats.Median(data) + fmt.Println(median) // 3.65 + + roundedMedian, _ := stats.Round(median, 0) + fmt.Println(roundedMedian) // 4 + +MIT License Copyright (c) 2014-2020 Montana Flynn (https://montanaflynn.com) + + + + +## Index +* [Variables](#pkg-variables) +* [func AutoCorrelation(data Float64Data, lags int) (float64, error)](#AutoCorrelation) +* [func ChebyshevDistance(dataPointX, dataPointY Float64Data) (distance float64, err error)](#ChebyshevDistance) +* [func Correlation(data1, data2 Float64Data) (float64, error)](#Correlation) +* [func Covariance(data1, data2 Float64Data) (float64, error)](#Covariance) +* [func CovariancePopulation(data1, data2 Float64Data) (float64, error)](#CovariancePopulation) +* [func CumulativeSum(input Float64Data) ([]float64, error)](#CumulativeSum) +* [func Entropy(input Float64Data) (float64, error)](#Entropy) +* [func EuclideanDistance(dataPointX, dataPointY Float64Data) (distance float64, err error)](#EuclideanDistance) +* [func ExpGeom(p float64) (exp float64, err error)](#ExpGeom) +* [func GeometricMean(input Float64Data) (float64, error)](#GeometricMean) +* [func HarmonicMean(input Float64Data) (float64, error)](#HarmonicMean) +* [func InterQuartileRange(input Float64Data) (float64, error)](#InterQuartileRange) +* [func ManhattanDistance(dataPointX, dataPointY Float64Data) (distance float64, err error)](#ManhattanDistance) +* [func Max(input Float64Data) (max float64, err error)](#Max) +* [func Mean(input Float64Data) (float64, error)](#Mean) +* [func Median(input Float64Data) (median float64, err error)](#Median) +* [func MedianAbsoluteDeviation(input Float64Data) (mad float64, err error)](#MedianAbsoluteDeviation) +* [func MedianAbsoluteDeviationPopulation(input Float64Data) (mad float64, err error)](#MedianAbsoluteDeviationPopulation) +* [func Midhinge(input Float64Data) (float64, error)](#Midhinge) +* [func Min(input Float64Data) (min float64, err error)](#Min) +* [func MinkowskiDistance(dataPointX, dataPointY Float64Data, lambda float64) (distance float64, err error)](#MinkowskiDistance) +* [func Mode(input Float64Data) (mode []float64, err error)](#Mode) +* [func Ncr(n, r int) int](#Ncr) +* [func NormBoxMullerRvs(loc float64, scale float64, size int) []float64](#NormBoxMullerRvs) +* [func NormCdf(x float64, loc float64, scale float64) float64](#NormCdf) +* [func NormEntropy(loc float64, scale float64) float64](#NormEntropy) +* [func NormFit(data []float64) [2]float64](#NormFit) +* [func NormInterval(alpha float64, loc float64, scale float64) [2]float64](#NormInterval) +* [func NormIsf(p float64, loc float64, scale float64) (x float64)](#NormIsf) +* [func NormLogCdf(x float64, loc float64, scale float64) float64](#NormLogCdf) +* [func NormLogPdf(x float64, loc float64, scale float64) float64](#NormLogPdf) +* [func NormLogSf(x float64, loc float64, scale float64) float64](#NormLogSf) +* [func NormMean(loc float64, scale float64) float64](#NormMean) +* [func NormMedian(loc float64, scale float64) float64](#NormMedian) +* [func NormMoment(n int, loc float64, scale float64) float64](#NormMoment) +* [func NormPdf(x float64, loc float64, scale float64) float64](#NormPdf) +* [func NormPpf(p float64, loc float64, scale float64) (x float64)](#NormPpf) +* [func NormPpfRvs(loc float64, scale float64, size int) []float64](#NormPpfRvs) +* [func NormSf(x float64, loc float64, scale float64) float64](#NormSf) +* [func NormStats(loc float64, scale float64, moments string) []float64](#NormStats) +* [func NormStd(loc float64, scale float64) float64](#NormStd) +* [func NormVar(loc float64, scale float64) float64](#NormVar) +* [func Pearson(data1, data2 Float64Data) (float64, error)](#Pearson) +* [func Percentile(input Float64Data, percent float64) (percentile float64, err error)](#Percentile) +* [func PercentileNearestRank(input Float64Data, percent float64) (percentile float64, err error)](#PercentileNearestRank) +* [func PopulationVariance(input Float64Data) (pvar float64, err error)](#PopulationVariance) +* [func ProbGeom(a int, b int, p float64) (prob float64, err error)](#ProbGeom) +* [func Round(input float64, places int) (rounded float64, err error)](#Round) +* [func Sample(input Float64Data, takenum int, replacement bool) ([]float64, error)](#Sample) +* [func SampleVariance(input Float64Data) (svar float64, err error)](#SampleVariance) +* [func Sigmoid(input Float64Data) ([]float64, error)](#Sigmoid) +* [func SoftMax(input Float64Data) ([]float64, error)](#SoftMax) +* [func StableSample(input Float64Data, takenum int) ([]float64, error)](#StableSample) +* [func StandardDeviation(input Float64Data) (sdev float64, err error)](#StandardDeviation) +* [func StandardDeviationPopulation(input Float64Data) (sdev float64, err error)](#StandardDeviationPopulation) +* [func StandardDeviationSample(input Float64Data) (sdev float64, err error)](#StandardDeviationSample) +* [func StdDevP(input Float64Data) (sdev float64, err error)](#StdDevP) +* [func StdDevS(input Float64Data) (sdev float64, err error)](#StdDevS) +* [func Sum(input Float64Data) (sum float64, err error)](#Sum) +* [func Trimean(input Float64Data) (float64, error)](#Trimean) +* [func VarGeom(p float64) (exp float64, err error)](#VarGeom) +* [func VarP(input Float64Data) (sdev float64, err error)](#VarP) +* [func VarS(input Float64Data) (sdev float64, err error)](#VarS) +* [func Variance(input Float64Data) (sdev float64, err error)](#Variance) +* [type Coordinate](#Coordinate) + * [func ExpReg(s []Coordinate) (regressions []Coordinate, err error)](#ExpReg) + * [func LinReg(s []Coordinate) (regressions []Coordinate, err error)](#LinReg) + * [func LogReg(s []Coordinate) (regressions []Coordinate, err error)](#LogReg) +* [type Float64Data](#Float64Data) + * [func LoadRawData(raw interface{}) (f Float64Data)](#LoadRawData) + * [func (f Float64Data) AutoCorrelation(lags int) (float64, error)](#Float64Data.AutoCorrelation) + * [func (f Float64Data) Correlation(d Float64Data) (float64, error)](#Float64Data.Correlation) + * [func (f Float64Data) Covariance(d Float64Data) (float64, error)](#Float64Data.Covariance) + * [func (f Float64Data) CovariancePopulation(d Float64Data) (float64, error)](#Float64Data.CovariancePopulation) + * [func (f Float64Data) CumulativeSum() ([]float64, error)](#Float64Data.CumulativeSum) + * [func (f Float64Data) Entropy() (float64, error)](#Float64Data.Entropy) + * [func (f Float64Data) GeometricMean() (float64, error)](#Float64Data.GeometricMean) + * [func (f Float64Data) Get(i int) float64](#Float64Data.Get) + * [func (f Float64Data) HarmonicMean() (float64, error)](#Float64Data.HarmonicMean) + * [func (f Float64Data) InterQuartileRange() (float64, error)](#Float64Data.InterQuartileRange) + * [func (f Float64Data) Len() int](#Float64Data.Len) + * [func (f Float64Data) Less(i, j int) bool](#Float64Data.Less) + * [func (f Float64Data) Max() (float64, error)](#Float64Data.Max) + * [func (f Float64Data) Mean() (float64, error)](#Float64Data.Mean) + * [func (f Float64Data) Median() (float64, error)](#Float64Data.Median) + * [func (f Float64Data) MedianAbsoluteDeviation() (float64, error)](#Float64Data.MedianAbsoluteDeviation) + * [func (f Float64Data) MedianAbsoluteDeviationPopulation() (float64, error)](#Float64Data.MedianAbsoluteDeviationPopulation) + * [func (f Float64Data) Midhinge(d Float64Data) (float64, error)](#Float64Data.Midhinge) + * [func (f Float64Data) Min() (float64, error)](#Float64Data.Min) + * [func (f Float64Data) Mode() ([]float64, error)](#Float64Data.Mode) + * [func (f Float64Data) Pearson(d Float64Data) (float64, error)](#Float64Data.Pearson) + * [func (f Float64Data) Percentile(p float64) (float64, error)](#Float64Data.Percentile) + * [func (f Float64Data) PercentileNearestRank(p float64) (float64, error)](#Float64Data.PercentileNearestRank) + * [func (f Float64Data) PopulationVariance() (float64, error)](#Float64Data.PopulationVariance) + * [func (f Float64Data) Quartile(d Float64Data) (Quartiles, error)](#Float64Data.Quartile) + * [func (f Float64Data) QuartileOutliers() (Outliers, error)](#Float64Data.QuartileOutliers) + * [func (f Float64Data) Quartiles() (Quartiles, error)](#Float64Data.Quartiles) + * [func (f Float64Data) Sample(n int, r bool) ([]float64, error)](#Float64Data.Sample) + * [func (f Float64Data) SampleVariance() (float64, error)](#Float64Data.SampleVariance) + * [func (f Float64Data) Sigmoid() ([]float64, error)](#Float64Data.Sigmoid) + * [func (f Float64Data) SoftMax() ([]float64, error)](#Float64Data.SoftMax) + * [func (f Float64Data) StandardDeviation() (float64, error)](#Float64Data.StandardDeviation) + * [func (f Float64Data) StandardDeviationPopulation() (float64, error)](#Float64Data.StandardDeviationPopulation) + * [func (f Float64Data) StandardDeviationSample() (float64, error)](#Float64Data.StandardDeviationSample) + * [func (f Float64Data) Sum() (float64, error)](#Float64Data.Sum) + * [func (f Float64Data) Swap(i, j int)](#Float64Data.Swap) + * [func (f Float64Data) Trimean(d Float64Data) (float64, error)](#Float64Data.Trimean) + * [func (f Float64Data) Variance() (float64, error)](#Float64Data.Variance) +* [type Outliers](#Outliers) + * [func QuartileOutliers(input Float64Data) (Outliers, error)](#QuartileOutliers) +* [type Quartiles](#Quartiles) + * [func Quartile(input Float64Data) (Quartiles, error)](#Quartile) +* [type Series](#Series) + * [func ExponentialRegression(s Series) (regressions Series, err error)](#ExponentialRegression) + * [func LinearRegression(s Series) (regressions Series, err error)](#LinearRegression) + * [func LogarithmicRegression(s Series) (regressions Series, err error)](#LogarithmicRegression) + +#### Examples +* [AutoCorrelation](#example_AutoCorrelation) +* [ChebyshevDistance](#example_ChebyshevDistance) +* [Correlation](#example_Correlation) +* [CumulativeSum](#example_CumulativeSum) +* [Entropy](#example_Entropy) +* [ExpGeom](#example_ExpGeom) +* [LinearRegression](#example_LinearRegression) +* [LoadRawData](#example_LoadRawData) +* [Max](#example_Max) +* [Median](#example_Median) +* [Min](#example_Min) +* [ProbGeom](#example_ProbGeom) +* [Round](#example_Round) +* [Sigmoid](#example_Sigmoid) +* [SoftMax](#example_SoftMax) +* [Sum](#example_Sum) +* [VarGeom](#example_VarGeom) + +#### Package files +[correlation.go](/src/github.com/montanaflynn/stats/correlation.go) [cumulative_sum.go](/src/github.com/montanaflynn/stats/cumulative_sum.go) [data.go](/src/github.com/montanaflynn/stats/data.go) [deviation.go](/src/github.com/montanaflynn/stats/deviation.go) [distances.go](/src/github.com/montanaflynn/stats/distances.go) [doc.go](/src/github.com/montanaflynn/stats/doc.go) [entropy.go](/src/github.com/montanaflynn/stats/entropy.go) [errors.go](/src/github.com/montanaflynn/stats/errors.go) [geometric_distribution.go](/src/github.com/montanaflynn/stats/geometric_distribution.go) [legacy.go](/src/github.com/montanaflynn/stats/legacy.go) [load.go](/src/github.com/montanaflynn/stats/load.go) [max.go](/src/github.com/montanaflynn/stats/max.go) [mean.go](/src/github.com/montanaflynn/stats/mean.go) [median.go](/src/github.com/montanaflynn/stats/median.go) [min.go](/src/github.com/montanaflynn/stats/min.go) [mode.go](/src/github.com/montanaflynn/stats/mode.go) [norm.go](/src/github.com/montanaflynn/stats/norm.go) [outlier.go](/src/github.com/montanaflynn/stats/outlier.go) [percentile.go](/src/github.com/montanaflynn/stats/percentile.go) [quartile.go](/src/github.com/montanaflynn/stats/quartile.go) [ranksum.go](/src/github.com/montanaflynn/stats/ranksum.go) [regression.go](/src/github.com/montanaflynn/stats/regression.go) [round.go](/src/github.com/montanaflynn/stats/round.go) [sample.go](/src/github.com/montanaflynn/stats/sample.go) [sigmoid.go](/src/github.com/montanaflynn/stats/sigmoid.go) [softmax.go](/src/github.com/montanaflynn/stats/softmax.go) [sum.go](/src/github.com/montanaflynn/stats/sum.go) [util.go](/src/github.com/montanaflynn/stats/util.go) [variance.go](/src/github.com/montanaflynn/stats/variance.go) + + + +## Variables +``` go +var ( + // ErrEmptyInput Input must not be empty + ErrEmptyInput = statsError{"Input must not be empty."} + // ErrNaN Not a number + ErrNaN = statsError{"Not a number."} + // ErrNegative Must not contain negative values + ErrNegative = statsError{"Must not contain negative values."} + // ErrZero Must not contain zero values + ErrZero = statsError{"Must not contain zero values."} + // ErrBounds Input is outside of range + ErrBounds = statsError{"Input is outside of range."} + // ErrSize Must be the same length + ErrSize = statsError{"Must be the same length."} + // ErrInfValue Value is infinite + ErrInfValue = statsError{"Value is infinite."} + // ErrYCoord Y Value must be greater than zero + ErrYCoord = statsError{"Y Value must be greater than zero."} +) +``` +These are the package-wide error values. +All error identification should use these values. +https://github.com/golang/go/wiki/Errors#naming + +``` go +var ( + EmptyInputErr = ErrEmptyInput + NaNErr = ErrNaN + NegativeErr = ErrNegative + ZeroErr = ErrZero + BoundsErr = ErrBounds + SizeErr = ErrSize + InfValue = ErrInfValue + YCoordErr = ErrYCoord + EmptyInput = ErrEmptyInput +) +``` +Legacy error names that didn't start with Err + + + +## func [AutoCorrelation](/correlation.go?s=853:918#L38) +``` go +func AutoCorrelation(data Float64Data, lags int) (float64, error) +``` +AutoCorrelation is the correlation of a signal with a delayed copy of itself as a function of delay + + + +## func [ChebyshevDistance](/distances.go?s=368:456#L20) +``` go +func ChebyshevDistance(dataPointX, dataPointY Float64Data) (distance float64, err error) +``` +ChebyshevDistance computes the Chebyshev distance between two data sets + + + +## func [Correlation](/correlation.go?s=112:171#L8) +``` go +func Correlation(data1, data2 Float64Data) (float64, error) +``` +Correlation describes the degree of relationship between two sets of data + + + +## func [Covariance](/variance.go?s=1284:1342#L53) +``` go +func Covariance(data1, data2 Float64Data) (float64, error) +``` +Covariance is a measure of how much two sets of data change + + + +## func [CovariancePopulation](/variance.go?s=1864:1932#L81) +``` go +func CovariancePopulation(data1, data2 Float64Data) (float64, error) +``` +CovariancePopulation computes covariance for entire population between two variables. + + + +## func [CumulativeSum](/cumulative_sum.go?s=81:137#L4) +``` go +func CumulativeSum(input Float64Data) ([]float64, error) +``` +CumulativeSum calculates the cumulative sum of the input slice + + + +## func [Entropy](/entropy.go?s=77:125#L6) +``` go +func Entropy(input Float64Data) (float64, error) +``` +Entropy provides calculation of the entropy + + + +## func [EuclideanDistance](/distances.go?s=836:924#L36) +``` go +func EuclideanDistance(dataPointX, dataPointY Float64Data) (distance float64, err error) +``` +EuclideanDistance computes the Euclidean distance between two data sets + + + +## func [ExpGeom](/geometric_distribution.go?s=652:700#L27) +``` go +func ExpGeom(p float64) (exp float64, err error) +``` +ProbGeom generates the expectation or average number of trials +for a geometric random variable with parameter p + + + +## func [GeometricMean](/mean.go?s=319:373#L18) +``` go +func GeometricMean(input Float64Data) (float64, error) +``` +GeometricMean gets the geometric mean for a slice of numbers + + + +## func [HarmonicMean](/mean.go?s=717:770#L40) +``` go +func HarmonicMean(input Float64Data) (float64, error) +``` +HarmonicMean gets the harmonic mean for a slice of numbers + + + +## func [InterQuartileRange](/quartile.go?s=821:880#L45) +``` go +func InterQuartileRange(input Float64Data) (float64, error) +``` +InterQuartileRange finds the range between Q1 and Q3 + + + +## func [ManhattanDistance](/distances.go?s=1277:1365#L50) +``` go +func ManhattanDistance(dataPointX, dataPointY Float64Data) (distance float64, err error) +``` +ManhattanDistance computes the Manhattan distance between two data sets + + + +## func [Max](/max.go?s=78:130#L8) +``` go +func Max(input Float64Data) (max float64, err error) +``` +Max finds the highest number in a slice + + + +## func [Mean](/mean.go?s=77:122#L6) +``` go +func Mean(input Float64Data) (float64, error) +``` +Mean gets the average of a slice of numbers + + + +## func [Median](/median.go?s=85:143#L6) +``` go +func Median(input Float64Data) (median float64, err error) +``` +Median gets the median number in a slice of numbers + + + +## func [MedianAbsoluteDeviation](/deviation.go?s=125:197#L6) +``` go +func MedianAbsoluteDeviation(input Float64Data) (mad float64, err error) +``` +MedianAbsoluteDeviation finds the median of the absolute deviations from the dataset median + + + +## func [MedianAbsoluteDeviationPopulation](/deviation.go?s=360:442#L11) +``` go +func MedianAbsoluteDeviationPopulation(input Float64Data) (mad float64, err error) +``` +MedianAbsoluteDeviationPopulation finds the median of the absolute deviations from the population median + + + +## func [Midhinge](/quartile.go?s=1075:1124#L55) +``` go +func Midhinge(input Float64Data) (float64, error) +``` +Midhinge finds the average of the first and third quartiles + + + +## func [Min](/min.go?s=78:130#L6) +``` go +func Min(input Float64Data) (min float64, err error) +``` +Min finds the lowest number in a set of data + + + +## func [MinkowskiDistance](/distances.go?s=2152:2256#L75) +``` go +func MinkowskiDistance(dataPointX, dataPointY Float64Data, lambda float64) (distance float64, err error) +``` +MinkowskiDistance computes the Minkowski distance between two data sets + +Arguments: + + + dataPointX: First set of data points + dataPointY: Second set of data points. Length of both data + sets must be equal. + lambda: aka p or city blocks; With lambda = 1 + returned distance is manhattan distance and + lambda = 2; it is euclidean distance. Lambda + reaching to infinite - distance would be chebysev + distance. + +Return: + + + Distance or error + + + +## func [Mode](/mode.go?s=85:141#L4) +``` go +func Mode(input Float64Data) (mode []float64, err error) +``` +Mode gets the mode [most frequent value(s)] of a slice of float64s + + + +## func [Ncr](/norm.go?s=7384:7406#L239) +``` go +func Ncr(n, r int) int +``` +Ncr is an N choose R algorithm. +Aaron Cannon's algorithm. + + + +## func [NormBoxMullerRvs](/norm.go?s=667:736#L23) +``` go +func NormBoxMullerRvs(loc float64, scale float64, size int) []float64 +``` +NormBoxMullerRvs generates random variates using the Box–Muller transform. +For more information please visit: http://mathworld.wolfram.com/Box-MullerTransformation.html + + + +## func [NormCdf](/norm.go?s=1826:1885#L52) +``` go +func NormCdf(x float64, loc float64, scale float64) float64 +``` +NormCdf is the cumulative distribution function. + + + +## func [NormEntropy](/norm.go?s=5773:5825#L180) +``` go +func NormEntropy(loc float64, scale float64) float64 +``` +NormEntropy is the differential entropy of the RV. + + + +## func [NormFit](/norm.go?s=6058:6097#L187) +``` go +func NormFit(data []float64) [2]float64 +``` +NormFit returns the maximum likelihood estimators for the Normal Distribution. +Takes array of float64 values. +Returns array of Mean followed by Standard Deviation. + + + +## func [NormInterval](/norm.go?s=6976:7047#L221) +``` go +func NormInterval(alpha float64, loc float64, scale float64) [2]float64 +``` +NormInterval finds endpoints of the range that contains alpha percent of the distribution. + + + +## func [NormIsf](/norm.go?s=4330:4393#L137) +``` go +func NormIsf(p float64, loc float64, scale float64) (x float64) +``` +NormIsf is the inverse survival function (inverse of sf). + + + +## func [NormLogCdf](/norm.go?s=2016:2078#L57) +``` go +func NormLogCdf(x float64, loc float64, scale float64) float64 +``` +NormLogCdf is the log of the cumulative distribution function. + + + +## func [NormLogPdf](/norm.go?s=1590:1652#L47) +``` go +func NormLogPdf(x float64, loc float64, scale float64) float64 +``` +NormLogPdf is the log of the probability density function. + + + +## func [NormLogSf](/norm.go?s=2423:2484#L67) +``` go +func NormLogSf(x float64, loc float64, scale float64) float64 +``` +NormLogSf is the log of the survival function. + + + +## func [NormMean](/norm.go?s=6560:6609#L206) +``` go +func NormMean(loc float64, scale float64) float64 +``` +NormMean is the mean/expected value of the distribution. + + + +## func [NormMedian](/norm.go?s=6431:6482#L201) +``` go +func NormMedian(loc float64, scale float64) float64 +``` +NormMedian is the median of the distribution. + + + +## func [NormMoment](/norm.go?s=4694:4752#L146) +``` go +func NormMoment(n int, loc float64, scale float64) float64 +``` +NormMoment approximates the non-central (raw) moment of order n. +For more information please visit: https://math.stackexchange.com/questions/1945448/methods-for-finding-raw-moments-of-the-normal-distribution + + + +## func [NormPdf](/norm.go?s=1357:1416#L42) +``` go +func NormPdf(x float64, loc float64, scale float64) float64 +``` +NormPdf is the probability density function. + + + +## func [NormPpf](/norm.go?s=2854:2917#L75) +``` go +func NormPpf(p float64, loc float64, scale float64) (x float64) +``` +NormPpf is the point percentile function. +This is based on Peter John Acklam's inverse normal CDF. +algorithm: http://home.online.no/~pjacklam/notes/invnorm/ (no longer visible). +For more information please visit: https://stackedboxes.org/2017/05/01/acklams-normal-quantile-function/ + + + +## func [NormPpfRvs](/norm.go?s=247:310#L12) +``` go +func NormPpfRvs(loc float64, scale float64, size int) []float64 +``` +NormPpfRvs generates random variates using the Point Percentile Function. +For more information please visit: https://demonstrations.wolfram.com/TheMethodOfInverseTransforms/ + + + +## func [NormSf](/norm.go?s=2250:2308#L62) +``` go +func NormSf(x float64, loc float64, scale float64) float64 +``` +NormSf is the survival function (also defined as 1 - cdf, but sf is sometimes more accurate). + + + +## func [NormStats](/norm.go?s=5277:5345#L162) +``` go +func NormStats(loc float64, scale float64, moments string) []float64 +``` +NormStats returns the mean, variance, skew, and/or kurtosis. +Mean(‘m’), variance(‘v’), skew(‘s’), and/or kurtosis(‘k’). +Takes string containing any of 'mvsk'. +Returns array of m v s k in that order. + + + +## func [NormStd](/norm.go?s=6814:6862#L216) +``` go +func NormStd(loc float64, scale float64) float64 +``` +NormStd is the standard deviation of the distribution. + + + +## func [NormVar](/norm.go?s=6675:6723#L211) +``` go +func NormVar(loc float64, scale float64) float64 +``` +NormVar is the variance of the distribution. + + + +## func [Pearson](/correlation.go?s=655:710#L33) +``` go +func Pearson(data1, data2 Float64Data) (float64, error) +``` +Pearson calculates the Pearson product-moment correlation coefficient between two variables + + + +## func [Percentile](/percentile.go?s=98:181#L8) +``` go +func Percentile(input Float64Data, percent float64) (percentile float64, err error) +``` +Percentile finds the relative standing in a slice of floats + + + +## func [PercentileNearestRank](/percentile.go?s=1079:1173#L54) +``` go +func PercentileNearestRank(input Float64Data, percent float64) (percentile float64, err error) +``` +PercentileNearestRank finds the relative standing in a slice of floats using the Nearest Rank method + + + +## func [PopulationVariance](/variance.go?s=828:896#L31) +``` go +func PopulationVariance(input Float64Data) (pvar float64, err error) +``` +PopulationVariance finds the amount of variance within a population + + + +## func [ProbGeom](/geometric_distribution.go?s=258:322#L10) +``` go +func ProbGeom(a int, b int, p float64) (prob float64, err error) +``` +ProbGeom generates the probability for a geometric random variable +with parameter p to achieve success in the interval of [a, b] trials +See https://en.wikipedia.org/wiki/Geometric_distribution for more information + + + +## func [Round](/round.go?s=88:154#L6) +``` go +func Round(input float64, places int) (rounded float64, err error) +``` +Round a float to a specific decimal place or precision + + + +## func [Sample](/sample.go?s=112:192#L9) +``` go +func Sample(input Float64Data, takenum int, replacement bool) ([]float64, error) +``` +Sample returns sample from input with replacement or without + + + +## func [SampleVariance](/variance.go?s=1058:1122#L42) +``` go +func SampleVariance(input Float64Data) (svar float64, err error) +``` +SampleVariance finds the amount of variance within a sample + + + +## func [Sigmoid](/sigmoid.go?s=228:278#L9) +``` go +func Sigmoid(input Float64Data) ([]float64, error) +``` +Sigmoid returns the input values in the range of -1 to 1 +along the sigmoid or s-shaped curve, commonly used in +machine learning while training neural networks as an +activation function. + + + +## func [SoftMax](/softmax.go?s=206:256#L8) +``` go +func SoftMax(input Float64Data) ([]float64, error) +``` +SoftMax returns the input values in the range of 0 to 1 +with sum of all the probabilities being equal to one. It +is commonly used in machine learning neural networks. + + + +## func [StableSample](/sample.go?s=974:1042#L50) +``` go +func StableSample(input Float64Data, takenum int) ([]float64, error) +``` +StableSample like stable sort, it returns samples from input while keeps the order of original data. + + + +## func [StandardDeviation](/deviation.go?s=695:762#L27) +``` go +func StandardDeviation(input Float64Data) (sdev float64, err error) +``` +StandardDeviation the amount of variation in the dataset + + + +## func [StandardDeviationPopulation](/deviation.go?s=892:969#L32) +``` go +func StandardDeviationPopulation(input Float64Data) (sdev float64, err error) +``` +StandardDeviationPopulation finds the amount of variation from the population + + + +## func [StandardDeviationSample](/deviation.go?s=1250:1323#L46) +``` go +func StandardDeviationSample(input Float64Data) (sdev float64, err error) +``` +StandardDeviationSample finds the amount of variation from a sample + + + +## func [StdDevP](/legacy.go?s=339:396#L14) +``` go +func StdDevP(input Float64Data) (sdev float64, err error) +``` +StdDevP is a shortcut to StandardDeviationPopulation + + + +## func [StdDevS](/legacy.go?s=497:554#L19) +``` go +func StdDevS(input Float64Data) (sdev float64, err error) +``` +StdDevS is a shortcut to StandardDeviationSample + + + +## func [Sum](/sum.go?s=78:130#L6) +``` go +func Sum(input Float64Data) (sum float64, err error) +``` +Sum adds all the numbers of a slice together + + + +## func [Trimean](/quartile.go?s=1320:1368#L65) +``` go +func Trimean(input Float64Data) (float64, error) +``` +Trimean finds the average of the median and the midhinge + + + +## func [VarGeom](/geometric_distribution.go?s=885:933#L37) +``` go +func VarGeom(p float64) (exp float64, err error) +``` +ProbGeom generates the variance for number for a +geometric random variable with parameter p + + + +## func [VarP](/legacy.go?s=59:113#L4) +``` go +func VarP(input Float64Data) (sdev float64, err error) +``` +VarP is a shortcut to PopulationVariance + + + +## func [VarS](/legacy.go?s=193:247#L9) +``` go +func VarS(input Float64Data) (sdev float64, err error) +``` +VarS is a shortcut to SampleVariance + + + +## func [Variance](/variance.go?s=659:717#L26) +``` go +func Variance(input Float64Data) (sdev float64, err error) +``` +Variance the amount of variation in the dataset + + + + +## type [Coordinate](/regression.go?s=143:183#L9) +``` go +type Coordinate struct { + X, Y float64 +} + +``` +Coordinate holds the data in a series + + + + + + + +### func [ExpReg](/legacy.go?s=791:856#L29) +``` go +func ExpReg(s []Coordinate) (regressions []Coordinate, err error) +``` +ExpReg is a shortcut to ExponentialRegression + + +### func [LinReg](/legacy.go?s=643:708#L24) +``` go +func LinReg(s []Coordinate) (regressions []Coordinate, err error) +``` +LinReg is a shortcut to LinearRegression + + +### func [LogReg](/legacy.go?s=944:1009#L34) +``` go +func LogReg(s []Coordinate) (regressions []Coordinate, err error) +``` +LogReg is a shortcut to LogarithmicRegression + + + + + +## type [Float64Data](/data.go?s=80:106#L4) +``` go +type Float64Data []float64 +``` +Float64Data is a named type for []float64 with helper methods + + + + + + + +### func [LoadRawData](/load.go?s=145:194#L12) +``` go +func LoadRawData(raw interface{}) (f Float64Data) +``` +LoadRawData parses and converts a slice of mixed data types to floats + + + + + +### func (Float64Data) [AutoCorrelation](/data.go?s=3257:3320#L91) +``` go +func (f Float64Data) AutoCorrelation(lags int) (float64, error) +``` +AutoCorrelation is the correlation of a signal with a delayed copy of itself as a function of delay + + + + +### func (Float64Data) [Correlation](/data.go?s=3058:3122#L86) +``` go +func (f Float64Data) Correlation(d Float64Data) (float64, error) +``` +Correlation describes the degree of relationship between two sets of data + + + + +### func (Float64Data) [Covariance](/data.go?s=4801:4864#L141) +``` go +func (f Float64Data) Covariance(d Float64Data) (float64, error) +``` +Covariance is a measure of how much two sets of data change + + + + +### func (Float64Data) [CovariancePopulation](/data.go?s=4983:5056#L146) +``` go +func (f Float64Data) CovariancePopulation(d Float64Data) (float64, error) +``` +CovariancePopulation computes covariance for entire population between two variables + + + + +### func (Float64Data) [CumulativeSum](/data.go?s=883:938#L28) +``` go +func (f Float64Data) CumulativeSum() ([]float64, error) +``` +CumulativeSum returns the cumulative sum of the data + + + + +### func (Float64Data) [Entropy](/data.go?s=5480:5527#L162) +``` go +func (f Float64Data) Entropy() (float64, error) +``` +Entropy provides calculation of the entropy + + + + +### func (Float64Data) [GeometricMean](/data.go?s=1332:1385#L40) +``` go +func (f Float64Data) GeometricMean() (float64, error) +``` +GeometricMean returns the median of the data + + + + +### func (Float64Data) [Get](/data.go?s=129:168#L7) +``` go +func (f Float64Data) Get(i int) float64 +``` +Get item in slice + + + + +### func (Float64Data) [HarmonicMean](/data.go?s=1460:1512#L43) +``` go +func (f Float64Data) HarmonicMean() (float64, error) +``` +HarmonicMean returns the mode of the data + + + + +### func (Float64Data) [InterQuartileRange](/data.go?s=3755:3813#L106) +``` go +func (f Float64Data) InterQuartileRange() (float64, error) +``` +InterQuartileRange finds the range between Q1 and Q3 + + + + +### func (Float64Data) [Len](/data.go?s=217:247#L10) +``` go +func (f Float64Data) Len() int +``` +Len returns length of slice + + + + +### func (Float64Data) [Less](/data.go?s=318:358#L13) +``` go +func (f Float64Data) Less(i, j int) bool +``` +Less returns if one number is less than another + + + + +### func (Float64Data) [Max](/data.go?s=645:688#L22) +``` go +func (f Float64Data) Max() (float64, error) +``` +Max returns the maximum number in the data + + + + +### func (Float64Data) [Mean](/data.go?s=1005:1049#L31) +``` go +func (f Float64Data) Mean() (float64, error) +``` +Mean returns the mean of the data + + + + +### func (Float64Data) [Median](/data.go?s=1111:1157#L34) +``` go +func (f Float64Data) Median() (float64, error) +``` +Median returns the median of the data + + + + +### func (Float64Data) [MedianAbsoluteDeviation](/data.go?s=1630:1693#L46) +``` go +func (f Float64Data) MedianAbsoluteDeviation() (float64, error) +``` +MedianAbsoluteDeviation the median of the absolute deviations from the dataset median + + + + +### func (Float64Data) [MedianAbsoluteDeviationPopulation](/data.go?s=1842:1915#L51) +``` go +func (f Float64Data) MedianAbsoluteDeviationPopulation() (float64, error) +``` +MedianAbsoluteDeviationPopulation finds the median of the absolute deviations from the population median + + + + +### func (Float64Data) [Midhinge](/data.go?s=3912:3973#L111) +``` go +func (f Float64Data) Midhinge(d Float64Data) (float64, error) +``` +Midhinge finds the average of the first and third quartiles + + + + +### func (Float64Data) [Min](/data.go?s=536:579#L19) +``` go +func (f Float64Data) Min() (float64, error) +``` +Min returns the minimum number in the data + + + + +### func (Float64Data) [Mode](/data.go?s=1217:1263#L37) +``` go +func (f Float64Data) Mode() ([]float64, error) +``` +Mode returns the mode of the data + + + + +### func (Float64Data) [Pearson](/data.go?s=3455:3515#L96) +``` go +func (f Float64Data) Pearson(d Float64Data) (float64, error) +``` +Pearson calculates the Pearson product-moment correlation coefficient between two variables. + + + + +### func (Float64Data) [Percentile](/data.go?s=2696:2755#L76) +``` go +func (f Float64Data) Percentile(p float64) (float64, error) +``` +Percentile finds the relative standing in a slice of floats + + + + +### func (Float64Data) [PercentileNearestRank](/data.go?s=2869:2939#L81) +``` go +func (f Float64Data) PercentileNearestRank(p float64) (float64, error) +``` +PercentileNearestRank finds the relative standing using the Nearest Rank method + + + + +### func (Float64Data) [PopulationVariance](/data.go?s=4495:4553#L131) +``` go +func (f Float64Data) PopulationVariance() (float64, error) +``` +PopulationVariance finds the amount of variance within a population + + + + +### func (Float64Data) [Quartile](/data.go?s=3610:3673#L101) +``` go +func (f Float64Data) Quartile(d Float64Data) (Quartiles, error) +``` +Quartile returns the three quartile points from a slice of data + + + + +### func (Float64Data) [QuartileOutliers](/data.go?s=2542:2599#L71) +``` go +func (f Float64Data) QuartileOutliers() (Outliers, error) +``` +QuartileOutliers finds the mild and extreme outliers + + + + +### func (Float64Data) [Quartiles](/data.go?s=5628:5679#L167) +``` go +func (f Float64Data) Quartiles() (Quartiles, error) +``` +Quartiles returns the three quartile points from instance of Float64Data + + + + +### func (Float64Data) [Sample](/data.go?s=4208:4269#L121) +``` go +func (f Float64Data) Sample(n int, r bool) ([]float64, error) +``` +Sample returns sample from input with replacement or without + + + + +### func (Float64Data) [SampleVariance](/data.go?s=4652:4706#L136) +``` go +func (f Float64Data) SampleVariance() (float64, error) +``` +SampleVariance finds the amount of variance within a sample + + + + +### func (Float64Data) [Sigmoid](/data.go?s=5169:5218#L151) +``` go +func (f Float64Data) Sigmoid() ([]float64, error) +``` +Sigmoid returns the input values along the sigmoid or s-shaped curve + + + + +### func (Float64Data) [SoftMax](/data.go?s=5359:5408#L157) +``` go +func (f Float64Data) SoftMax() ([]float64, error) +``` +SoftMax returns the input values in the range of 0 to 1 +with sum of all the probabilities being equal to one. + + + + +### func (Float64Data) [StandardDeviation](/data.go?s=2026:2083#L56) +``` go +func (f Float64Data) StandardDeviation() (float64, error) +``` +StandardDeviation the amount of variation in the dataset + + + + +### func (Float64Data) [StandardDeviationPopulation](/data.go?s=2199:2266#L61) +``` go +func (f Float64Data) StandardDeviationPopulation() (float64, error) +``` +StandardDeviationPopulation finds the amount of variation from the population + + + + +### func (Float64Data) [StandardDeviationSample](/data.go?s=2382:2445#L66) +``` go +func (f Float64Data) StandardDeviationSample() (float64, error) +``` +StandardDeviationSample finds the amount of variation from a sample + + + + +### func (Float64Data) [Sum](/data.go?s=764:807#L25) +``` go +func (f Float64Data) Sum() (float64, error) +``` +Sum returns the total of all the numbers in the data + + + + +### func (Float64Data) [Swap](/data.go?s=425:460#L16) +``` go +func (f Float64Data) Swap(i, j int) +``` +Swap switches out two numbers in slice + + + + +### func (Float64Data) [Trimean](/data.go?s=4059:4119#L116) +``` go +func (f Float64Data) Trimean(d Float64Data) (float64, error) +``` +Trimean finds the average of the median and the midhinge + + + + +### func (Float64Data) [Variance](/data.go?s=4350:4398#L126) +``` go +func (f Float64Data) Variance() (float64, error) +``` +Variance the amount of variation in the dataset + + + + +## type [Outliers](/outlier.go?s=73:139#L4) +``` go +type Outliers struct { + Mild Float64Data + Extreme Float64Data +} + +``` +Outliers holds mild and extreme outliers found in data + + + + + + + +### func [QuartileOutliers](/outlier.go?s=197:255#L10) +``` go +func QuartileOutliers(input Float64Data) (Outliers, error) +``` +QuartileOutliers finds the mild and extreme outliers + + + + + +## type [Quartiles](/quartile.go?s=75:136#L6) +``` go +type Quartiles struct { + Q1 float64 + Q2 float64 + Q3 float64 +} + +``` +Quartiles holds the three quartile points + + + + + + + +### func [Quartile](/quartile.go?s=205:256#L13) +``` go +func Quartile(input Float64Data) (Quartiles, error) +``` +Quartile returns the three quartile points from a slice of data + + + + + +## type [Series](/regression.go?s=76:100#L6) +``` go +type Series []Coordinate +``` +Series is a container for a series of data + + + + + + + +### func [ExponentialRegression](/regression.go?s=1089:1157#L50) +``` go +func ExponentialRegression(s Series) (regressions Series, err error) +``` +ExponentialRegression returns an exponential regression on data series + + +### func [LinearRegression](/regression.go?s=262:325#L14) +``` go +func LinearRegression(s Series) (regressions Series, err error) +``` +LinearRegression finds the least squares linear regression on data series + + +### func [LogarithmicRegression](/regression.go?s=1903:1971#L85) +``` go +func LogarithmicRegression(s Series) (regressions Series, err error) +``` +LogarithmicRegression returns an logarithmic regression on data series + + + + + + + + + +- - - +Generated by [godoc2md](http://godoc.org/github.com/davecheney/godoc2md) diff --git a/vendor/github.com/go-stack/stack/LICENSE.md b/vendor/github.com/montanaflynn/stats/LICENSE similarity index 94% rename from vendor/github.com/go-stack/stack/LICENSE.md rename to vendor/github.com/montanaflynn/stats/LICENSE index 2abf98e..3162cb1 100644 --- a/vendor/github.com/go-stack/stack/LICENSE.md +++ b/vendor/github.com/montanaflynn/stats/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2014 Chris Hines +Copyright (c) 2014-2023 Montana Flynn (https://montanaflynn.com) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/vendor/github.com/montanaflynn/stats/Makefile b/vendor/github.com/montanaflynn/stats/Makefile new file mode 100644 index 0000000..969df12 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/Makefile @@ -0,0 +1,34 @@ +.PHONY: all + +default: test lint + +format: + go fmt . + +test: + go test -race + +check: format test + +benchmark: + go test -bench=. -benchmem + +coverage: + go test -coverprofile=coverage.out + go tool cover -html="coverage.out" + +lint: format + golangci-lint run . + +docs: + godoc2md github.com/montanaflynn/stats | sed -e s#src/target/##g > DOCUMENTATION.md + +release: + git-chglog --output CHANGELOG.md --next-tag ${TAG} + git add CHANGELOG.md + git commit -m "Update changelog with ${TAG} changes" + git tag ${TAG} + git-chglog $(TAG) | tail -n +4 | gsed '1s/^/$(TAG)\n/gm' > release-notes.txt + git push origin master ${TAG} + hub release create --copy -F release-notes.txt ${TAG} + diff --git a/vendor/github.com/montanaflynn/stats/README.md b/vendor/github.com/montanaflynn/stats/README.md new file mode 100644 index 0000000..9c18890 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/README.md @@ -0,0 +1,237 @@ +# Stats - Golang Statistics Package + +[![][action-svg]][action-url] [![][codecov-svg]][codecov-url] [![][goreport-svg]][goreport-url] [![][godoc-svg]][godoc-url] [![][pkggodev-svg]][pkggodev-url] [![][license-svg]][license-url] + +A well tested and comprehensive Golang statistics library / package / module with no dependencies. + +If you have any suggestions, problems or bug reports please [create an issue](https://github.com/montanaflynn/stats/issues) and I'll do my best to accommodate you. In addition simply starring the repo would show your support for the project and be very much appreciated! + +## Installation + +``` +go get github.com/montanaflynn/stats +``` + +## Example Usage + +All the functions can be seen in [examples/main.go](examples/main.go) but here's a little taste: + +```go +// start with some source data to use +data := []float64{1.0, 2.1, 3.2, 4.823, 4.1, 5.8} + +// you could also use different types like this +// data := stats.LoadRawData([]int{1, 2, 3, 4, 5}) +// data := stats.LoadRawData([]interface{}{1.1, "2", 3}) +// etc... + +median, _ := stats.Median(data) +fmt.Println(median) // 3.65 + +roundedMedian, _ := stats.Round(median, 0) +fmt.Println(roundedMedian) // 4 +``` + +## Documentation + +The entire API documentation is available on [GoDoc.org](http://godoc.org/github.com/montanaflynn/stats) or [pkg.go.dev](https://pkg.go.dev/github.com/montanaflynn/stats). + +You can also view docs offline with the following commands: + +``` +# Command line +godoc . # show all exported apis +godoc . Median # show a single function +godoc -ex . Round # show function with example +godoc . Float64Data # show the type and methods + +# Local website +godoc -http=:4444 # start the godoc server on port 4444 +open http://localhost:4444/pkg/github.com/montanaflynn/stats/ +``` + +The exported API is as follows: + +```go +var ( + ErrEmptyInput = statsError{"Input must not be empty."} + ErrNaN = statsError{"Not a number."} + ErrNegative = statsError{"Must not contain negative values."} + ErrZero = statsError{"Must not contain zero values."} + ErrBounds = statsError{"Input is outside of range."} + ErrSize = statsError{"Must be the same length."} + ErrInfValue = statsError{"Value is infinite."} + ErrYCoord = statsError{"Y Value must be greater than zero."} +) + +func Round(input float64, places int) (rounded float64, err error) {} + +type Float64Data []float64 + +func LoadRawData(raw interface{}) (f Float64Data) {} + +func AutoCorrelation(data Float64Data, lags int) (float64, error) {} +func ChebyshevDistance(dataPointX, dataPointY Float64Data) (distance float64, err error) {} +func Correlation(data1, data2 Float64Data) (float64, error) {} +func Covariance(data1, data2 Float64Data) (float64, error) {} +func CovariancePopulation(data1, data2 Float64Data) (float64, error) {} +func CumulativeSum(input Float64Data) ([]float64, error) {} +func Describe(input Float64Data, allowNaN bool, percentiles *[]float64) (*Description, error) {} +func DescribePercentileFunc(input Float64Data, allowNaN bool, percentiles *[]float64, percentileFunc func(Float64Data, float64) (float64, error)) (*Description, error) {} +func Entropy(input Float64Data) (float64, error) {} +func EuclideanDistance(dataPointX, dataPointY Float64Data) (distance float64, err error) {} +func GeometricMean(input Float64Data) (float64, error) {} +func HarmonicMean(input Float64Data) (float64, error) {} +func InterQuartileRange(input Float64Data) (float64, error) {} +func ManhattanDistance(dataPointX, dataPointY Float64Data) (distance float64, err error) {} +func Max(input Float64Data) (max float64, err error) {} +func Mean(input Float64Data) (float64, error) {} +func Median(input Float64Data) (median float64, err error) {} +func MedianAbsoluteDeviation(input Float64Data) (mad float64, err error) {} +func MedianAbsoluteDeviationPopulation(input Float64Data) (mad float64, err error) {} +func Midhinge(input Float64Data) (float64, error) {} +func Min(input Float64Data) (min float64, err error) {} +func MinkowskiDistance(dataPointX, dataPointY Float64Data, lambda float64) (distance float64, err error) {} +func Mode(input Float64Data) (mode []float64, err error) {} +func NormBoxMullerRvs(loc float64, scale float64, size int) []float64 {} +func NormCdf(x float64, loc float64, scale float64) float64 {} +func NormEntropy(loc float64, scale float64) float64 {} +func NormFit(data []float64) [2]float64{} +func NormInterval(alpha float64, loc float64, scale float64 ) [2]float64 {} +func NormIsf(p float64, loc float64, scale float64) (x float64) {} +func NormLogCdf(x float64, loc float64, scale float64) float64 {} +func NormLogPdf(x float64, loc float64, scale float64) float64 {} +func NormLogSf(x float64, loc float64, scale float64) float64 {} +func NormMean(loc float64, scale float64) float64 {} +func NormMedian(loc float64, scale float64) float64 {} +func NormMoment(n int, loc float64, scale float64) float64 {} +func NormPdf(x float64, loc float64, scale float64) float64 {} +func NormPpf(p float64, loc float64, scale float64) (x float64) {} +func NormPpfRvs(loc float64, scale float64, size int) []float64 {} +func NormSf(x float64, loc float64, scale float64) float64 {} +func NormStats(loc float64, scale float64, moments string) []float64 {} +func NormStd(loc float64, scale float64) float64 {} +func NormVar(loc float64, scale float64) float64 {} +func Pearson(data1, data2 Float64Data) (float64, error) {} +func Percentile(input Float64Data, percent float64) (percentile float64, err error) {} +func PercentileNearestRank(input Float64Data, percent float64) (percentile float64, err error) {} +func PopulationVariance(input Float64Data) (pvar float64, err error) {} +func Sample(input Float64Data, takenum int, replacement bool) ([]float64, error) {} +func SampleVariance(input Float64Data) (svar float64, err error) {} +func Sigmoid(input Float64Data) ([]float64, error) {} +func SoftMax(input Float64Data) ([]float64, error) {} +func StableSample(input Float64Data, takenum int) ([]float64, error) {} +func StandardDeviation(input Float64Data) (sdev float64, err error) {} +func StandardDeviationPopulation(input Float64Data) (sdev float64, err error) {} +func StandardDeviationSample(input Float64Data) (sdev float64, err error) {} +func StdDevP(input Float64Data) (sdev float64, err error) {} +func StdDevS(input Float64Data) (sdev float64, err error) {} +func Sum(input Float64Data) (sum float64, err error) {} +func Trimean(input Float64Data) (float64, error) {} +func VarP(input Float64Data) (sdev float64, err error) {} +func VarS(input Float64Data) (sdev float64, err error) {} +func Variance(input Float64Data) (sdev float64, err error) {} +func ProbGeom(a int, b int, p float64) (prob float64, err error) {} +func ExpGeom(p float64) (exp float64, err error) {} +func VarGeom(p float64) (exp float64, err error) {} + +type Coordinate struct { + X, Y float64 +} + +type Series []Coordinate + +func ExponentialRegression(s Series) (regressions Series, err error) {} +func LinearRegression(s Series) (regressions Series, err error) {} +func LogarithmicRegression(s Series) (regressions Series, err error) {} + +type Outliers struct { + Mild Float64Data + Extreme Float64Data +} + +type Quartiles struct { + Q1 float64 + Q2 float64 + Q3 float64 +} + +func Quartile(input Float64Data) (Quartiles, error) {} +func QuartileOutliers(input Float64Data) (Outliers, error) {} +``` + +## Contributing + +Pull request are always welcome no matter how big or small. I've included a [Makefile](https://github.com/montanaflynn/stats/blob/master/Makefile) that has a lot of helper targets for common actions such as linting, testing, code coverage reporting and more. + +1. Fork the repo and clone your fork +2. Create new branch (`git checkout -b some-thing`) +3. Make the desired changes +4. Ensure tests pass (`go test -cover` or `make test`) +5. Run lint and fix problems (`go vet .` or `make lint`) +6. Commit changes (`git commit -am 'Did something'`) +7. Push branch (`git push origin some-thing`) +8. Submit pull request + +To make things as seamless as possible please also consider the following steps: + +- Update `examples/main.go` with a simple example of the new feature +- Update `README.md` documentation section with any new exported API +- Keep 100% code coverage (you can check with `make coverage`) +- Squash commits into single units of work with `git rebase -i new-feature` + +## Releasing + +This is not required by contributors and mostly here as a reminder to myself as the maintainer of this repo. To release a new version we should update the [CHANGELOG.md](/CHANGELOG.md) and [DOCUMENTATION.md](/DOCUMENTATION.md). + +First install the tools used to generate the markdown files and release: + +``` +go install github.com/davecheney/godoc2md@latest +go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest +brew tap git-chglog/git-chglog +brew install gnu-sed hub git-chglog +``` + +Then you can run these `make` directives: + +``` +# Generate DOCUMENTATION.md +make docs +``` + +Then we can create a [CHANGELOG.md](/CHANGELOG.md) a new git tag and a github release: + +``` +make release TAG=v0.x.x +``` + +To authenticate `hub` for the release you will need to create a personal access token and use it as the password when it's requested. + +## MIT License + +Copyright (c) 2014-2023 Montana Flynn (https://montanaflynn.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORpublicS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +[action-url]: https://github.com/montanaflynn/stats/actions +[action-svg]: https://img.shields.io/github/actions/workflow/status/montanaflynn/stats/go.yml + +[codecov-url]: https://app.codecov.io/gh/montanaflynn/stats +[codecov-svg]: https://img.shields.io/codecov/c/github/montanaflynn/stats?token=wnw8dActnH + +[goreport-url]: https://goreportcard.com/report/github.com/montanaflynn/stats +[goreport-svg]: https://goreportcard.com/badge/github.com/montanaflynn/stats + +[godoc-url]: https://godoc.org/github.com/montanaflynn/stats +[godoc-svg]: https://godoc.org/github.com/montanaflynn/stats?status.svg + +[pkggodev-url]: https://pkg.go.dev/github.com/montanaflynn/stats +[pkggodev-svg]: https://gistcdn.githack.com/montanaflynn/b02f1d78d8c0de8435895d7e7cd0d473/raw/17f2a5a69f1323ecd42c00e0683655da96d9ecc8/badge.svg + +[license-url]: https://github.com/montanaflynn/stats/blob/master/LICENSE +[license-svg]: https://img.shields.io/badge/license-MIT-blue.svg diff --git a/vendor/github.com/montanaflynn/stats/correlation.go b/vendor/github.com/montanaflynn/stats/correlation.go new file mode 100644 index 0000000..4acab94 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/correlation.go @@ -0,0 +1,60 @@ +package stats + +import ( + "math" +) + +// Correlation describes the degree of relationship between two sets of data +func Correlation(data1, data2 Float64Data) (float64, error) { + + l1 := data1.Len() + l2 := data2.Len() + + if l1 == 0 || l2 == 0 { + return math.NaN(), EmptyInputErr + } + + if l1 != l2 { + return math.NaN(), SizeErr + } + + sdev1, _ := StandardDeviationPopulation(data1) + sdev2, _ := StandardDeviationPopulation(data2) + + if sdev1 == 0 || sdev2 == 0 { + return 0, nil + } + + covp, _ := CovariancePopulation(data1, data2) + return covp / (sdev1 * sdev2), nil +} + +// Pearson calculates the Pearson product-moment correlation coefficient between two variables +func Pearson(data1, data2 Float64Data) (float64, error) { + return Correlation(data1, data2) +} + +// AutoCorrelation is the correlation of a signal with a delayed copy of itself as a function of delay +func AutoCorrelation(data Float64Data, lags int) (float64, error) { + if len(data) < 1 { + return 0, EmptyInputErr + } + + mean, _ := Mean(data) + + var result, q float64 + + for i := 0; i < lags; i++ { + v := (data[0] - mean) * (data[0] - mean) + for i := 1; i < len(data); i++ { + delta0 := data[i-1] - mean + delta1 := data[i] - mean + q += (delta0*delta1 - q) / float64(i+1) + v += (delta1*delta1 - v) / float64(i+1) + } + + result = q / v + } + + return result, nil +} diff --git a/vendor/github.com/montanaflynn/stats/cumulative_sum.go b/vendor/github.com/montanaflynn/stats/cumulative_sum.go new file mode 100644 index 0000000..e5305da --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/cumulative_sum.go @@ -0,0 +1,21 @@ +package stats + +// CumulativeSum calculates the cumulative sum of the input slice +func CumulativeSum(input Float64Data) ([]float64, error) { + + if input.Len() == 0 { + return Float64Data{}, EmptyInput + } + + cumSum := make([]float64, input.Len()) + + for i, val := range input { + if i == 0 { + cumSum[i] = val + } else { + cumSum[i] = cumSum[i-1] + val + } + } + + return cumSum, nil +} diff --git a/vendor/github.com/montanaflynn/stats/data.go b/vendor/github.com/montanaflynn/stats/data.go new file mode 100644 index 0000000..b86f0d8 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/data.go @@ -0,0 +1,169 @@ +package stats + +// Float64Data is a named type for []float64 with helper methods +type Float64Data []float64 + +// Get item in slice +func (f Float64Data) Get(i int) float64 { return f[i] } + +// Len returns length of slice +func (f Float64Data) Len() int { return len(f) } + +// Less returns if one number is less than another +func (f Float64Data) Less(i, j int) bool { return f[i] < f[j] } + +// Swap switches out two numbers in slice +func (f Float64Data) Swap(i, j int) { f[i], f[j] = f[j], f[i] } + +// Min returns the minimum number in the data +func (f Float64Data) Min() (float64, error) { return Min(f) } + +// Max returns the maximum number in the data +func (f Float64Data) Max() (float64, error) { return Max(f) } + +// Sum returns the total of all the numbers in the data +func (f Float64Data) Sum() (float64, error) { return Sum(f) } + +// CumulativeSum returns the cumulative sum of the data +func (f Float64Data) CumulativeSum() ([]float64, error) { return CumulativeSum(f) } + +// Mean returns the mean of the data +func (f Float64Data) Mean() (float64, error) { return Mean(f) } + +// Median returns the median of the data +func (f Float64Data) Median() (float64, error) { return Median(f) } + +// Mode returns the mode of the data +func (f Float64Data) Mode() ([]float64, error) { return Mode(f) } + +// GeometricMean returns the median of the data +func (f Float64Data) GeometricMean() (float64, error) { return GeometricMean(f) } + +// HarmonicMean returns the mode of the data +func (f Float64Data) HarmonicMean() (float64, error) { return HarmonicMean(f) } + +// MedianAbsoluteDeviation the median of the absolute deviations from the dataset median +func (f Float64Data) MedianAbsoluteDeviation() (float64, error) { + return MedianAbsoluteDeviation(f) +} + +// MedianAbsoluteDeviationPopulation finds the median of the absolute deviations from the population median +func (f Float64Data) MedianAbsoluteDeviationPopulation() (float64, error) { + return MedianAbsoluteDeviationPopulation(f) +} + +// StandardDeviation the amount of variation in the dataset +func (f Float64Data) StandardDeviation() (float64, error) { + return StandardDeviation(f) +} + +// StandardDeviationPopulation finds the amount of variation from the population +func (f Float64Data) StandardDeviationPopulation() (float64, error) { + return StandardDeviationPopulation(f) +} + +// StandardDeviationSample finds the amount of variation from a sample +func (f Float64Data) StandardDeviationSample() (float64, error) { + return StandardDeviationSample(f) +} + +// QuartileOutliers finds the mild and extreme outliers +func (f Float64Data) QuartileOutliers() (Outliers, error) { + return QuartileOutliers(f) +} + +// Percentile finds the relative standing in a slice of floats +func (f Float64Data) Percentile(p float64) (float64, error) { + return Percentile(f, p) +} + +// PercentileNearestRank finds the relative standing using the Nearest Rank method +func (f Float64Data) PercentileNearestRank(p float64) (float64, error) { + return PercentileNearestRank(f, p) +} + +// Correlation describes the degree of relationship between two sets of data +func (f Float64Data) Correlation(d Float64Data) (float64, error) { + return Correlation(f, d) +} + +// AutoCorrelation is the correlation of a signal with a delayed copy of itself as a function of delay +func (f Float64Data) AutoCorrelation(lags int) (float64, error) { + return AutoCorrelation(f, lags) +} + +// Pearson calculates the Pearson product-moment correlation coefficient between two variables. +func (f Float64Data) Pearson(d Float64Data) (float64, error) { + return Pearson(f, d) +} + +// Quartile returns the three quartile points from a slice of data +func (f Float64Data) Quartile(d Float64Data) (Quartiles, error) { + return Quartile(d) +} + +// InterQuartileRange finds the range between Q1 and Q3 +func (f Float64Data) InterQuartileRange() (float64, error) { + return InterQuartileRange(f) +} + +// Midhinge finds the average of the first and third quartiles +func (f Float64Data) Midhinge(d Float64Data) (float64, error) { + return Midhinge(d) +} + +// Trimean finds the average of the median and the midhinge +func (f Float64Data) Trimean(d Float64Data) (float64, error) { + return Trimean(d) +} + +// Sample returns sample from input with replacement or without +func (f Float64Data) Sample(n int, r bool) ([]float64, error) { + return Sample(f, n, r) +} + +// Variance the amount of variation in the dataset +func (f Float64Data) Variance() (float64, error) { + return Variance(f) +} + +// PopulationVariance finds the amount of variance within a population +func (f Float64Data) PopulationVariance() (float64, error) { + return PopulationVariance(f) +} + +// SampleVariance finds the amount of variance within a sample +func (f Float64Data) SampleVariance() (float64, error) { + return SampleVariance(f) +} + +// Covariance is a measure of how much two sets of data change +func (f Float64Data) Covariance(d Float64Data) (float64, error) { + return Covariance(f, d) +} + +// CovariancePopulation computes covariance for entire population between two variables +func (f Float64Data) CovariancePopulation(d Float64Data) (float64, error) { + return CovariancePopulation(f, d) +} + +// Sigmoid returns the input values along the sigmoid or s-shaped curve +func (f Float64Data) Sigmoid() ([]float64, error) { + return Sigmoid(f) +} + +// SoftMax returns the input values in the range of 0 to 1 +// with sum of all the probabilities being equal to one. +func (f Float64Data) SoftMax() ([]float64, error) { + return SoftMax(f) +} + +// Entropy provides calculation of the entropy +func (f Float64Data) Entropy() (float64, error) { + return Entropy(f) +} + +// Quartiles returns the three quartile points from instance of Float64Data +func (f Float64Data) Quartiles() (Quartiles, error) { + return Quartile(f) +} diff --git a/vendor/github.com/montanaflynn/stats/describe.go b/vendor/github.com/montanaflynn/stats/describe.go new file mode 100644 index 0000000..86b7242 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/describe.go @@ -0,0 +1,81 @@ +package stats + +import "fmt" + +// Holds information about the dataset provided to Describe +type Description struct { + Count int + Mean float64 + Std float64 + Max float64 + Min float64 + DescriptionPercentiles []descriptionPercentile + AllowedNaN bool +} + +// Specifies percentiles to be computed +type descriptionPercentile struct { + Percentile float64 + Value float64 +} + +// Describe generates descriptive statistics about a provided dataset, similar to python's pandas.describe() +func Describe(input Float64Data, allowNaN bool, percentiles *[]float64) (*Description, error) { + return DescribePercentileFunc(input, allowNaN, percentiles, Percentile) +} + +// Describe generates descriptive statistics about a provided dataset, similar to python's pandas.describe() +// Takes in a function to use for percentile calculation +func DescribePercentileFunc(input Float64Data, allowNaN bool, percentiles *[]float64, percentileFunc func(Float64Data, float64) (float64, error)) (*Description, error) { + var description Description + description.AllowedNaN = allowNaN + description.Count = input.Len() + + if description.Count == 0 && !allowNaN { + return &description, ErrEmptyInput + } + + // Disregard error, since it cannot be thrown if Count is > 0 and allowNaN is false, else NaN is accepted + description.Std, _ = StandardDeviation(input) + description.Max, _ = Max(input) + description.Min, _ = Min(input) + description.Mean, _ = Mean(input) + + if percentiles != nil { + for _, percentile := range *percentiles { + if value, err := percentileFunc(input, percentile); err == nil || allowNaN { + description.DescriptionPercentiles = append(description.DescriptionPercentiles, descriptionPercentile{Percentile: percentile, Value: value}) + } + } + } + + return &description, nil +} + +/* +Represents the Description instance in a string format with specified number of decimals + + count 3 + mean 2.00 + std 0.82 + max 3.00 + min 1.00 + 25.00% NaN + 50.00% 1.50 + 75.00% 2.50 + NaN OK true +*/ +func (d *Description) String(decimals int) string { + var str string + + str += fmt.Sprintf("count\t%d\n", d.Count) + str += fmt.Sprintf("mean\t%.*f\n", decimals, d.Mean) + str += fmt.Sprintf("std\t%.*f\n", decimals, d.Std) + str += fmt.Sprintf("max\t%.*f\n", decimals, d.Max) + str += fmt.Sprintf("min\t%.*f\n", decimals, d.Min) + for _, percentile := range d.DescriptionPercentiles { + str += fmt.Sprintf("%.2f%%\t%.*f\n", percentile.Percentile, decimals, percentile.Value) + } + str += fmt.Sprintf("NaN OK\t%t", d.AllowedNaN) + return str +} diff --git a/vendor/github.com/montanaflynn/stats/deviation.go b/vendor/github.com/montanaflynn/stats/deviation.go new file mode 100644 index 0000000..e69a19f --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/deviation.go @@ -0,0 +1,57 @@ +package stats + +import "math" + +// MedianAbsoluteDeviation finds the median of the absolute deviations from the dataset median +func MedianAbsoluteDeviation(input Float64Data) (mad float64, err error) { + return MedianAbsoluteDeviationPopulation(input) +} + +// MedianAbsoluteDeviationPopulation finds the median of the absolute deviations from the population median +func MedianAbsoluteDeviationPopulation(input Float64Data) (mad float64, err error) { + if input.Len() == 0 { + return math.NaN(), EmptyInputErr + } + + i := copyslice(input) + m, _ := Median(i) + + for key, value := range i { + i[key] = math.Abs(value - m) + } + + return Median(i) +} + +// StandardDeviation the amount of variation in the dataset +func StandardDeviation(input Float64Data) (sdev float64, err error) { + return StandardDeviationPopulation(input) +} + +// StandardDeviationPopulation finds the amount of variation from the population +func StandardDeviationPopulation(input Float64Data) (sdev float64, err error) { + + if input.Len() == 0 { + return math.NaN(), EmptyInputErr + } + + // Get the population variance + vp, _ := PopulationVariance(input) + + // Return the population standard deviation + return math.Sqrt(vp), nil +} + +// StandardDeviationSample finds the amount of variation from a sample +func StandardDeviationSample(input Float64Data) (sdev float64, err error) { + + if input.Len() == 0 { + return math.NaN(), EmptyInputErr + } + + // Get the sample variance + vs, _ := SampleVariance(input) + + // Return the sample standard deviation + return math.Sqrt(vs), nil +} diff --git a/vendor/github.com/montanaflynn/stats/distances.go b/vendor/github.com/montanaflynn/stats/distances.go new file mode 100644 index 0000000..8a6330e --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/distances.go @@ -0,0 +1,91 @@ +package stats + +import ( + "math" +) + +// Validate data for distance calculation +func validateData(dataPointX, dataPointY Float64Data) error { + if len(dataPointX) == 0 || len(dataPointY) == 0 { + return EmptyInputErr + } + + if len(dataPointX) != len(dataPointY) { + return SizeErr + } + return nil +} + +// ChebyshevDistance computes the Chebyshev distance between two data sets +func ChebyshevDistance(dataPointX, dataPointY Float64Data) (distance float64, err error) { + err = validateData(dataPointX, dataPointY) + if err != nil { + return math.NaN(), err + } + var tempDistance float64 + for i := 0; i < len(dataPointY); i++ { + tempDistance = math.Abs(dataPointX[i] - dataPointY[i]) + if distance < tempDistance { + distance = tempDistance + } + } + return distance, nil +} + +// EuclideanDistance computes the Euclidean distance between two data sets +func EuclideanDistance(dataPointX, dataPointY Float64Data) (distance float64, err error) { + + err = validateData(dataPointX, dataPointY) + if err != nil { + return math.NaN(), err + } + distance = 0 + for i := 0; i < len(dataPointX); i++ { + distance = distance + ((dataPointX[i] - dataPointY[i]) * (dataPointX[i] - dataPointY[i])) + } + return math.Sqrt(distance), nil +} + +// ManhattanDistance computes the Manhattan distance between two data sets +func ManhattanDistance(dataPointX, dataPointY Float64Data) (distance float64, err error) { + err = validateData(dataPointX, dataPointY) + if err != nil { + return math.NaN(), err + } + distance = 0 + for i := 0; i < len(dataPointX); i++ { + distance = distance + math.Abs(dataPointX[i]-dataPointY[i]) + } + return distance, nil +} + +// MinkowskiDistance computes the Minkowski distance between two data sets +// +// Arguments: +// +// dataPointX: First set of data points +// dataPointY: Second set of data points. Length of both data +// sets must be equal. +// lambda: aka p or city blocks; With lambda = 1 +// returned distance is manhattan distance and +// lambda = 2; it is euclidean distance. Lambda +// reaching to infinite - distance would be chebysev +// distance. +// +// Return: +// +// Distance or error +func MinkowskiDistance(dataPointX, dataPointY Float64Data, lambda float64) (distance float64, err error) { + err = validateData(dataPointX, dataPointY) + if err != nil { + return math.NaN(), err + } + for i := 0; i < len(dataPointY); i++ { + distance = distance + math.Pow(math.Abs(dataPointX[i]-dataPointY[i]), lambda) + } + distance = math.Pow(distance, 1/lambda) + if math.IsInf(distance, 1) { + return math.NaN(), InfValue + } + return distance, nil +} diff --git a/vendor/github.com/montanaflynn/stats/doc.go b/vendor/github.com/montanaflynn/stats/doc.go new file mode 100644 index 0000000..facb8d5 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/doc.go @@ -0,0 +1,23 @@ +/* +Package stats is a well tested and comprehensive +statistics library package with no dependencies. + +Example Usage: + + // start with some source data to use + data := []float64{1.0, 2.1, 3.2, 4.823, 4.1, 5.8} + + // you could also use different types like this + // data := stats.LoadRawData([]int{1, 2, 3, 4, 5}) + // data := stats.LoadRawData([]interface{}{1.1, "2", 3}) + // etc... + + median, _ := stats.Median(data) + fmt.Println(median) // 3.65 + + roundedMedian, _ := stats.Round(median, 0) + fmt.Println(roundedMedian) // 4 + +MIT License Copyright (c) 2014-2020 Montana Flynn (https://montanaflynn.com) +*/ +package stats diff --git a/vendor/github.com/montanaflynn/stats/entropy.go b/vendor/github.com/montanaflynn/stats/entropy.go new file mode 100644 index 0000000..95263b0 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/entropy.go @@ -0,0 +1,31 @@ +package stats + +import "math" + +// Entropy provides calculation of the entropy +func Entropy(input Float64Data) (float64, error) { + input, err := normalize(input) + if err != nil { + return math.NaN(), err + } + var result float64 + for i := 0; i < input.Len(); i++ { + v := input.Get(i) + if v == 0 { + continue + } + result += (v * math.Log(v)) + } + return -result, nil +} + +func normalize(input Float64Data) (Float64Data, error) { + sum, err := input.Sum() + if err != nil { + return Float64Data{}, err + } + for i := 0; i < input.Len(); i++ { + input[i] = input[i] / sum + } + return input, nil +} diff --git a/vendor/github.com/montanaflynn/stats/errors.go b/vendor/github.com/montanaflynn/stats/errors.go new file mode 100644 index 0000000..95f82ff --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/errors.go @@ -0,0 +1,35 @@ +package stats + +type statsError struct { + err string +} + +func (s statsError) Error() string { + return s.err +} + +func (s statsError) String() string { + return s.err +} + +// These are the package-wide error values. +// All error identification should use these values. +// https://github.com/golang/go/wiki/Errors#naming +var ( + // ErrEmptyInput Input must not be empty + ErrEmptyInput = statsError{"Input must not be empty."} + // ErrNaN Not a number + ErrNaN = statsError{"Not a number."} + // ErrNegative Must not contain negative values + ErrNegative = statsError{"Must not contain negative values."} + // ErrZero Must not contain zero values + ErrZero = statsError{"Must not contain zero values."} + // ErrBounds Input is outside of range + ErrBounds = statsError{"Input is outside of range."} + // ErrSize Must be the same length + ErrSize = statsError{"Must be the same length."} + // ErrInfValue Value is infinite + ErrInfValue = statsError{"Value is infinite."} + // ErrYCoord Y Value must be greater than zero + ErrYCoord = statsError{"Y Value must be greater than zero."} +) diff --git a/vendor/github.com/montanaflynn/stats/geometric_distribution.go b/vendor/github.com/montanaflynn/stats/geometric_distribution.go new file mode 100644 index 0000000..db785dd --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/geometric_distribution.go @@ -0,0 +1,42 @@ +package stats + +import ( + "math" +) + +// ProbGeom generates the probability for a geometric random variable +// with parameter p to achieve success in the interval of [a, b] trials +// See https://en.wikipedia.org/wiki/Geometric_distribution for more information +func ProbGeom(a int, b int, p float64) (prob float64, err error) { + if (a > b) || (a < 1) { + return math.NaN(), ErrBounds + } + + prob = 0 + q := 1 - p // probability of failure + + for k := a + 1; k <= b; k++ { + prob = prob + p*math.Pow(q, float64(k-1)) + } + + return prob, nil +} + +// ProbGeom generates the expectation or average number of trials +// for a geometric random variable with parameter p +func ExpGeom(p float64) (exp float64, err error) { + if (p > 1) || (p < 0) { + return math.NaN(), ErrNegative + } + + return 1 / p, nil +} + +// ProbGeom generates the variance for number for a +// geometric random variable with parameter p +func VarGeom(p float64) (exp float64, err error) { + if (p > 1) || (p < 0) { + return math.NaN(), ErrNegative + } + return (1 - p) / math.Pow(p, 2), nil +} diff --git a/vendor/github.com/montanaflynn/stats/go.mod b/vendor/github.com/montanaflynn/stats/go.mod new file mode 100644 index 0000000..7e3ca1a --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/go.mod @@ -0,0 +1,3 @@ +module github.com/montanaflynn/stats + +go 1.13 diff --git a/vendor/github.com/montanaflynn/stats/legacy.go b/vendor/github.com/montanaflynn/stats/legacy.go new file mode 100644 index 0000000..0f3d1e8 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/legacy.go @@ -0,0 +1,49 @@ +package stats + +// VarP is a shortcut to PopulationVariance +func VarP(input Float64Data) (sdev float64, err error) { + return PopulationVariance(input) +} + +// VarS is a shortcut to SampleVariance +func VarS(input Float64Data) (sdev float64, err error) { + return SampleVariance(input) +} + +// StdDevP is a shortcut to StandardDeviationPopulation +func StdDevP(input Float64Data) (sdev float64, err error) { + return StandardDeviationPopulation(input) +} + +// StdDevS is a shortcut to StandardDeviationSample +func StdDevS(input Float64Data) (sdev float64, err error) { + return StandardDeviationSample(input) +} + +// LinReg is a shortcut to LinearRegression +func LinReg(s []Coordinate) (regressions []Coordinate, err error) { + return LinearRegression(s) +} + +// ExpReg is a shortcut to ExponentialRegression +func ExpReg(s []Coordinate) (regressions []Coordinate, err error) { + return ExponentialRegression(s) +} + +// LogReg is a shortcut to LogarithmicRegression +func LogReg(s []Coordinate) (regressions []Coordinate, err error) { + return LogarithmicRegression(s) +} + +// Legacy error names that didn't start with Err +var ( + EmptyInputErr = ErrEmptyInput + NaNErr = ErrNaN + NegativeErr = ErrNegative + ZeroErr = ErrZero + BoundsErr = ErrBounds + SizeErr = ErrSize + InfValue = ErrInfValue + YCoordErr = ErrYCoord + EmptyInput = ErrEmptyInput +) diff --git a/vendor/github.com/montanaflynn/stats/load.go b/vendor/github.com/montanaflynn/stats/load.go new file mode 100644 index 0000000..0eb0e27 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/load.go @@ -0,0 +1,199 @@ +package stats + +import ( + "bufio" + "io" + "strconv" + "strings" + "time" +) + +// LoadRawData parses and converts a slice of mixed data types to floats +func LoadRawData(raw interface{}) (f Float64Data) { + var r []interface{} + var s Float64Data + + switch t := raw.(type) { + case []interface{}: + r = t + case []uint: + for _, v := range t { + s = append(s, float64(v)) + } + return s + case []uint8: + for _, v := range t { + s = append(s, float64(v)) + } + return s + case []uint16: + for _, v := range t { + s = append(s, float64(v)) + } + return s + case []uint32: + for _, v := range t { + s = append(s, float64(v)) + } + return s + case []uint64: + for _, v := range t { + s = append(s, float64(v)) + } + return s + case []bool: + for _, v := range t { + if v { + s = append(s, 1.0) + } else { + s = append(s, 0.0) + } + } + return s + case []float64: + return Float64Data(t) + case []int: + for _, v := range t { + s = append(s, float64(v)) + } + return s + case []int8: + for _, v := range t { + s = append(s, float64(v)) + } + return s + case []int16: + for _, v := range t { + s = append(s, float64(v)) + } + return s + case []int32: + for _, v := range t { + s = append(s, float64(v)) + } + return s + case []int64: + for _, v := range t { + s = append(s, float64(v)) + } + return s + case []string: + for _, v := range t { + r = append(r, v) + } + case []time.Duration: + for _, v := range t { + r = append(r, v) + } + case map[int]int: + for i := 0; i < len(t); i++ { + s = append(s, float64(t[i])) + } + return s + case map[int]int8: + for i := 0; i < len(t); i++ { + s = append(s, float64(t[i])) + } + return s + case map[int]int16: + for i := 0; i < len(t); i++ { + s = append(s, float64(t[i])) + } + return s + case map[int]int32: + for i := 0; i < len(t); i++ { + s = append(s, float64(t[i])) + } + return s + case map[int]int64: + for i := 0; i < len(t); i++ { + s = append(s, float64(t[i])) + } + return s + case map[int]string: + for i := 0; i < len(t); i++ { + r = append(r, t[i]) + } + case map[int]uint: + for i := 0; i < len(t); i++ { + s = append(s, float64(t[i])) + } + return s + case map[int]uint8: + for i := 0; i < len(t); i++ { + s = append(s, float64(t[i])) + } + return s + case map[int]uint16: + for i := 0; i < len(t); i++ { + s = append(s, float64(t[i])) + } + return s + case map[int]uint32: + for i := 0; i < len(t); i++ { + s = append(s, float64(t[i])) + } + return s + case map[int]uint64: + for i := 0; i < len(t); i++ { + s = append(s, float64(t[i])) + } + return s + case map[int]bool: + for i := 0; i < len(t); i++ { + if t[i] { + s = append(s, 1.0) + } else { + s = append(s, 0.0) + } + } + return s + case map[int]float64: + for i := 0; i < len(t); i++ { + s = append(s, t[i]) + } + return s + case map[int]time.Duration: + for i := 0; i < len(t); i++ { + r = append(r, t[i]) + } + case string: + for _, v := range strings.Fields(t) { + r = append(r, v) + } + case io.Reader: + scanner := bufio.NewScanner(t) + for scanner.Scan() { + l := scanner.Text() + for _, v := range strings.Fields(l) { + r = append(r, v) + } + } + } + + for _, v := range r { + switch t := v.(type) { + case int: + a := float64(t) + f = append(f, a) + case uint: + f = append(f, float64(t)) + case float64: + f = append(f, t) + case string: + fl, err := strconv.ParseFloat(t, 64) + if err == nil { + f = append(f, fl) + } + case bool: + if t { + f = append(f, 1.0) + } else { + f = append(f, 0.0) + } + case time.Duration: + f = append(f, float64(t)) + } + } + return f +} diff --git a/vendor/github.com/montanaflynn/stats/max.go b/vendor/github.com/montanaflynn/stats/max.go new file mode 100644 index 0000000..bb8c83c --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/max.go @@ -0,0 +1,26 @@ +package stats + +import ( + "math" +) + +// Max finds the highest number in a slice +func Max(input Float64Data) (max float64, err error) { + + // Return an error if there are no numbers + if input.Len() == 0 { + return math.NaN(), EmptyInputErr + } + + // Get the first value as the starting point + max = input.Get(0) + + // Loop and replace higher values + for i := 1; i < input.Len(); i++ { + if input.Get(i) > max { + max = input.Get(i) + } + } + + return max, nil +} diff --git a/vendor/github.com/montanaflynn/stats/mean.go b/vendor/github.com/montanaflynn/stats/mean.go new file mode 100644 index 0000000..a78d299 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/mean.go @@ -0,0 +1,60 @@ +package stats + +import "math" + +// Mean gets the average of a slice of numbers +func Mean(input Float64Data) (float64, error) { + + if input.Len() == 0 { + return math.NaN(), EmptyInputErr + } + + sum, _ := input.Sum() + + return sum / float64(input.Len()), nil +} + +// GeometricMean gets the geometric mean for a slice of numbers +func GeometricMean(input Float64Data) (float64, error) { + + l := input.Len() + if l == 0 { + return math.NaN(), EmptyInputErr + } + + // Get the product of all the numbers + var p float64 + for _, n := range input { + if p == 0 { + p = n + } else { + p *= n + } + } + + // Calculate the geometric mean + return math.Pow(p, 1/float64(l)), nil +} + +// HarmonicMean gets the harmonic mean for a slice of numbers +func HarmonicMean(input Float64Data) (float64, error) { + + l := input.Len() + if l == 0 { + return math.NaN(), EmptyInputErr + } + + // Get the sum of all the numbers reciprocals and return an + // error for values that cannot be included in harmonic mean + var p float64 + for _, n := range input { + if n < 0 { + return math.NaN(), NegativeErr + } else if n == 0 { + return math.NaN(), ZeroErr + } + p += (1 / n) + } + + return float64(l) / p, nil +} diff --git a/vendor/github.com/montanaflynn/stats/median.go b/vendor/github.com/montanaflynn/stats/median.go new file mode 100644 index 0000000..a678c36 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/median.go @@ -0,0 +1,25 @@ +package stats + +import "math" + +// Median gets the median number in a slice of numbers +func Median(input Float64Data) (median float64, err error) { + + // Start by sorting a copy of the slice + c := sortedCopy(input) + + // No math is needed if there are no numbers + // For even numbers we add the two middle numbers + // and divide by two using the mean function above + // For odd numbers we just use the middle number + l := len(c) + if l == 0 { + return math.NaN(), EmptyInputErr + } else if l%2 == 0 { + median, _ = Mean(c[l/2-1 : l/2+1]) + } else { + median = c[l/2] + } + + return median, nil +} diff --git a/vendor/github.com/montanaflynn/stats/min.go b/vendor/github.com/montanaflynn/stats/min.go new file mode 100644 index 0000000..bf7e70a --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/min.go @@ -0,0 +1,26 @@ +package stats + +import "math" + +// Min finds the lowest number in a set of data +func Min(input Float64Data) (min float64, err error) { + + // Get the count of numbers in the slice + l := input.Len() + + // Return an error if there are no numbers + if l == 0 { + return math.NaN(), EmptyInputErr + } + + // Get the first value as the starting point + min = input.Get(0) + + // Iterate until done checking for a lower value + for i := 1; i < l; i++ { + if input.Get(i) < min { + min = input.Get(i) + } + } + return min, nil +} diff --git a/vendor/github.com/montanaflynn/stats/mode.go b/vendor/github.com/montanaflynn/stats/mode.go new file mode 100644 index 0000000..a7cf9f7 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/mode.go @@ -0,0 +1,47 @@ +package stats + +// Mode gets the mode [most frequent value(s)] of a slice of float64s +func Mode(input Float64Data) (mode []float64, err error) { + // Return the input if there's only one number + l := input.Len() + if l == 1 { + return input, nil + } else if l == 0 { + return nil, EmptyInputErr + } + + c := sortedCopyDif(input) + // Traverse sorted array, + // tracking the longest repeating sequence + mode = make([]float64, 5) + cnt, maxCnt := 1, 1 + for i := 1; i < l; i++ { + switch { + case c[i] == c[i-1]: + cnt++ + case cnt == maxCnt && maxCnt != 1: + mode = append(mode, c[i-1]) + cnt = 1 + case cnt > maxCnt: + mode = append(mode[:0], c[i-1]) + maxCnt, cnt = cnt, 1 + default: + cnt = 1 + } + } + switch { + case cnt == maxCnt: + mode = append(mode, c[l-1]) + case cnt > maxCnt: + mode = append(mode[:0], c[l-1]) + maxCnt = cnt + } + + // Since length must be greater than 1, + // check for slices of distinct values + if maxCnt == 1 || len(mode)*maxCnt == l && maxCnt != l { + return Float64Data{}, nil + } + + return mode, nil +} diff --git a/vendor/github.com/montanaflynn/stats/norm.go b/vendor/github.com/montanaflynn/stats/norm.go new file mode 100644 index 0000000..4eb8eb8 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/norm.go @@ -0,0 +1,254 @@ +package stats + +import ( + "math" + "math/rand" + "strings" + "time" +) + +// NormPpfRvs generates random variates using the Point Percentile Function. +// For more information please visit: https://demonstrations.wolfram.com/TheMethodOfInverseTransforms/ +func NormPpfRvs(loc float64, scale float64, size int) []float64 { + rand.Seed(time.Now().UnixNano()) + var toReturn []float64 + for i := 0; i < size; i++ { + toReturn = append(toReturn, NormPpf(rand.Float64(), loc, scale)) + } + return toReturn +} + +// NormBoxMullerRvs generates random variates using the Box–Muller transform. +// For more information please visit: http://mathworld.wolfram.com/Box-MullerTransformation.html +func NormBoxMullerRvs(loc float64, scale float64, size int) []float64 { + rand.Seed(time.Now().UnixNano()) + var toReturn []float64 + for i := 0; i < int(float64(size/2)+float64(size%2)); i++ { + // u1 and u2 are uniformly distributed random numbers between 0 and 1. + u1 := rand.Float64() + u2 := rand.Float64() + // x1 and x2 are normally distributed random numbers. + x1 := loc + (scale * (math.Sqrt(-2*math.Log(u1)) * math.Cos(2*math.Pi*u2))) + toReturn = append(toReturn, x1) + if (i+1)*2 <= size { + x2 := loc + (scale * (math.Sqrt(-2*math.Log(u1)) * math.Sin(2*math.Pi*u2))) + toReturn = append(toReturn, x2) + } + } + return toReturn +} + +// NormPdf is the probability density function. +func NormPdf(x float64, loc float64, scale float64) float64 { + return (math.Pow(math.E, -(math.Pow(x-loc, 2))/(2*math.Pow(scale, 2)))) / (scale * math.Sqrt(2*math.Pi)) +} + +// NormLogPdf is the log of the probability density function. +func NormLogPdf(x float64, loc float64, scale float64) float64 { + return math.Log((math.Pow(math.E, -(math.Pow(x-loc, 2))/(2*math.Pow(scale, 2)))) / (scale * math.Sqrt(2*math.Pi))) +} + +// NormCdf is the cumulative distribution function. +func NormCdf(x float64, loc float64, scale float64) float64 { + return 0.5 * (1 + math.Erf((x-loc)/(scale*math.Sqrt(2)))) +} + +// NormLogCdf is the log of the cumulative distribution function. +func NormLogCdf(x float64, loc float64, scale float64) float64 { + return math.Log(0.5 * (1 + math.Erf((x-loc)/(scale*math.Sqrt(2))))) +} + +// NormSf is the survival function (also defined as 1 - cdf, but sf is sometimes more accurate). +func NormSf(x float64, loc float64, scale float64) float64 { + return 1 - 0.5*(1+math.Erf((x-loc)/(scale*math.Sqrt(2)))) +} + +// NormLogSf is the log of the survival function. +func NormLogSf(x float64, loc float64, scale float64) float64 { + return math.Log(1 - 0.5*(1+math.Erf((x-loc)/(scale*math.Sqrt(2))))) +} + +// NormPpf is the point percentile function. +// This is based on Peter John Acklam's inverse normal CDF. +// algorithm: http://home.online.no/~pjacklam/notes/invnorm/ (no longer visible). +// For more information please visit: https://stackedboxes.org/2017/05/01/acklams-normal-quantile-function/ +func NormPpf(p float64, loc float64, scale float64) (x float64) { + const ( + a1 = -3.969683028665376e+01 + a2 = 2.209460984245205e+02 + a3 = -2.759285104469687e+02 + a4 = 1.383577518672690e+02 + a5 = -3.066479806614716e+01 + a6 = 2.506628277459239e+00 + + b1 = -5.447609879822406e+01 + b2 = 1.615858368580409e+02 + b3 = -1.556989798598866e+02 + b4 = 6.680131188771972e+01 + b5 = -1.328068155288572e+01 + + c1 = -7.784894002430293e-03 + c2 = -3.223964580411365e-01 + c3 = -2.400758277161838e+00 + c4 = -2.549732539343734e+00 + c5 = 4.374664141464968e+00 + c6 = 2.938163982698783e+00 + + d1 = 7.784695709041462e-03 + d2 = 3.224671290700398e-01 + d3 = 2.445134137142996e+00 + d4 = 3.754408661907416e+00 + + plow = 0.02425 + phigh = 1 - plow + ) + + if p < 0 || p > 1 { + return math.NaN() + } else if p == 0 { + return -math.Inf(0) + } else if p == 1 { + return math.Inf(0) + } + + if p < plow { + q := math.Sqrt(-2 * math.Log(p)) + x = (((((c1*q+c2)*q+c3)*q+c4)*q+c5)*q + c6) / + ((((d1*q+d2)*q+d3)*q+d4)*q + 1) + } else if phigh < p { + q := math.Sqrt(-2 * math.Log(1-p)) + x = -(((((c1*q+c2)*q+c3)*q+c4)*q+c5)*q + c6) / + ((((d1*q+d2)*q+d3)*q+d4)*q + 1) + } else { + q := p - 0.5 + r := q * q + x = (((((a1*r+a2)*r+a3)*r+a4)*r+a5)*r + a6) * q / + (((((b1*r+b2)*r+b3)*r+b4)*r+b5)*r + 1) + } + + e := 0.5*math.Erfc(-x/math.Sqrt2) - p + u := e * math.Sqrt(2*math.Pi) * math.Exp(x*x/2) + x = x - u/(1+x*u/2) + + return x*scale + loc +} + +// NormIsf is the inverse survival function (inverse of sf). +func NormIsf(p float64, loc float64, scale float64) (x float64) { + if -NormPpf(p, loc, scale) == 0 { + return 0 + } + return -NormPpf(p, loc, scale) +} + +// NormMoment approximates the non-central (raw) moment of order n. +// For more information please visit: https://math.stackexchange.com/questions/1945448/methods-for-finding-raw-moments-of-the-normal-distribution +func NormMoment(n int, loc float64, scale float64) float64 { + toReturn := 0.0 + for i := 0; i < n+1; i++ { + if (n-i)%2 == 0 { + toReturn += float64(Ncr(n, i)) * (math.Pow(loc, float64(i))) * (math.Pow(scale, float64(n-i))) * + (float64(factorial(n-i)) / ((math.Pow(2.0, float64((n-i)/2))) * + float64(factorial((n-i)/2)))) + } + } + return toReturn +} + +// NormStats returns the mean, variance, skew, and/or kurtosis. +// Mean(‘m’), variance(‘v’), skew(‘s’), and/or kurtosis(‘k’). +// Takes string containing any of 'mvsk'. +// Returns array of m v s k in that order. +func NormStats(loc float64, scale float64, moments string) []float64 { + var toReturn []float64 + if strings.ContainsAny(moments, "m") { + toReturn = append(toReturn, loc) + } + if strings.ContainsAny(moments, "v") { + toReturn = append(toReturn, math.Pow(scale, 2)) + } + if strings.ContainsAny(moments, "s") { + toReturn = append(toReturn, 0.0) + } + if strings.ContainsAny(moments, "k") { + toReturn = append(toReturn, 0.0) + } + return toReturn +} + +// NormEntropy is the differential entropy of the RV. +func NormEntropy(loc float64, scale float64) float64 { + return math.Log(scale * math.Sqrt(2*math.Pi*math.E)) +} + +// NormFit returns the maximum likelihood estimators for the Normal Distribution. +// Takes array of float64 values. +// Returns array of Mean followed by Standard Deviation. +func NormFit(data []float64) [2]float64 { + sum := 0.00 + for i := 0; i < len(data); i++ { + sum += data[i] + } + mean := sum / float64(len(data)) + stdNumerator := 0.00 + for i := 0; i < len(data); i++ { + stdNumerator += math.Pow(data[i]-mean, 2) + } + return [2]float64{mean, math.Sqrt((stdNumerator) / (float64(len(data))))} +} + +// NormMedian is the median of the distribution. +func NormMedian(loc float64, scale float64) float64 { + return loc +} + +// NormMean is the mean/expected value of the distribution. +func NormMean(loc float64, scale float64) float64 { + return loc +} + +// NormVar is the variance of the distribution. +func NormVar(loc float64, scale float64) float64 { + return math.Pow(scale, 2) +} + +// NormStd is the standard deviation of the distribution. +func NormStd(loc float64, scale float64) float64 { + return scale +} + +// NormInterval finds endpoints of the range that contains alpha percent of the distribution. +func NormInterval(alpha float64, loc float64, scale float64) [2]float64 { + q1 := (1.0 - alpha) / 2 + q2 := (1.0 + alpha) / 2 + a := NormPpf(q1, loc, scale) + b := NormPpf(q2, loc, scale) + return [2]float64{a, b} +} + +// factorial is the naive factorial algorithm. +func factorial(x int) int { + if x == 0 { + return 1 + } + return x * factorial(x-1) +} + +// Ncr is an N choose R algorithm. +// Aaron Cannon's algorithm. +func Ncr(n, r int) int { + if n <= 1 || r == 0 || n == r { + return 1 + } + if newR := n - r; newR < r { + r = newR + } + if r == 1 { + return n + } + ret := int(n - r + 1) + for i, j := ret+1, int(2); j <= r; i, j = i+1, j+1 { + ret = ret * i / j + } + return ret +} diff --git a/vendor/github.com/montanaflynn/stats/outlier.go b/vendor/github.com/montanaflynn/stats/outlier.go new file mode 100644 index 0000000..7c9795b --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/outlier.go @@ -0,0 +1,44 @@ +package stats + +// Outliers holds mild and extreme outliers found in data +type Outliers struct { + Mild Float64Data + Extreme Float64Data +} + +// QuartileOutliers finds the mild and extreme outliers +func QuartileOutliers(input Float64Data) (Outliers, error) { + if input.Len() == 0 { + return Outliers{}, EmptyInputErr + } + + // Start by sorting a copy of the slice + copy := sortedCopy(input) + + // Calculate the quartiles and interquartile range + qs, _ := Quartile(copy) + iqr, _ := InterQuartileRange(copy) + + // Calculate the lower and upper inner and outer fences + lif := qs.Q1 - (1.5 * iqr) + uif := qs.Q3 + (1.5 * iqr) + lof := qs.Q1 - (3 * iqr) + uof := qs.Q3 + (3 * iqr) + + // Find the data points that are outside of the + // inner and upper fences and add them to mild + // and extreme outlier slices + var mild Float64Data + var extreme Float64Data + for _, v := range copy { + + if v < lof || v > uof { + extreme = append(extreme, v) + } else if v < lif || v > uif { + mild = append(mild, v) + } + } + + // Wrap them into our struct + return Outliers{mild, extreme}, nil +} diff --git a/vendor/github.com/montanaflynn/stats/percentile.go b/vendor/github.com/montanaflynn/stats/percentile.go new file mode 100644 index 0000000..f564178 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/percentile.go @@ -0,0 +1,86 @@ +package stats + +import ( + "math" +) + +// Percentile finds the relative standing in a slice of floats +func Percentile(input Float64Data, percent float64) (percentile float64, err error) { + length := input.Len() + if length == 0 { + return math.NaN(), EmptyInputErr + } + + if length == 1 { + return input[0], nil + } + + if percent <= 0 || percent > 100 { + return math.NaN(), BoundsErr + } + + // Start by sorting a copy of the slice + c := sortedCopy(input) + + // Multiply percent by length of input + index := (percent / 100) * float64(len(c)) + + // Check if the index is a whole number + if index == float64(int64(index)) { + + // Convert float to int + i := int(index) + + // Find the value at the index + percentile = c[i-1] + + } else if index > 1 { + + // Convert float to int via truncation + i := int(index) + + // Find the average of the index and following values + percentile, _ = Mean(Float64Data{c[i-1], c[i]}) + + } else { + return math.NaN(), BoundsErr + } + + return percentile, nil + +} + +// PercentileNearestRank finds the relative standing in a slice of floats using the Nearest Rank method +func PercentileNearestRank(input Float64Data, percent float64) (percentile float64, err error) { + + // Find the length of items in the slice + il := input.Len() + + // Return an error for empty slices + if il == 0 { + return math.NaN(), EmptyInputErr + } + + // Return error for less than 0 or greater than 100 percentages + if percent < 0 || percent > 100 { + return math.NaN(), BoundsErr + } + + // Start by sorting a copy of the slice + c := sortedCopy(input) + + // Return the last item + if percent == 100.0 { + return c[il-1], nil + } + + // Find ordinal ranking + or := int(math.Ceil(float64(il) * percent / 100)) + + // Return the item that is in the place of the ordinal rank + if or == 0 { + return c[0], nil + } + return c[or-1], nil + +} diff --git a/vendor/github.com/montanaflynn/stats/quartile.go b/vendor/github.com/montanaflynn/stats/quartile.go new file mode 100644 index 0000000..40bbf6e --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/quartile.go @@ -0,0 +1,74 @@ +package stats + +import "math" + +// Quartiles holds the three quartile points +type Quartiles struct { + Q1 float64 + Q2 float64 + Q3 float64 +} + +// Quartile returns the three quartile points from a slice of data +func Quartile(input Float64Data) (Quartiles, error) { + + il := input.Len() + if il == 0 { + return Quartiles{}, EmptyInputErr + } + + // Start by sorting a copy of the slice + copy := sortedCopy(input) + + // Find the cutoff places depeding on if + // the input slice length is even or odd + var c1 int + var c2 int + if il%2 == 0 { + c1 = il / 2 + c2 = il / 2 + } else { + c1 = (il - 1) / 2 + c2 = c1 + 1 + } + + // Find the Medians with the cutoff points + Q1, _ := Median(copy[:c1]) + Q2, _ := Median(copy) + Q3, _ := Median(copy[c2:]) + + return Quartiles{Q1, Q2, Q3}, nil + +} + +// InterQuartileRange finds the range between Q1 and Q3 +func InterQuartileRange(input Float64Data) (float64, error) { + if input.Len() == 0 { + return math.NaN(), EmptyInputErr + } + qs, _ := Quartile(input) + iqr := qs.Q3 - qs.Q1 + return iqr, nil +} + +// Midhinge finds the average of the first and third quartiles +func Midhinge(input Float64Data) (float64, error) { + if input.Len() == 0 { + return math.NaN(), EmptyInputErr + } + qs, _ := Quartile(input) + mh := (qs.Q1 + qs.Q3) / 2 + return mh, nil +} + +// Trimean finds the average of the median and the midhinge +func Trimean(input Float64Data) (float64, error) { + if input.Len() == 0 { + return math.NaN(), EmptyInputErr + } + + c := sortedCopy(input) + q, _ := Quartile(c) + + return (q.Q1 + (q.Q2 * 2) + q.Q3) / 4, nil +} diff --git a/vendor/github.com/montanaflynn/stats/ranksum.go b/vendor/github.com/montanaflynn/stats/ranksum.go new file mode 100644 index 0000000..fc424ef --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/ranksum.go @@ -0,0 +1,183 @@ +package stats + +// import "math" +// +// // WilcoxonRankSum tests the null hypothesis that two sets +// // of data are drawn from the same distribution. It does +// // not handle ties between measurements in x and y. +// // +// // Parameters: +// // data1 Float64Data: First set of data points. +// // data2 Float64Data: Second set of data points. +// // Length of both data samples must be equal. +// // +// // Return: +// // statistic float64: The test statistic under the +// // large-sample approximation that the +// // rank sum statistic is normally distributed. +// // pvalue float64: The two-sided p-value of the test +// // err error: Any error from the input data parameters +// // +// // https://en.wikipedia.org/wiki/Wilcoxon_rank-sum_test +// func WilcoxonRankSum(data1, data2 Float64Data) (float64, float64, error) { +// +// l1 := data1.Len() +// l2 := data2.Len() +// +// if l1 == 0 || l2 == 0 { +// return math.NaN(), math.NaN(), EmptyInputErr +// } +// +// if l1 != l2 { +// return math.NaN(), math.NaN(), SizeErr +// } +// +// alldata := Float64Data{} +// alldata = append(alldata, data1...) +// alldata = append(alldata, data2...) +// +// // ranked := +// +// return 0.0, 0.0, nil +// } +// +// // x, y = map(np.asarray, (x, y)) +// // n1 = len(x) +// // n2 = len(y) +// // alldata = np.concatenate((x, y)) +// // ranked = rankdata(alldata) +// // x = ranked[:n1] +// // s = np.sum(x, axis=0) +// // expected = n1 * (n1+n2+1) / 2.0 +// // z = (s - expected) / np.sqrt(n1*n2*(n1+n2+1)/12.0) +// // prob = 2 * distributions.norm.sf(abs(z)) +// // +// // return RanksumsResult(z, prob) +// +// // def rankdata(a, method='average'): +// // """ +// // Assign ranks to data, dealing with ties appropriately. +// // Ranks begin at 1. The `method` argument controls how ranks are assigned +// // to equal values. See [1]_ for further discussion of ranking methods. +// // Parameters +// // ---------- +// // a : array_like +// // The array of values to be ranked. The array is first flattened. +// // method : str, optional +// // The method used to assign ranks to tied elements. +// // The options are 'average', 'min', 'max', 'dense' and 'ordinal'. +// // 'average': +// // The average of the ranks that would have been assigned to +// // all the tied values is assigned to each value. +// // 'min': +// // The minimum of the ranks that would have been assigned to all +// // the tied values is assigned to each value. (This is also +// // referred to as "competition" ranking.) +// // 'max': +// // The maximum of the ranks that would have been assigned to all +// // the tied values is assigned to each value. +// // 'dense': +// // Like 'min', but the rank of the next highest element is assigned +// // the rank immediately after those assigned to the tied elements. +// // 'ordinal': +// // All values are given a distinct rank, corresponding to the order +// // that the values occur in `a`. +// // The default is 'average'. +// // Returns +// // ------- +// // ranks : ndarray +// // An array of length equal to the size of `a`, containing rank +// // scores. +// // References +// // ---------- +// // .. [1] "Ranking", https://en.wikipedia.org/wiki/Ranking +// // Examples +// // -------- +// // >>> from scipy.stats import rankdata +// // >>> rankdata([0, 2, 3, 2]) +// // array([ 1. , 2.5, 4. , 2.5]) +// // """ +// // +// // arr = np.ravel(np.asarray(a)) +// // algo = 'quicksort' +// // sorter = np.argsort(arr, kind=algo) +// // +// // inv = np.empty(sorter.size, dtype=np.intp) +// // inv[sorter] = np.arange(sorter.size, dtype=np.intp) +// // +// // +// // arr = arr[sorter] +// // obs = np.r_[True, arr[1:] != arr[:-1]] +// // dense = obs.cumsum()[inv] +// // +// // +// // # cumulative counts of each unique value +// // count = np.r_[np.nonzero(obs)[0], len(obs)] +// // +// // # average method +// // return .5 * (count[dense] + count[dense - 1] + 1) +// +// type rankable interface { +// Len() int +// RankEqual(int, int) bool +// } +// +// func StandardRank(d rankable) []float64 { +// r := make([]float64, d.Len()) +// var k int +// for i := range r { +// if i == 0 || !d.RankEqual(i, i-1) { +// k = i + 1 +// } +// r[i] = float64(k) +// } +// return r +// } +// +// func ModifiedRank(d rankable) []float64 { +// r := make([]float64, d.Len()) +// for i := range r { +// k := i + 1 +// for j := i + 1; j < len(r) && d.RankEqual(i, j); j++ { +// k = j + 1 +// } +// r[i] = float64(k) +// } +// return r +// } +// +// func DenseRank(d rankable) []float64 { +// r := make([]float64, d.Len()) +// var k int +// for i := range r { +// if i == 0 || !d.RankEqual(i, i-1) { +// k++ +// } +// r[i] = float64(k) +// } +// return r +// } +// +// func OrdinalRank(d rankable) []float64 { +// r := make([]float64, d.Len()) +// for i := range r { +// r[i] = float64(i + 1) +// } +// return r +// } +// +// func FractionalRank(d rankable) []float64 { +// r := make([]float64, d.Len()) +// for i := 0; i < len(r); { +// var j int +// f := float64(i + 1) +// for j = i + 1; j < len(r) && d.RankEqual(i, j); j++ { +// f += float64(j + 1) +// } +// f /= float64(j - i) +// for ; i < j; i++ { +// r[i] = f +// } +// } +// return r +// } diff --git a/vendor/github.com/montanaflynn/stats/regression.go b/vendor/github.com/montanaflynn/stats/regression.go new file mode 100644 index 0000000..401d951 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/regression.go @@ -0,0 +1,113 @@ +package stats + +import "math" + +// Series is a container for a series of data +type Series []Coordinate + +// Coordinate holds the data in a series +type Coordinate struct { + X, Y float64 +} + +// LinearRegression finds the least squares linear regression on data series +func LinearRegression(s Series) (regressions Series, err error) { + + if len(s) == 0 { + return nil, EmptyInputErr + } + + // Placeholder for the math to be done + var sum [5]float64 + + // Loop over data keeping index in place + i := 0 + for ; i < len(s); i++ { + sum[0] += s[i].X + sum[1] += s[i].Y + sum[2] += s[i].X * s[i].X + sum[3] += s[i].X * s[i].Y + sum[4] += s[i].Y * s[i].Y + } + + // Find gradient and intercept + f := float64(i) + gradient := (f*sum[3] - sum[0]*sum[1]) / (f*sum[2] - sum[0]*sum[0]) + intercept := (sum[1] / f) - (gradient * sum[0] / f) + + // Create the new regression series + for j := 0; j < len(s); j++ { + regressions = append(regressions, Coordinate{ + X: s[j].X, + Y: s[j].X*gradient + intercept, + }) + } + + return regressions, nil +} + +// ExponentialRegression returns an exponential regression on data series +func ExponentialRegression(s Series) (regressions Series, err error) { + + if len(s) == 0 { + return nil, EmptyInputErr + } + + var sum [6]float64 + + for i := 0; i < len(s); i++ { + if s[i].Y < 0 { + return nil, YCoordErr + } + sum[0] += s[i].X + sum[1] += s[i].Y + sum[2] += s[i].X * s[i].X * s[i].Y + sum[3] += s[i].Y * math.Log(s[i].Y) + sum[4] += s[i].X * s[i].Y * math.Log(s[i].Y) + sum[5] += s[i].X * s[i].Y + } + + denominator := (sum[1]*sum[2] - sum[5]*sum[5]) + a := math.Pow(math.E, (sum[2]*sum[3]-sum[5]*sum[4])/denominator) + b := (sum[1]*sum[4] - sum[5]*sum[3]) / denominator + + for j := 0; j < len(s); j++ { + regressions = append(regressions, Coordinate{ + X: s[j].X, + Y: a * math.Exp(b*s[j].X), + }) + } + + return regressions, nil +} + +// LogarithmicRegression returns an logarithmic regression on data series +func LogarithmicRegression(s Series) (regressions Series, err error) { + + if len(s) == 0 { + return nil, EmptyInputErr + } + + var sum [4]float64 + + i := 0 + for ; i < len(s); i++ { + sum[0] += math.Log(s[i].X) + sum[1] += s[i].Y * math.Log(s[i].X) + sum[2] += s[i].Y + sum[3] += math.Pow(math.Log(s[i].X), 2) + } + + f := float64(i) + a := (f*sum[1] - sum[2]*sum[0]) / (f*sum[3] - sum[0]*sum[0]) + b := (sum[2] - a*sum[0]) / f + + for j := 0; j < len(s); j++ { + regressions = append(regressions, Coordinate{ + X: s[j].X, + Y: b + a*math.Log(s[j].X), + }) + } + + return regressions, nil +} diff --git a/vendor/github.com/montanaflynn/stats/round.go b/vendor/github.com/montanaflynn/stats/round.go new file mode 100644 index 0000000..b66779c --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/round.go @@ -0,0 +1,38 @@ +package stats + +import "math" + +// Round a float to a specific decimal place or precision +func Round(input float64, places int) (rounded float64, err error) { + + // If the float is not a number + if math.IsNaN(input) { + return math.NaN(), NaNErr + } + + // Find out the actual sign and correct the input for later + sign := 1.0 + if input < 0 { + sign = -1 + input *= -1 + } + + // Use the places arg to get the amount of precision wanted + precision := math.Pow(10, float64(places)) + + // Find the decimal place we are looking to round + digit := input * precision + + // Get the actual decimal number as a fraction to be compared + _, decimal := math.Modf(digit) + + // If the decimal is less than .5 we round down otherwise up + if decimal >= 0.5 { + rounded = math.Ceil(digit) + } else { + rounded = math.Floor(digit) + } + + // Finally we do the math to actually create a rounded number + return rounded / precision * sign, nil +} diff --git a/vendor/github.com/montanaflynn/stats/sample.go b/vendor/github.com/montanaflynn/stats/sample.go new file mode 100644 index 0000000..40166af --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/sample.go @@ -0,0 +1,76 @@ +package stats + +import ( + "math/rand" + "sort" +) + +// Sample returns sample from input with replacement or without +func Sample(input Float64Data, takenum int, replacement bool) ([]float64, error) { + + if input.Len() == 0 { + return nil, EmptyInputErr + } + + length := input.Len() + if replacement { + + result := Float64Data{} + rand.Seed(unixnano()) + + // In every step, randomly take the num for + for i := 0; i < takenum; i++ { + idx := rand.Intn(length) + result = append(result, input[idx]) + } + + return result, nil + + } else if !replacement && takenum <= length { + + rand.Seed(unixnano()) + + // Get permutation of number of indexies + perm := rand.Perm(length) + result := Float64Data{} + + // Get element of input by permutated index + for _, idx := range perm[0:takenum] { + result = append(result, input[idx]) + } + + return result, nil + + } + + return nil, BoundsErr +} + +// StableSample like stable sort, it returns samples from input while keeps the order of original data. +func StableSample(input Float64Data, takenum int) ([]float64, error) { + if input.Len() == 0 { + return nil, EmptyInputErr + } + + length := input.Len() + + if takenum <= length { + + rand.Seed(unixnano()) + + perm := rand.Perm(length) + perm = perm[0:takenum] + // Sort perm before applying + sort.Ints(perm) + result := Float64Data{} + + for _, idx := range perm { + result = append(result, input[idx]) + } + + return result, nil + + } + + return nil, BoundsErr +} diff --git a/vendor/github.com/montanaflynn/stats/sigmoid.go b/vendor/github.com/montanaflynn/stats/sigmoid.go new file mode 100644 index 0000000..5f2559d --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/sigmoid.go @@ -0,0 +1,18 @@ +package stats + +import "math" + +// Sigmoid returns the input values in the range of -1 to 1 +// along the sigmoid or s-shaped curve, commonly used in +// machine learning while training neural networks as an +// activation function. +func Sigmoid(input Float64Data) ([]float64, error) { + if input.Len() == 0 { + return Float64Data{}, EmptyInput + } + s := make([]float64, len(input)) + for i, v := range input { + s[i] = 1 / (1 + math.Exp(-v)) + } + return s, nil +} diff --git a/vendor/github.com/montanaflynn/stats/softmax.go b/vendor/github.com/montanaflynn/stats/softmax.go new file mode 100644 index 0000000..8507264 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/softmax.go @@ -0,0 +1,25 @@ +package stats + +import "math" + +// SoftMax returns the input values in the range of 0 to 1 +// with sum of all the probabilities being equal to one. It +// is commonly used in machine learning neural networks. +func SoftMax(input Float64Data) ([]float64, error) { + if input.Len() == 0 { + return Float64Data{}, EmptyInput + } + + s := 0.0 + c, _ := Max(input) + for _, e := range input { + s += math.Exp(e - c) + } + + sm := make([]float64, len(input)) + for i, v := range input { + sm[i] = math.Exp(v-c) / s + } + + return sm, nil +} diff --git a/vendor/github.com/montanaflynn/stats/sum.go b/vendor/github.com/montanaflynn/stats/sum.go new file mode 100644 index 0000000..15b611d --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/sum.go @@ -0,0 +1,18 @@ +package stats + +import "math" + +// Sum adds all the numbers of a slice together +func Sum(input Float64Data) (sum float64, err error) { + + if input.Len() == 0 { + return math.NaN(), EmptyInputErr + } + + // Add em up + for _, n := range input { + sum += n + } + + return sum, nil +} diff --git a/vendor/github.com/montanaflynn/stats/util.go b/vendor/github.com/montanaflynn/stats/util.go new file mode 100644 index 0000000..8819976 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/util.go @@ -0,0 +1,43 @@ +package stats + +import ( + "sort" + "time" +) + +// float64ToInt rounds a float64 to an int +func float64ToInt(input float64) (output int) { + r, _ := Round(input, 0) + return int(r) +} + +// unixnano returns nanoseconds from UTC epoch +func unixnano() int64 { + return time.Now().UTC().UnixNano() +} + +// copyslice copies a slice of float64s +func copyslice(input Float64Data) Float64Data { + s := make(Float64Data, input.Len()) + copy(s, input) + return s +} + +// sortedCopy returns a sorted copy of float64s +func sortedCopy(input Float64Data) (copy Float64Data) { + copy = copyslice(input) + sort.Float64s(copy) + return +} + +// sortedCopyDif returns a sorted copy of float64s +// only if the original data isn't sorted. +// Only use this if returned slice won't be manipulated! +func sortedCopyDif(input Float64Data) (copy Float64Data) { + if sort.Float64sAreSorted(input) { + return input + } + copy = copyslice(input) + sort.Float64s(copy) + return +} diff --git a/vendor/github.com/montanaflynn/stats/variance.go b/vendor/github.com/montanaflynn/stats/variance.go new file mode 100644 index 0000000..a644569 --- /dev/null +++ b/vendor/github.com/montanaflynn/stats/variance.go @@ -0,0 +1,105 @@ +package stats + +import "math" + +// _variance finds the variance for both population and sample data +func _variance(input Float64Data, sample int) (variance float64, err error) { + + if input.Len() == 0 { + return math.NaN(), EmptyInputErr + } + + // Sum the square of the mean subtracted from each number + m, _ := Mean(input) + + for _, n := range input { + variance += (n - m) * (n - m) + } + + // When getting the mean of the squared differences + // "sample" will allow us to know if it's a sample + // or population and wether to subtract by one or not + return variance / float64((input.Len() - (1 * sample))), nil +} + +// Variance the amount of variation in the dataset +func Variance(input Float64Data) (sdev float64, err error) { + return PopulationVariance(input) +} + +// PopulationVariance finds the amount of variance within a population +func PopulationVariance(input Float64Data) (pvar float64, err error) { + + v, err := _variance(input, 0) + if err != nil { + return math.NaN(), err + } + + return v, nil +} + +// SampleVariance finds the amount of variance within a sample +func SampleVariance(input Float64Data) (svar float64, err error) { + + v, err := _variance(input, 1) + if err != nil { + return math.NaN(), err + } + + return v, nil +} + +// Covariance is a measure of how much two sets of data change +func Covariance(data1, data2 Float64Data) (float64, error) { + + l1 := data1.Len() + l2 := data2.Len() + + if l1 == 0 || l2 == 0 { + return math.NaN(), EmptyInputErr + } + + if l1 != l2 { + return math.NaN(), SizeErr + } + + m1, _ := Mean(data1) + m2, _ := Mean(data2) + + // Calculate sum of squares + var ss float64 + for i := 0; i < l1; i++ { + delta1 := (data1.Get(i) - m1) + delta2 := (data2.Get(i) - m2) + ss += (delta1*delta2 - ss) / float64(i+1) + } + + return ss * float64(l1) / float64(l1-1), nil +} + +// CovariancePopulation computes covariance for entire population between two variables. +func CovariancePopulation(data1, data2 Float64Data) (float64, error) { + + l1 := data1.Len() + l2 := data2.Len() + + if l1 == 0 || l2 == 0 { + return math.NaN(), EmptyInputErr + } + + if l1 != l2 { + return math.NaN(), SizeErr + } + + m1, _ := Mean(data1) + m2, _ := Mean(data2) + + var s float64 + for i := 0; i < l1; i++ { + delta1 := (data1.Get(i) - m1) + delta2 := (data2.Get(i) - m2) + s += delta1 * delta2 + } + + return s / float64(l1), nil +} diff --git a/vendor/github.com/xdg-go/scram/CHANGELOG.md b/vendor/github.com/xdg-go/scram/CHANGELOG.md index 425c122..b833be5 100644 --- a/vendor/github.com/xdg-go/scram/CHANGELOG.md +++ b/vendor/github.com/xdg-go/scram/CHANGELOG.md @@ -1,5 +1,17 @@ # CHANGELOG +## v1.1.2 - 2022-12-07 + +- Bump stringprep dependency to v1.0.4 for upstream CVE fix. + +## v1.1.1 - 2022-03-03 + +- Bump stringprep dependency to v1.0.3 for upstream CVE fix. + +## v1.1.0 - 2022-01-16 + +- Add SHA-512 hash generator function for convenience. + ## v1.0.2 - 2021-03-28 - Switch PBKDF2 dependency to github.com/xdg-go/pbkdf2 to diff --git a/vendor/github.com/xdg-go/scram/doc.go b/vendor/github.com/xdg-go/scram/doc.go index d43bee6..82e8aee 100644 --- a/vendor/github.com/xdg-go/scram/doc.go +++ b/vendor/github.com/xdg-go/scram/doc.go @@ -10,14 +10,16 @@ // // Usage // -// The scram package provides two variables, `SHA1` and `SHA256`, that are -// used to construct Client or Server objects. +// The scram package provides variables, `SHA1`, `SHA256`, and `SHA512`, that +// are used to construct Client or Server objects. // // clientSHA1, err := scram.SHA1.NewClient(username, password, authID) // clientSHA256, err := scram.SHA256.NewClient(username, password, authID) +// clientSHA512, err := scram.SHA512.NewClient(username, password, authID) // // serverSHA1, err := scram.SHA1.NewServer(credentialLookupFcn) // serverSHA256, err := scram.SHA256.NewServer(credentialLookupFcn) +// serverSHA512, err := scram.SHA512.NewServer(credentialLookupFcn) // // These objects are used to construct ClientConversation or // ServerConversation objects that are used to carry out authentication. diff --git a/vendor/github.com/xdg-go/scram/go.mod b/vendor/github.com/xdg-go/scram/go.mod index ad37635..d641553 100644 --- a/vendor/github.com/xdg-go/scram/go.mod +++ b/vendor/github.com/xdg-go/scram/go.mod @@ -4,5 +4,5 @@ go 1.11 require ( github.com/xdg-go/pbkdf2 v1.0.0 - github.com/xdg-go/stringprep v1.0.2 + github.com/xdg-go/stringprep v1.0.4 ) diff --git a/vendor/github.com/xdg-go/scram/go.sum b/vendor/github.com/xdg-go/scram/go.sum index 0409882..5edb2d6 100644 --- a/vendor/github.com/xdg-go/scram/go.sum +++ b/vendor/github.com/xdg-go/scram/go.sum @@ -1,7 +1,29 @@ github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= -github.com/xdg-go/stringprep v1.0.2 h1:6iq84/ryjjeRmMJwxutI51F2GIPlP5BfTvXHeYjyhBc= -github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= -golang.org/x/text v0.3.5 h1:i6eZZ+zk0SOf0xgBpEpPD18qWcJda6q1sxt3S0kzyUQ= -golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/vendor/github.com/xdg-go/scram/scram.go b/vendor/github.com/xdg-go/scram/scram.go index 9276599..a7b3660 100644 --- a/vendor/github.com/xdg-go/scram/scram.go +++ b/vendor/github.com/xdg-go/scram/scram.go @@ -9,6 +9,7 @@ package scram import ( "crypto/sha1" "crypto/sha256" + "crypto/sha512" "fmt" "hash" @@ -29,6 +30,10 @@ var SHA1 HashGeneratorFcn = func() hash.Hash { return sha1.New() } // to create Client objects configured for SHA-256 hashing. var SHA256 HashGeneratorFcn = func() hash.Hash { return sha256.New() } +// SHA512 is a function that returns a crypto/sha512 hasher and should be used +// to create Client objects configured for SHA-512 hashing. +var SHA512 HashGeneratorFcn = func() hash.Hash { return sha512.New() } + // NewClient constructs a SCRAM client component based on a given hash.Hash // factory receiver. This constructor will normalize the username, password // and authzID via the SASLprep algorithm, as recommended by RFC-5802. If diff --git a/vendor/github.com/xdg-go/stringprep/CHANGELOG.md b/vendor/github.com/xdg-go/stringprep/CHANGELOG.md index 2849637..04b9753 100644 --- a/vendor/github.com/xdg-go/stringprep/CHANGELOG.md +++ b/vendor/github.com/xdg-go/stringprep/CHANGELOG.md @@ -1,5 +1,19 @@ # CHANGELOG + +## [v1.0.4] - 2022-12-07 + +### Maintenance + +- Bump golang.org/x/text to v0.3.8 due to CVE-2022-32149 + + +## [v1.0.3] - 2022-03-01 + +### Maintenance + +- Bump golang.org/x/text to v0.3.7 due to CVE-2021-38561 + ## [v1.0.2] - 2021-03-27 diff --git a/vendor/github.com/xdg-go/stringprep/go.mod b/vendor/github.com/xdg-go/stringprep/go.mod index f57123b..0af0f60 100644 --- a/vendor/github.com/xdg-go/stringprep/go.mod +++ b/vendor/github.com/xdg-go/stringprep/go.mod @@ -2,4 +2,4 @@ module github.com/xdg-go/stringprep go 1.11 -require golang.org/x/text v0.3.5 +require golang.org/x/text v0.3.8 diff --git a/vendor/github.com/xdg-go/stringprep/go.sum b/vendor/github.com/xdg-go/stringprep/go.sum index bbd33e8..b691fa8 100644 --- a/vendor/github.com/xdg-go/stringprep/go.sum +++ b/vendor/github.com/xdg-go/stringprep/go.sum @@ -1,3 +1,25 @@ -golang.org/x/text v0.3.5 h1:i6eZZ+zk0SOf0xgBpEpPD18qWcJda6q1sxt3S0kzyUQ= -golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/vendor/github.com/youmark/pkcs8/.travis.yml b/vendor/github.com/youmark/pkcs8/.travis.yml deleted file mode 100644 index 0bceef6..0000000 --- a/vendor/github.com/youmark/pkcs8/.travis.yml +++ /dev/null @@ -1,9 +0,0 @@ -language: go - -go: - - "1.9.x" - - "1.10.x" - - master - -script: - - go test -v ./... diff --git a/vendor/github.com/youmark/pkcs8/README.md b/vendor/github.com/youmark/pkcs8/README.md index f2167db..ef6c762 100644 --- a/vendor/github.com/youmark/pkcs8/README.md +++ b/vendor/github.com/youmark/pkcs8/README.md @@ -8,14 +8,15 @@ pkcs8 package fills the gap here. It implements functions to process private key [**Godoc**](http://godoc.org/github.com/youmark/pkcs8) ## Installation -Supports Go 1.9+ +Supports Go 1.10+. Release v1.1 is the last release supporting Go 1.9 ```text go get github.com/youmark/pkcs8 ``` ## dependency -This package depends on golang.org/x/crypto/pbkdf2 package. Use the following command to retrive pbkdf2 package +This package depends on golang.org/x/crypto/pbkdf2 and golang.org/x/crypto/scrypt packages. Use the following command to retrieve them ```text go get golang.org/x/crypto/pbkdf2 +go get golang.org/x/crypto/scrypt ``` diff --git a/vendor/github.com/youmark/pkcs8/cipher.go b/vendor/github.com/youmark/pkcs8/cipher.go new file mode 100644 index 0000000..2946c93 --- /dev/null +++ b/vendor/github.com/youmark/pkcs8/cipher.go @@ -0,0 +1,60 @@ +package pkcs8 + +import ( + "bytes" + "crypto/cipher" + "encoding/asn1" +) + +type cipherWithBlock struct { + oid asn1.ObjectIdentifier + ivSize int + keySize int + newBlock func(key []byte) (cipher.Block, error) +} + +func (c cipherWithBlock) IVSize() int { + return c.ivSize +} + +func (c cipherWithBlock) KeySize() int { + return c.keySize +} + +func (c cipherWithBlock) OID() asn1.ObjectIdentifier { + return c.oid +} + +func (c cipherWithBlock) Encrypt(key, iv, plaintext []byte) ([]byte, error) { + block, err := c.newBlock(key) + if err != nil { + return nil, err + } + return cbcEncrypt(block, key, iv, plaintext) +} + +func (c cipherWithBlock) Decrypt(key, iv, ciphertext []byte) ([]byte, error) { + block, err := c.newBlock(key) + if err != nil { + return nil, err + } + return cbcDecrypt(block, key, iv, ciphertext) +} + +func cbcEncrypt(block cipher.Block, key, iv, plaintext []byte) ([]byte, error) { + mode := cipher.NewCBCEncrypter(block, iv) + paddingLen := block.BlockSize() - (len(plaintext) % block.BlockSize()) + ciphertext := make([]byte, len(plaintext)+paddingLen) + copy(ciphertext, plaintext) + copy(ciphertext[len(plaintext):], bytes.Repeat([]byte{byte(paddingLen)}, paddingLen)) + mode.CryptBlocks(ciphertext, ciphertext) + return ciphertext, nil +} + +func cbcDecrypt(block cipher.Block, key, iv, ciphertext []byte) ([]byte, error) { + mode := cipher.NewCBCDecrypter(block, iv) + plaintext := make([]byte, len(ciphertext)) + mode.CryptBlocks(plaintext, ciphertext) + // TODO: remove padding + return plaintext, nil +} diff --git a/vendor/github.com/youmark/pkcs8/cipher_3des.go b/vendor/github.com/youmark/pkcs8/cipher_3des.go new file mode 100644 index 0000000..5629664 --- /dev/null +++ b/vendor/github.com/youmark/pkcs8/cipher_3des.go @@ -0,0 +1,24 @@ +package pkcs8 + +import ( + "crypto/des" + "encoding/asn1" +) + +var ( + oidDESEDE3CBC = asn1.ObjectIdentifier{1, 2, 840, 113549, 3, 7} +) + +func init() { + RegisterCipher(oidDESEDE3CBC, func() Cipher { + return TripleDESCBC + }) +} + +// TripleDESCBC is the 168-bit key 3DES cipher in CBC mode. +var TripleDESCBC = cipherWithBlock{ + ivSize: des.BlockSize, + keySize: 24, + newBlock: des.NewTripleDESCipher, + oid: oidDESEDE3CBC, +} diff --git a/vendor/github.com/youmark/pkcs8/cipher_aes.go b/vendor/github.com/youmark/pkcs8/cipher_aes.go new file mode 100644 index 0000000..c0372d1 --- /dev/null +++ b/vendor/github.com/youmark/pkcs8/cipher_aes.go @@ -0,0 +1,84 @@ +package pkcs8 + +import ( + "crypto/aes" + "encoding/asn1" +) + +var ( + oidAES128CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 2} + oidAES128GCM = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 6} + oidAES192CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 22} + oidAES192GCM = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 26} + oidAES256CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 42} + oidAES256GCM = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 46} +) + +func init() { + RegisterCipher(oidAES128CBC, func() Cipher { + return AES128CBC + }) + RegisterCipher(oidAES128GCM, func() Cipher { + return AES128GCM + }) + RegisterCipher(oidAES192CBC, func() Cipher { + return AES192CBC + }) + RegisterCipher(oidAES192GCM, func() Cipher { + return AES192GCM + }) + RegisterCipher(oidAES256CBC, func() Cipher { + return AES256CBC + }) + RegisterCipher(oidAES256GCM, func() Cipher { + return AES256GCM + }) +} + +// AES128CBC is the 128-bit key AES cipher in CBC mode. +var AES128CBC = cipherWithBlock{ + ivSize: aes.BlockSize, + keySize: 16, + newBlock: aes.NewCipher, + oid: oidAES128CBC, +} + +// AES128GCM is the 128-bit key AES cipher in GCM mode. +var AES128GCM = cipherWithBlock{ + ivSize: aes.BlockSize, + keySize: 16, + newBlock: aes.NewCipher, + oid: oidAES128GCM, +} + +// AES192CBC is the 192-bit key AES cipher in CBC mode. +var AES192CBC = cipherWithBlock{ + ivSize: aes.BlockSize, + keySize: 24, + newBlock: aes.NewCipher, + oid: oidAES192CBC, +} + +// AES192GCM is the 912-bit key AES cipher in GCM mode. +var AES192GCM = cipherWithBlock{ + ivSize: aes.BlockSize, + keySize: 24, + newBlock: aes.NewCipher, + oid: oidAES192GCM, +} + +// AES256CBC is the 256-bit key AES cipher in CBC mode. +var AES256CBC = cipherWithBlock{ + ivSize: aes.BlockSize, + keySize: 32, + newBlock: aes.NewCipher, + oid: oidAES256CBC, +} + +// AES256GCM is the 256-bit key AES cipher in GCM mode. +var AES256GCM = cipherWithBlock{ + ivSize: aes.BlockSize, + keySize: 32, + newBlock: aes.NewCipher, + oid: oidAES256GCM, +} diff --git a/vendor/github.com/youmark/pkcs8/go.mod b/vendor/github.com/youmark/pkcs8/go.mod new file mode 100644 index 0000000..c4e88d1 --- /dev/null +++ b/vendor/github.com/youmark/pkcs8/go.mod @@ -0,0 +1,5 @@ +module github.com/youmark/pkcs8 + +go 1.17 + +require golang.org/x/crypto v0.22.0 diff --git a/vendor/github.com/youmark/pkcs8/go.sum b/vendor/github.com/youmark/pkcs8/go.sum new file mode 100644 index 0000000..ce62e45 --- /dev/null +++ b/vendor/github.com/youmark/pkcs8/go.sum @@ -0,0 +1,2 @@ +golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= +golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= diff --git a/vendor/github.com/youmark/pkcs8/kdf_pbkdf2.go b/vendor/github.com/youmark/pkcs8/kdf_pbkdf2.go new file mode 100644 index 0000000..79697dd --- /dev/null +++ b/vendor/github.com/youmark/pkcs8/kdf_pbkdf2.go @@ -0,0 +1,91 @@ +package pkcs8 + +import ( + "crypto" + "crypto/sha1" + "crypto/sha256" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "hash" + + "golang.org/x/crypto/pbkdf2" +) + +var ( + oidPKCS5PBKDF2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 12} + oidHMACWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 7} + oidHMACWithSHA256 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 9} +) + +func init() { + RegisterKDF(oidPKCS5PBKDF2, func() KDFParameters { + return new(pbkdf2Params) + }) +} + +func newHashFromPRF(ai pkix.AlgorithmIdentifier) (func() hash.Hash, error) { + switch { + case len(ai.Algorithm) == 0 || ai.Algorithm.Equal(oidHMACWithSHA1): + return sha1.New, nil + case ai.Algorithm.Equal(oidHMACWithSHA256): + return sha256.New, nil + default: + return nil, errors.New("pkcs8: unsupported hash function") + } +} + +func newPRFParamFromHash(h crypto.Hash) (pkix.AlgorithmIdentifier, error) { + switch h { + case crypto.SHA1: + return pkix.AlgorithmIdentifier{ + Algorithm: oidHMACWithSHA1, + Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil + case crypto.SHA256: + return pkix.AlgorithmIdentifier{ + Algorithm: oidHMACWithSHA256, + Parameters: asn1.RawValue{Tag: asn1.TagNull}}, nil + } + return pkix.AlgorithmIdentifier{}, errors.New("pkcs8: unsupported hash function") +} + +type pbkdf2Params struct { + Salt []byte + IterationCount int + PRF pkix.AlgorithmIdentifier `asn1:"optional"` +} + +func (p pbkdf2Params) DeriveKey(password []byte, size int) (key []byte, err error) { + h, err := newHashFromPRF(p.PRF) + if err != nil { + return nil, err + } + return pbkdf2.Key(password, p.Salt, p.IterationCount, size, h), nil +} + +// PBKDF2Opts contains options for the PBKDF2 key derivation function. +type PBKDF2Opts struct { + SaltSize int + IterationCount int + HMACHash crypto.Hash +} + +func (p PBKDF2Opts) DeriveKey(password, salt []byte, size int) ( + key []byte, params KDFParameters, err error) { + + key = pbkdf2.Key(password, salt, p.IterationCount, size, p.HMACHash.New) + prfParam, err := newPRFParamFromHash(p.HMACHash) + if err != nil { + return nil, nil, err + } + params = pbkdf2Params{salt, p.IterationCount, prfParam} + return key, params, nil +} + +func (p PBKDF2Opts) GetSaltSize() int { + return p.SaltSize +} + +func (p PBKDF2Opts) OID() asn1.ObjectIdentifier { + return oidPKCS5PBKDF2 +} diff --git a/vendor/github.com/youmark/pkcs8/kdf_scrypt.go b/vendor/github.com/youmark/pkcs8/kdf_scrypt.go new file mode 100644 index 0000000..36c4f4f --- /dev/null +++ b/vendor/github.com/youmark/pkcs8/kdf_scrypt.go @@ -0,0 +1,62 @@ +package pkcs8 + +import ( + "encoding/asn1" + + "golang.org/x/crypto/scrypt" +) + +var ( + oidScrypt = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11591, 4, 11} +) + +func init() { + RegisterKDF(oidScrypt, func() KDFParameters { + return new(scryptParams) + }) +} + +type scryptParams struct { + Salt []byte + CostParameter int + BlockSize int + ParallelizationParameter int +} + +func (p scryptParams) DeriveKey(password []byte, size int) (key []byte, err error) { + return scrypt.Key(password, p.Salt, p.CostParameter, p.BlockSize, + p.ParallelizationParameter, size) +} + +// ScryptOpts contains options for the scrypt key derivation function. +type ScryptOpts struct { + SaltSize int + CostParameter int + BlockSize int + ParallelizationParameter int +} + +func (p ScryptOpts) DeriveKey(password, salt []byte, size int) ( + key []byte, params KDFParameters, err error) { + + key, err = scrypt.Key(password, salt, p.CostParameter, p.BlockSize, + p.ParallelizationParameter, size) + if err != nil { + return nil, nil, err + } + params = scryptParams{ + BlockSize: p.BlockSize, + CostParameter: p.CostParameter, + ParallelizationParameter: p.ParallelizationParameter, + Salt: salt, + } + return key, params, nil +} + +func (p ScryptOpts) GetSaltSize() int { + return p.SaltSize +} + +func (p ScryptOpts) OID() asn1.ObjectIdentifier { + return oidScrypt +} diff --git a/vendor/github.com/youmark/pkcs8/pkcs8.go b/vendor/github.com/youmark/pkcs8/pkcs8.go index 9270a79..f27f627 100644 --- a/vendor/github.com/youmark/pkcs8/pkcs8.go +++ b/vendor/github.com/youmark/pkcs8/pkcs8.go @@ -2,304 +2,308 @@ package pkcs8 import ( - "crypto/aes" - "crypto/cipher" - "crypto/des" + "crypto" "crypto/ecdsa" - "crypto/elliptic" "crypto/rand" "crypto/rsa" - "crypto/sha1" - "crypto/sha256" "crypto/x509" + "crypto/x509/pkix" "encoding/asn1" "errors" - - "golang.org/x/crypto/pbkdf2" -) - -// Copy from crypto/x509 -var ( - oidPublicKeyRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1} - oidPublicKeyDSA = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 1} - oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} + "fmt" ) -// Copy from crypto/x509 -var ( - oidNamedCurveP224 = asn1.ObjectIdentifier{1, 3, 132, 0, 33} - oidNamedCurveP256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 3, 1, 7} - oidNamedCurveP384 = asn1.ObjectIdentifier{1, 3, 132, 0, 34} - oidNamedCurveP521 = asn1.ObjectIdentifier{1, 3, 132, 0, 35} -) +// DefaultOpts are the default options for encrypting a key if none are given. +// The defaults can be changed by the library user. +var DefaultOpts = &Opts{ + Cipher: AES256CBC, + KDFOpts: PBKDF2Opts{ + SaltSize: 8, + IterationCount: 10000, + HMACHash: crypto.SHA256, + }, +} -// Copy from crypto/x509 -func oidFromNamedCurve(curve elliptic.Curve) (asn1.ObjectIdentifier, bool) { - switch curve { - case elliptic.P224(): - return oidNamedCurveP224, true - case elliptic.P256(): - return oidNamedCurveP256, true - case elliptic.P384(): - return oidNamedCurveP384, true - case elliptic.P521(): - return oidNamedCurveP521, true - } +// KDFOpts contains options for a key derivation function. +// An implementation of this interface must be specified when encrypting a PKCS#8 key. +type KDFOpts interface { + // DeriveKey derives a key of size bytes from the given password and salt. + // It returns the key and the ASN.1-encodable parameters used. + DeriveKey(password, salt []byte, size int) (key []byte, params KDFParameters, err error) + // GetSaltSize returns the salt size specified. + GetSaltSize() int + // OID returns the OID of the KDF specified. + OID() asn1.ObjectIdentifier +} - return nil, false +// KDFParameters contains parameters (salt, etc.) for a key deriviation function. +// It must be a ASN.1-decodable structure. +// An implementation of this interface is created when decoding an encrypted PKCS#8 key. +type KDFParameters interface { + // DeriveKey derives a key of size bytes from the given password. + // It uses the salt from the decoded parameters. + DeriveKey(password []byte, size int) (key []byte, err error) } -// Unecrypted PKCS8 -var ( - oidPKCS5PBKDF2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 12} - oidPBES2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13} - oidAES256CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 42} - oidAES128CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 2} - oidHMACWithSHA256 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 9} - oidDESEDE3CBC = asn1.ObjectIdentifier{1, 2, 840, 113549, 3, 7} -) +var kdfs = make(map[string]func() KDFParameters) -type ecPrivateKey struct { - Version int - PrivateKey []byte - NamedCurveOID asn1.ObjectIdentifier `asn1:"optional,explicit,tag:0"` - PublicKey asn1.BitString `asn1:"optional,explicit,tag:1"` +// RegisterKDF registers a function that returns a new instance of the given KDF +// parameters. This allows the library to support client-provided KDFs. +func RegisterKDF(oid asn1.ObjectIdentifier, params func() KDFParameters) { + kdfs[oid.String()] = params } -type privateKeyInfo struct { - Version int - PrivateKeyAlgorithm []asn1.ObjectIdentifier - PrivateKey []byte +// Cipher represents a cipher for encrypting the key material. +type Cipher interface { + // IVSize returns the IV size of the cipher, in bytes. + IVSize() int + // KeySize returns the key size of the cipher, in bytes. + KeySize() int + // Encrypt encrypts the key material. + Encrypt(key, iv, plaintext []byte) ([]byte, error) + // Decrypt decrypts the key material. + Decrypt(key, iv, ciphertext []byte) ([]byte, error) + // OID returns the OID of the cipher specified. + OID() asn1.ObjectIdentifier } -// Encrypted PKCS8 -type prfParam struct { - IdPRF asn1.ObjectIdentifier - NullParam asn1.RawValue -} +var ciphers = make(map[string]func() Cipher) -type pbkdf2Params struct { - Salt []byte - IterationCount int - PrfParam prfParam `asn1:"optional"` +// RegisterCipher registers a function that returns a new instance of the given +// cipher. This allows the library to support client-provided ciphers. +func RegisterCipher(oid asn1.ObjectIdentifier, cipher func() Cipher) { + ciphers[oid.String()] = cipher } -type pbkdf2Algorithms struct { - IdPBKDF2 asn1.ObjectIdentifier - PBKDF2Params pbkdf2Params +// Opts contains options for encrypting a PKCS#8 key. +type Opts struct { + Cipher Cipher + KDFOpts KDFOpts } -type pbkdf2Encs struct { - EncryAlgo asn1.ObjectIdentifier - IV []byte -} +// Unecrypted PKCS8 +var ( + oidPBES2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13} +) -type pbes2Params struct { - KeyDerivationFunc pbkdf2Algorithms - EncryptionScheme pbkdf2Encs +type encryptedPrivateKeyInfo struct { + EncryptionAlgorithm pkix.AlgorithmIdentifier + EncryptedData []byte } -type pbes2Algorithms struct { - IdPBES2 asn1.ObjectIdentifier - PBES2Params pbes2Params +type pbes2Params struct { + KeyDerivationFunc pkix.AlgorithmIdentifier + EncryptionScheme pkix.AlgorithmIdentifier } -type encryptedPrivateKeyInfo struct { - EncryptionAlgorithm pbes2Algorithms - EncryptedData []byte +type privateKeyInfo struct { + Version int + PrivateKeyAlgorithm pkix.AlgorithmIdentifier + PrivateKey []byte } -// ParsePKCS8PrivateKeyRSA parses encrypted/unencrypted private keys in PKCS#8 format. To parse encrypted private keys, a password of []byte type should be provided to the function as the second parameter. -// -// The function can decrypt the private key encrypted with AES-256-CBC mode, and stored in PKCS #5 v2.0 format. -func ParsePKCS8PrivateKeyRSA(der []byte, v ...[]byte) (*rsa.PrivateKey, error) { - key, err := ParsePKCS8PrivateKey(der, v...) - if err != nil { - return nil, err - } - typedKey, ok := key.(*rsa.PrivateKey) +func parseKeyDerivationFunc(keyDerivationFunc pkix.AlgorithmIdentifier) (KDFParameters, error) { + oid := keyDerivationFunc.Algorithm.String() + newParams, ok := kdfs[oid] if !ok { - return nil, errors.New("key block is not of type RSA") + return nil, fmt.Errorf("pkcs8: unsupported KDF (OID: %s)", oid) } - return typedKey, nil -} - -// ParsePKCS8PrivateKeyECDSA parses encrypted/unencrypted private keys in PKCS#8 format. To parse encrypted private keys, a password of []byte type should be provided to the function as the second parameter. -// -// The function can decrypt the private key encrypted with AES-256-CBC mode, and stored in PKCS #5 v2.0 format. -func ParsePKCS8PrivateKeyECDSA(der []byte, v ...[]byte) (*ecdsa.PrivateKey, error) { - key, err := ParsePKCS8PrivateKey(der, v...) + params := newParams() + _, err := asn1.Unmarshal(keyDerivationFunc.Parameters.FullBytes, params) if err != nil { - return nil, err + return nil, errors.New("pkcs8: invalid KDF parameters") } - typedKey, ok := key.(*ecdsa.PrivateKey) + return params, nil +} + +func parseEncryptionScheme(encryptionScheme pkix.AlgorithmIdentifier) (Cipher, []byte, error) { + oid := encryptionScheme.Algorithm.String() + newCipher, ok := ciphers[oid] if !ok { - return nil, errors.New("key block is not of type ECDSA") + return nil, nil, fmt.Errorf("pkcs8: unsupported cipher (OID: %s)", oid) } - return typedKey, nil + cipher := newCipher() + var iv []byte + if _, err := asn1.Unmarshal(encryptionScheme.Parameters.FullBytes, &iv); err != nil { + return nil, nil, errors.New("pkcs8: invalid cipher parameters") + } + return cipher, iv, nil } -// ParsePKCS8PrivateKey parses encrypted/unencrypted private keys in PKCS#8 format. To parse encrypted private keys, a password of []byte type should be provided to the function as the second parameter. -// -// The function can decrypt the private key encrypted with AES-256-CBC mode, and stored in PKCS #5 v2.0 format. -func ParsePKCS8PrivateKey(der []byte, v ...[]byte) (interface{}, error) { +// ParsePrivateKey parses a DER-encoded PKCS#8 private key. +// Password can be nil. +// This is equivalent to ParsePKCS8PrivateKey. +func ParsePrivateKey(der []byte, password []byte) (interface{}, KDFParameters, error) { // No password provided, assume the private key is unencrypted - if v == nil { - return x509.ParsePKCS8PrivateKey(der) + if len(password) == 0 { + privateKey, err := x509.ParsePKCS8PrivateKey(der) + return privateKey, nil, err } // Use the password provided to decrypt the private key - password := v[0] var privKey encryptedPrivateKeyInfo if _, err := asn1.Unmarshal(der, &privKey); err != nil { - return nil, errors.New("pkcs8: only PKCS #5 v2.0 supported") + return nil, nil, errors.New("pkcs8: only PKCS #5 v2.0 supported") } - if !privKey.EncryptionAlgorithm.IdPBES2.Equal(oidPBES2) { - return nil, errors.New("pkcs8: only PBES2 supported") + if !privKey.EncryptionAlgorithm.Algorithm.Equal(oidPBES2) { + return nil, nil, errors.New("pkcs8: only PBES2 supported") } - if !privKey.EncryptionAlgorithm.PBES2Params.KeyDerivationFunc.IdPBKDF2.Equal(oidPKCS5PBKDF2) { - return nil, errors.New("pkcs8: only PBKDF2 supported") + var params pbes2Params + if _, err := asn1.Unmarshal(privKey.EncryptionAlgorithm.Parameters.FullBytes, ¶ms); err != nil { + return nil, nil, errors.New("pkcs8: invalid PBES2 parameters") } - encParam := privKey.EncryptionAlgorithm.PBES2Params.EncryptionScheme - kdfParam := privKey.EncryptionAlgorithm.PBES2Params.KeyDerivationFunc.PBKDF2Params + cipher, iv, err := parseEncryptionScheme(params.EncryptionScheme) + if err != nil { + return nil, nil, err + } - iv := encParam.IV - salt := kdfParam.Salt - iter := kdfParam.IterationCount - keyHash := sha1.New - if kdfParam.PrfParam.IdPRF.Equal(oidHMACWithSHA256) { - keyHash = sha256.New + kdfParams, err := parseKeyDerivationFunc(params.KeyDerivationFunc) + if err != nil { + return nil, nil, err } - encryptedKey := privKey.EncryptedData - var symkey []byte - var block cipher.Block - var err error - switch { - case encParam.EncryAlgo.Equal(oidAES128CBC): - symkey = pbkdf2.Key(password, salt, iter, 16, keyHash) - block, err = aes.NewCipher(symkey) - case encParam.EncryAlgo.Equal(oidAES256CBC): - symkey = pbkdf2.Key(password, salt, iter, 32, keyHash) - block, err = aes.NewCipher(symkey) - case encParam.EncryAlgo.Equal(oidDESEDE3CBC): - symkey = pbkdf2.Key(password, salt, iter, 24, keyHash) - block, err = des.NewTripleDESCipher(symkey) - default: - return nil, errors.New("pkcs8: only AES-256-CBC, AES-128-CBC and DES-EDE3-CBC are supported") + keySize := cipher.KeySize() + symkey, err := kdfParams.DeriveKey(password, keySize) + if err != nil { + return nil, nil, err } + + encryptedKey := privKey.EncryptedData + decryptedKey, err := cipher.Decrypt(symkey, iv, encryptedKey) if err != nil { - return nil, err + return nil, nil, err } - mode := cipher.NewCBCDecrypter(block, iv) - mode.CryptBlocks(encryptedKey, encryptedKey) - key, err := x509.ParsePKCS8PrivateKey(encryptedKey) + key, err := x509.ParsePKCS8PrivateKey(decryptedKey) if err != nil { - return nil, errors.New("pkcs8: incorrect password") + return nil, nil, errors.New("pkcs8: incorrect password") } - return key, nil + return key, kdfParams, nil } -func convertPrivateKeyToPKCS8(priv interface{}) ([]byte, error) { - var pkey privateKeyInfo - - switch priv := priv.(type) { - case *ecdsa.PrivateKey: - eckey, err := x509.MarshalECPrivateKey(priv) - if err != nil { - return nil, err - } - - oidNamedCurve, ok := oidFromNamedCurve(priv.Curve) - if !ok { - return nil, errors.New("pkcs8: unknown elliptic curve") - } - - // Per RFC5958, if publicKey is present, then version is set to v2(1) else version is set to v1(0). - // But openssl set to v1 even publicKey is present - pkey.Version = 1 - pkey.PrivateKeyAlgorithm = make([]asn1.ObjectIdentifier, 2) - pkey.PrivateKeyAlgorithm[0] = oidPublicKeyECDSA - pkey.PrivateKeyAlgorithm[1] = oidNamedCurve - pkey.PrivateKey = eckey - case *rsa.PrivateKey: - - // Per RFC5958, if publicKey is present, then version is set to v2(1) else version is set to v1(0). - // But openssl set to v1 even publicKey is present - pkey.Version = 0 - pkey.PrivateKeyAlgorithm = make([]asn1.ObjectIdentifier, 1) - pkey.PrivateKeyAlgorithm[0] = oidPublicKeyRSA - pkey.PrivateKey = x509.MarshalPKCS1PrivateKey(priv) - } - - return asn1.Marshal(pkey) -} +// MarshalPrivateKey encodes a private key into DER-encoded PKCS#8 with the given options. +// Password can be nil. +func MarshalPrivateKey(priv interface{}, password []byte, opts *Opts) ([]byte, error) { + if len(password) == 0 { + return x509.MarshalPKCS8PrivateKey(priv) + } + + if opts == nil { + opts = DefaultOpts + } -func convertPrivateKeyToPKCS8Encrypted(priv interface{}, password []byte) ([]byte, error) { // Convert private key into PKCS8 format - pkey, err := convertPrivateKeyToPKCS8(priv) + pkey, err := x509.MarshalPKCS8PrivateKey(priv) if err != nil { return nil, err } - // Calculate key from password based on PKCS5 algorithm - // Use 8 byte salt, 16 byte IV, and 2048 iteration - iter := 2048 - salt := make([]byte, 8) - iv := make([]byte, 16) + encAlg := opts.Cipher + salt := make([]byte, opts.KDFOpts.GetSaltSize()) _, err = rand.Read(salt) if err != nil { return nil, err } + iv := make([]byte, encAlg.IVSize()) _, err = rand.Read(iv) if err != nil { return nil, err } + key, kdfParams, err := opts.KDFOpts.DeriveKey(password, salt, encAlg.KeySize()) + if err != nil { + return nil, err + } - key := pbkdf2.Key(password, salt, iter, 32, sha256.New) + encryptedKey, err := encAlg.Encrypt(key, iv, pkey) + if err != nil { + return nil, err + } - // Use AES256-CBC mode, pad plaintext with PKCS5 padding scheme - padding := aes.BlockSize - len(pkey)%aes.BlockSize - if padding > 0 { - n := len(pkey) - pkey = append(pkey, make([]byte, padding)...) - for i := 0; i < padding; i++ { - pkey[n+i] = byte(padding) - } + marshalledParams, err := asn1.Marshal(kdfParams) + if err != nil { + return nil, err + } + keyDerivationFunc := pkix.AlgorithmIdentifier{ + Algorithm: opts.KDFOpts.OID(), + Parameters: asn1.RawValue{FullBytes: marshalledParams}, + } + marshalledIV, err := asn1.Marshal(iv) + if err != nil { + return nil, err + } + encryptionScheme := pkix.AlgorithmIdentifier{ + Algorithm: encAlg.OID(), + Parameters: asn1.RawValue{FullBytes: marshalledIV}, } - encryptedKey := make([]byte, len(pkey)) - block, err := aes.NewCipher(key) + encryptionAlgorithmParams := pbes2Params{ + EncryptionScheme: encryptionScheme, + KeyDerivationFunc: keyDerivationFunc, + } + marshalledEncryptionAlgorithmParams, err := asn1.Marshal(encryptionAlgorithmParams) if err != nil { return nil, err } - mode := cipher.NewCBCEncrypter(block, iv) - mode.CryptBlocks(encryptedKey, pkey) + encryptionAlgorithm := pkix.AlgorithmIdentifier{ + Algorithm: oidPBES2, + Parameters: asn1.RawValue{FullBytes: marshalledEncryptionAlgorithmParams}, + } - // pbkdf2algo := pbkdf2Algorithms{oidPKCS5PBKDF2, pbkdf2Params{salt, iter, prfParam{oidHMACWithSHA256}}} + encryptedPkey := encryptedPrivateKeyInfo{ + EncryptionAlgorithm: encryptionAlgorithm, + EncryptedData: encryptedKey, + } - pbkdf2algo := pbkdf2Algorithms{oidPKCS5PBKDF2, pbkdf2Params{salt, iter, prfParam{oidHMACWithSHA256, asn1.RawValue{Tag: asn1.TagNull}}}} - pbkdf2encs := pbkdf2Encs{oidAES256CBC, iv} - pbes2algo := pbes2Algorithms{oidPBES2, pbes2Params{pbkdf2algo, pbkdf2encs}} + return asn1.Marshal(encryptedPkey) +} - encryptedPkey := encryptedPrivateKeyInfo{pbes2algo, encryptedKey} +// ParsePKCS8PrivateKey parses encrypted/unencrypted private keys in PKCS#8 format. To parse encrypted private keys, a password of []byte type should be provided to the function as the second parameter. +func ParsePKCS8PrivateKey(der []byte, v ...[]byte) (interface{}, error) { + var password []byte + if len(v) > 0 { + password = v[0] + } + privateKey, _, err := ParsePrivateKey(der, password) + return privateKey, err +} - return asn1.Marshal(encryptedPkey) +// ParsePKCS8PrivateKeyRSA parses encrypted/unencrypted private keys in PKCS#8 format. To parse encrypted private keys, a password of []byte type should be provided to the function as the second parameter. +func ParsePKCS8PrivateKeyRSA(der []byte, v ...[]byte) (*rsa.PrivateKey, error) { + key, err := ParsePKCS8PrivateKey(der, v...) + if err != nil { + return nil, err + } + typedKey, ok := key.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("key block is not of type RSA") + } + return typedKey, nil +} + +// ParsePKCS8PrivateKeyECDSA parses encrypted/unencrypted private keys in PKCS#8 format. To parse encrypted private keys, a password of []byte type should be provided to the function as the second parameter. +func ParsePKCS8PrivateKeyECDSA(der []byte, v ...[]byte) (*ecdsa.PrivateKey, error) { + key, err := ParsePKCS8PrivateKey(der, v...) + if err != nil { + return nil, err + } + typedKey, ok := key.(*ecdsa.PrivateKey) + if !ok { + return nil, errors.New("key block is not of type ECDSA") + } + return typedKey, nil } // ConvertPrivateKeyToPKCS8 converts the private key into PKCS#8 format. // To encrypt the private key, the password of []byte type should be provided as the second parameter. // -// The only supported key types are RSA and ECDSA (*rsa.PublicKey or *ecdsa.PublicKey for priv) +// The only supported key types are RSA and ECDSA (*rsa.PrivateKey or *ecdsa.PrivateKey for priv) func ConvertPrivateKeyToPKCS8(priv interface{}, v ...[]byte) ([]byte, error) { - if v == nil { - return convertPrivateKeyToPKCS8(priv) + var password []byte + if len(v) > 0 { + password = v[0] } - - password := string(v[0]) - return convertPrivateKeyToPKCS8Encrypted(priv, []byte(password)) + return MarshalPrivateKey(priv, password, nil) } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bson.go b/vendor/go.mongodb.org/mongo-driver/bson/bson.go index 95ffc10..a0d8185 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bson.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bson.go @@ -27,7 +27,7 @@ type Zeroer interface { // // Example usage: // -// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}} +// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}} type D = primitive.D // E represents a BSON element for a D. It is usually used inside a D. @@ -39,12 +39,12 @@ type E = primitive.E // // Example usage: // -// bson.M{"foo": "bar", "hello": "world", "pi": 3.14159} +// bson.M{"foo": "bar", "hello": "world", "pi": 3.14159} type M = primitive.M // An A is an ordered representation of a BSON array. // // Example usage: // -// bson.A{"bar", "world", 3.14159, bson.D{{"qux", 12345}}} +// bson.A{"bar", "world", 3.14159, bson.D{{"qux", 12345}}} type A = primitive.A diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/array_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/array_codec.go index 4e24f9e..652aa48 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/array_codec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/array_codec.go @@ -14,17 +14,22 @@ import ( ) // ArrayCodec is the Codec used for bsoncore.Array values. +// +// Deprecated: ArrayCodec will not be directly accessible in Go Driver 2.0. type ArrayCodec struct{} var defaultArrayCodec = NewArrayCodec() // NewArrayCodec returns an ArrayCodec. +// +// Deprecated: NewArrayCodec will not be available in Go Driver 2.0. See +// [ArrayCodec] for more details. func NewArrayCodec() *ArrayCodec { return &ArrayCodec{} } // EncodeValue is the ValueEncoder for bsoncore.Array values. -func (ac *ArrayCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +func (ac *ArrayCodec) EncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCoreArray { return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val} } @@ -34,7 +39,7 @@ func (ac *ArrayCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val r } // DecodeValue is the ValueDecoder for bsoncore.Array values. -func (ac *ArrayCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { +func (ac *ArrayCodec) DecodeValue(_ DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCoreArray { return ValueDecoderError{Name: "CoreArrayDecodeValue", Types: []reflect.Type{tCoreArray}, Received: val} } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/bsoncodec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/bsoncodec.go index 96195bc..0693bd4 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/bsoncodec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/bsoncodec.go @@ -13,6 +13,7 @@ import ( "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" ) var ( @@ -22,6 +23,8 @@ var ( // Marshaler is an interface implemented by types that can marshal themselves // into a BSON document represented as bytes. The bytes returned must be a valid // BSON document if the error is nil. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Marshaler] instead. type Marshaler interface { MarshalBSON() ([]byte, error) } @@ -30,6 +33,8 @@ type Marshaler interface { // themselves into a BSON value as bytes. The type must be the valid type for // the bytes returned. The bytes and byte type together must be valid if the // error is nil. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.ValueMarshaler] instead. type ValueMarshaler interface { MarshalBSONValue() (bsontype.Type, []byte, error) } @@ -38,6 +43,8 @@ type ValueMarshaler interface { // document representation of themselves. The BSON bytes can be assumed to be // valid. UnmarshalBSON must copy the BSON bytes if it wishes to retain the data // after returning. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Unmarshaler] instead. type Unmarshaler interface { UnmarshalBSON([]byte) error } @@ -46,6 +53,8 @@ type Unmarshaler interface { // BSON value representation of themselves. The BSON bytes and type can be // assumed to be valid. UnmarshalBSONValue must copy the BSON value bytes if it // wishes to retain the data after returning. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.ValueUnmarshaler] instead. type ValueUnmarshaler interface { UnmarshalBSONValue(bsontype.Type, []byte) error } @@ -110,23 +119,176 @@ func (vde ValueDecoderError) Error() string { // value. type EncodeContext struct { *Registry + + // MinSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, + // uint, uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) + // that can represent the integer value. + // + // Deprecated: Use bson.Encoder.IntMinSize instead. MinSize bool + + errorOnInlineDuplicates bool + stringifyMapKeysWithFmt bool + nilMapAsEmpty bool + nilSliceAsEmpty bool + nilByteSliceAsEmpty bool + omitZeroStruct bool + useJSONStructTags bool +} + +// ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in +// the marshaled BSON when the "inline" struct tag option is set. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.ErrorOnInlineDuplicates] instead. +func (ec *EncodeContext) ErrorOnInlineDuplicates() { + ec.errorOnInlineDuplicates = true +} + +// StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name +// strings using fmt.Sprintf() instead of the default string conversion logic. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.StringifyMapKeysWithFmt] instead. +func (ec *EncodeContext) StringifyMapKeysWithFmt() { + ec.stringifyMapKeysWithFmt = true +} + +// NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON +// null. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilMapAsEmpty] instead. +func (ec *EncodeContext) NilMapAsEmpty() { + ec.nilMapAsEmpty = true +} + +// NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON +// null. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilSliceAsEmpty] instead. +func (ec *EncodeContext) NilSliceAsEmpty() { + ec.nilSliceAsEmpty = true +} + +// NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values +// instead of BSON null. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilByteSliceAsEmpty] instead. +func (ec *EncodeContext) NilByteSliceAsEmpty() { + ec.nilByteSliceAsEmpty = true +} + +// OmitZeroStruct causes the Encoder to consider the zero value for a struct (e.g. MyStruct{}) +// as empty and omit it from the marshaled BSON when the "omitempty" struct tag option is set. +// +// Note that the Encoder only examines exported struct fields when determining if a struct is the +// zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.OmitZeroStruct] instead. +func (ec *EncodeContext) OmitZeroStruct() { + ec.omitZeroStruct = true +} + +// UseJSONStructTags causes the Encoder to fall back to using the "json" struct tag if a "bson" +// struct tag is not specified. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.UseJSONStructTags] instead. +func (ec *EncodeContext) UseJSONStructTags() { + ec.useJSONStructTags = true } // DecodeContext is the contextual information required for a Codec to decode a // value. type DecodeContext struct { *Registry + + // Truncate, if true, instructs decoders to to truncate the fractional part of BSON "double" + // values when attempting to unmarshal them into a Go integer (int, int8, int16, int32, int64, + // uint, uint8, uint16, uint32, or uint64) struct field. The truncation logic does not apply to + // BSON "decimal128" values. + // + // Deprecated: Use bson.Decoder.AllowTruncatingDoubles instead. Truncate bool + // Ancestor is the type of a containing document. This is mainly used to determine what type // should be used when decoding an embedded document into an empty interface. For example, if // Ancestor is a bson.M, BSON embedded document values being decoded into an empty interface // will be decoded into a bson.M. + // + // Deprecated: Use bson.Decoder.DefaultDocumentM or bson.Decoder.DefaultDocumentD instead. Ancestor reflect.Type + + // defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the + // usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is + // set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an + // error. DocumentType overrides the Ancestor field. + defaultDocumentType reflect.Type + + binaryAsSlice bool + useJSONStructTags bool + useLocalTimeZone bool + zeroMaps bool + zeroStructs bool +} + +// BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or +// "Old" BSON binary subtype as a Go byte slice instead of a primitive.Binary. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.BinaryAsSlice] instead. +func (dc *DecodeContext) BinaryAsSlice() { + dc.binaryAsSlice = true } -// ValueCodec is the interface that groups the methods to encode and decode +// UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson" +// struct tag is not specified. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.UseJSONStructTags] instead. +func (dc *DecodeContext) UseJSONStructTags() { + dc.useJSONStructTags = true +} + +// UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead +// of the UTC timezone. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.UseLocalTimeZone] instead. +func (dc *DecodeContext) UseLocalTimeZone() { + dc.useLocalTimeZone = true +} + +// ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value +// passed to Decode before unmarshaling BSON documents into them. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.ZeroMaps] instead. +func (dc *DecodeContext) ZeroMaps() { + dc.zeroMaps = true +} + +// ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination +// value passed to Decode before unmarshaling BSON documents into them. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.ZeroStructs] instead. +func (dc *DecodeContext) ZeroStructs() { + dc.zeroStructs = true +} + +// DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This +// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.DefaultDocumentM] instead. +func (dc *DecodeContext) DefaultDocumentM() { + dc.defaultDocumentType = reflect.TypeOf(primitive.M{}) +} + +// DefaultDocumentD causes the Decoder to always unmarshal documents into the primitive.D type. This +// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.DefaultDocumentD] instead. +func (dc *DecodeContext) DefaultDocumentD() { + dc.defaultDocumentType = reflect.TypeOf(primitive.D{}) +} + +// ValueCodec is an interface for encoding and decoding a reflect.Value. // values. +// +// Deprecated: Use [ValueEncoder] and [ValueDecoder] instead. type ValueCodec interface { ValueEncoder ValueDecoder @@ -211,6 +373,10 @@ func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext // CodecZeroer is the interface implemented by Codecs that can also determine if // a value of the type that would be encoded is zero. +// +// Deprecated: Defining custom rules for the zero/empty value will not be supported in Go Driver +// 2.0. Users who want to omit empty complex values should use a pointer field and set the value to +// nil instead. type CodecZeroer interface { IsTypeZero(interface{}) bool } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/byte_slice_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/byte_slice_codec.go index 5a916cc..0134b5a 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/byte_slice_codec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/byte_slice_codec.go @@ -16,18 +16,45 @@ import ( ) // ByteSliceCodec is the Codec used for []byte values. +// +// Deprecated: ByteSliceCodec will not be directly configurable in Go Driver +// 2.0. To configure the byte slice encode and decode behavior, use the +// configuration methods on a [go.mongodb.org/mongo-driver/bson.Encoder] or +// [go.mongodb.org/mongo-driver/bson.Decoder]. To configure the byte slice +// encode and decode behavior for a mongo.Client, use +// [go.mongodb.org/mongo-driver/mongo/options.ClientOptions.SetBSONOptions]. +// +// For example, to configure a mongo.Client to encode nil byte slices as empty +// BSON binary values, use: +// +// opt := options.Client().SetBSONOptions(&options.BSONOptions{ +// NilByteSliceAsEmpty: true, +// }) +// +// See the deprecation notice for each field in ByteSliceCodec for the +// corresponding settings. type ByteSliceCodec struct { + // EncodeNilAsEmpty causes EncodeValue to marshal nil Go byte slices as empty BSON binary values + // instead of BSON null. + // + // Deprecated: Use bson.Encoder.NilByteSliceAsEmpty or options.BSONOptions.NilByteSliceAsEmpty + // instead. EncodeNilAsEmpty bool } var ( defaultByteSliceCodec = NewByteSliceCodec() - _ ValueCodec = defaultByteSliceCodec + // Assert that defaultByteSliceCodec satisfies the typeDecoder interface, which allows it to be + // used by collection type decoders (e.g. map, slice, etc) to set individual values in a + // collection. _ typeDecoder = defaultByteSliceCodec ) -// NewByteSliceCodec returns a StringCodec with options opts. +// NewByteSliceCodec returns a ByteSliceCodec with options opts. +// +// Deprecated: NewByteSliceCodec will not be available in Go Driver 2.0. See +// [ByteSliceCodec] for more details. func NewByteSliceCodec(opts ...*bsonoptions.ByteSliceCodecOptions) *ByteSliceCodec { byteSliceOpt := bsonoptions.MergeByteSliceCodecOptions(opts...) codec := ByteSliceCodec{} @@ -42,13 +69,13 @@ func (bsc *ByteSliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, if !val.IsValid() || val.Type() != tByteSlice { return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} } - if val.IsNil() && !bsc.EncodeNilAsEmpty { + if val.IsNil() && !bsc.EncodeNilAsEmpty && !ec.nilByteSliceAsEmpty { return vw.WriteNull() } return vw.WriteBinary(val.Interface().([]byte)) } -func (bsc *ByteSliceCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (bsc *ByteSliceCodec) decodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t != tByteSlice { return emptyValue, ValueDecoderError{ Name: "ByteSliceDecodeValue", diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/codec_cache.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/codec_cache.go new file mode 100644 index 0000000..844b502 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/codec_cache.go @@ -0,0 +1,166 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsoncodec + +import ( + "reflect" + "sync" + "sync/atomic" +) + +// Runtime check that the kind encoder and decoder caches can store any valid +// reflect.Kind constant. +func init() { + if s := reflect.Kind(len(kindEncoderCache{}.entries)).String(); s != "kind27" { + panic("The capacity of kindEncoderCache is too small.\n" + + "This is due to a new type being added to reflect.Kind.") + } +} + +// statically assert array size +var _ = (kindEncoderCache{}).entries[reflect.UnsafePointer] +var _ = (kindDecoderCache{}).entries[reflect.UnsafePointer] + +type typeEncoderCache struct { + cache sync.Map // map[reflect.Type]ValueEncoder +} + +func (c *typeEncoderCache) Store(rt reflect.Type, enc ValueEncoder) { + c.cache.Store(rt, enc) +} + +func (c *typeEncoderCache) Load(rt reflect.Type) (ValueEncoder, bool) { + if v, _ := c.cache.Load(rt); v != nil { + return v.(ValueEncoder), true + } + return nil, false +} + +func (c *typeEncoderCache) LoadOrStore(rt reflect.Type, enc ValueEncoder) ValueEncoder { + if v, loaded := c.cache.LoadOrStore(rt, enc); loaded { + enc = v.(ValueEncoder) + } + return enc +} + +func (c *typeEncoderCache) Clone() *typeEncoderCache { + cc := new(typeEncoderCache) + c.cache.Range(func(k, v interface{}) bool { + if k != nil && v != nil { + cc.cache.Store(k, v) + } + return true + }) + return cc +} + +type typeDecoderCache struct { + cache sync.Map // map[reflect.Type]ValueDecoder +} + +func (c *typeDecoderCache) Store(rt reflect.Type, dec ValueDecoder) { + c.cache.Store(rt, dec) +} + +func (c *typeDecoderCache) Load(rt reflect.Type) (ValueDecoder, bool) { + if v, _ := c.cache.Load(rt); v != nil { + return v.(ValueDecoder), true + } + return nil, false +} + +func (c *typeDecoderCache) LoadOrStore(rt reflect.Type, dec ValueDecoder) ValueDecoder { + if v, loaded := c.cache.LoadOrStore(rt, dec); loaded { + dec = v.(ValueDecoder) + } + return dec +} + +func (c *typeDecoderCache) Clone() *typeDecoderCache { + cc := new(typeDecoderCache) + c.cache.Range(func(k, v interface{}) bool { + if k != nil && v != nil { + cc.cache.Store(k, v) + } + return true + }) + return cc +} + +// atomic.Value requires that all calls to Store() have the same concrete type +// so we wrap the ValueEncoder with a kindEncoderCacheEntry to ensure the type +// is always the same (since different concrete types may implement the +// ValueEncoder interface). +type kindEncoderCacheEntry struct { + enc ValueEncoder +} + +type kindEncoderCache struct { + entries [reflect.UnsafePointer + 1]atomic.Value // *kindEncoderCacheEntry +} + +func (c *kindEncoderCache) Store(rt reflect.Kind, enc ValueEncoder) { + if enc != nil && rt < reflect.Kind(len(c.entries)) { + c.entries[rt].Store(&kindEncoderCacheEntry{enc: enc}) + } +} + +func (c *kindEncoderCache) Load(rt reflect.Kind) (ValueEncoder, bool) { + if rt < reflect.Kind(len(c.entries)) { + if ent, ok := c.entries[rt].Load().(*kindEncoderCacheEntry); ok { + return ent.enc, ent.enc != nil + } + } + return nil, false +} + +func (c *kindEncoderCache) Clone() *kindEncoderCache { + cc := new(kindEncoderCache) + for i, v := range c.entries { + if val := v.Load(); val != nil { + cc.entries[i].Store(val) + } + } + return cc +} + +// atomic.Value requires that all calls to Store() have the same concrete type +// so we wrap the ValueDecoder with a kindDecoderCacheEntry to ensure the type +// is always the same (since different concrete types may implement the +// ValueDecoder interface). +type kindDecoderCacheEntry struct { + dec ValueDecoder +} + +type kindDecoderCache struct { + entries [reflect.UnsafePointer + 1]atomic.Value // *kindDecoderCacheEntry +} + +func (c *kindDecoderCache) Store(rt reflect.Kind, dec ValueDecoder) { + if rt < reflect.Kind(len(c.entries)) { + c.entries[rt].Store(&kindDecoderCacheEntry{dec: dec}) + } +} + +func (c *kindDecoderCache) Load(rt reflect.Kind) (ValueDecoder, bool) { + if rt < reflect.Kind(len(c.entries)) { + if ent, ok := c.entries[rt].Load().(*kindDecoderCacheEntry); ok { + return ent.dec, ent.dec != nil + } + } + return nil, false +} + +func (c *kindDecoderCache) Clone() *kindDecoderCache { + cc := new(kindDecoderCache) + for i, v := range c.entries { + if val := v.Load(); val != nil { + cc.entries[i].Store(val) + } + } + return cc +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_decoders.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_decoders.go index 20f4797..8702d6d 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_decoders.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_decoders.go @@ -24,7 +24,7 @@ import ( var ( defaultValueDecoders DefaultValueDecoders - errCannotTruncate = errors.New("float64 can only be truncated to an integer type when truncation is enabled") + errCannotTruncate = errors.New("float64 can only be truncated to a lower precision type when truncation is enabled") ) type decodeBinaryError struct { @@ -41,13 +41,16 @@ func newDefaultStructCodec() *StructCodec { if err != nil { // This function is called from the codec registration path, so errors can't be propagated. If there's an error // constructing the StructCodec, we panic to avoid losing it. - panic(fmt.Errorf("error creating default StructCodec: %v", err)) + panic(fmt.Errorf("error creating default StructCodec: %w", err)) } return codec } // DefaultValueDecoders is a namespace type for the default ValueDecoders used // when creating a registry. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. type DefaultValueDecoders struct{} // RegisterDefaultDecoders will register the decoder methods attached to DefaultValueDecoders with @@ -56,6 +59,9 @@ type DefaultValueDecoders struct{} // There is no support for decoding map[string]interface{} because there is no decoder for // interface{}, so users must either register this decoder themselves or use the // EmptyInterfaceDecoder available in the bson package. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) RegisterDefaultDecoders(rb *RegistryBuilder) { if rb == nil { panic(errors.New("argument to RegisterDefaultDecoders must not be nil")) @@ -132,6 +138,9 @@ func (dvd DefaultValueDecoders) RegisterDefaultDecoders(rb *RegistryBuilder) { } // DDecodeValue is the ValueDecoderFunc for primitive.D instances. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.IsValid() || !val.CanSet() || val.Type() != tD { return ValueDecoderError{Name: "DDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} @@ -169,7 +178,7 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr bsonrw.ValueRe for { key, elemVr, err := dr.ReadElement() - if err == bsonrw.ErrEOD { + if errors.Is(err, bsonrw.ErrEOD) { break } else if err != nil { return err @@ -188,7 +197,7 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr bsonrw.ValueRe return nil } -func (dvd DefaultValueDecoders) booleanDecodeType(dctx DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (dvd DefaultValueDecoders) booleanDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t.Kind() != reflect.Bool { return emptyValue, ValueDecoderError{ Name: "BooleanDecodeValue", @@ -235,6 +244,9 @@ func (dvd DefaultValueDecoders) booleanDecodeType(dctx DecodeContext, vr bsonrw. } // BooleanDecodeValue is the ValueDecoderFunc for bool types. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) BooleanDecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.IsValid() || !val.CanSet() || val.Kind() != reflect.Bool { return ValueDecoderError{Name: "BooleanDecodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} @@ -318,7 +330,7 @@ func (DefaultValueDecoders) intDecodeType(dc DecodeContext, vr bsonrw.ValueReade case reflect.Int64: return reflect.ValueOf(i64), nil case reflect.Int: - if int64(int(i64)) != i64 { // Can we fit this inside of an int + if i64 > math.MaxInt { // Can we fit this inside of an int return emptyValue, fmt.Errorf("%d overflows int", i64) } @@ -333,6 +345,9 @@ func (DefaultValueDecoders) intDecodeType(dc DecodeContext, vr bsonrw.ValueReade } // IntDecodeValue is the ValueDecoderFunc for int types. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) IntDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() { return ValueDecoderError{ @@ -419,7 +434,7 @@ func (dvd DefaultValueDecoders) UintDecodeValue(dc DecodeContext, vr bsonrw.Valu return fmt.Errorf("%d overflows uint64", i64) } case reflect.Uint: - if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint + if i64 < 0 || uint64(i64) > uint64(math.MaxUint) { // Can we fit this inside of an uint return fmt.Errorf("%d overflows uint", i64) } default: @@ -434,7 +449,7 @@ func (dvd DefaultValueDecoders) UintDecodeValue(dc DecodeContext, vr bsonrw.Valu return nil } -func (dvd DefaultValueDecoders) floatDecodeType(ec DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (dvd DefaultValueDecoders) floatDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { var f float64 var err error switch vrType := vr.Type(); vrType { @@ -477,7 +492,7 @@ func (dvd DefaultValueDecoders) floatDecodeType(ec DecodeContext, vr bsonrw.Valu switch t.Kind() { case reflect.Float32: - if !ec.Truncate && float64(float32(f)) != f { + if !dc.Truncate && float64(float32(f)) != f { return emptyValue, errCannotTruncate } @@ -494,6 +509,9 @@ func (dvd DefaultValueDecoders) floatDecodeType(ec DecodeContext, vr bsonrw.Valu } // FloatDecodeValue is the ValueDecoderFunc for float types. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) FloatDecodeValue(ec DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() { return ValueDecoderError{ @@ -515,7 +533,7 @@ func (dvd DefaultValueDecoders) FloatDecodeValue(ec DecodeContext, vr bsonrw.Val // StringDecodeValue is the ValueDecoderFunc for string types. // // Deprecated: StringDecodeValue is not registered by default. Use StringCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) StringDecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { +func (dvd DefaultValueDecoders) StringDecodeValue(_ DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { var str string var err error switch vr.Type() { @@ -536,7 +554,7 @@ func (dvd DefaultValueDecoders) StringDecodeValue(dctx DecodeContext, vr bsonrw. return nil } -func (DefaultValueDecoders) javaScriptDecodeType(dctx DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (DefaultValueDecoders) javaScriptDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t != tJavaScript { return emptyValue, ValueDecoderError{ Name: "JavaScriptDecodeValue", @@ -565,6 +583,9 @@ func (DefaultValueDecoders) javaScriptDecodeType(dctx DecodeContext, vr bsonrw.V } // JavaScriptDecodeValue is the ValueDecoderFunc for the primitive.JavaScript type. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) JavaScriptDecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tJavaScript { return ValueDecoderError{Name: "JavaScriptDecodeValue", Types: []reflect.Type{tJavaScript}, Received: val} @@ -579,7 +600,7 @@ func (dvd DefaultValueDecoders) JavaScriptDecodeValue(dctx DecodeContext, vr bso return nil } -func (DefaultValueDecoders) symbolDecodeType(dctx DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (DefaultValueDecoders) symbolDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t != tSymbol { return emptyValue, ValueDecoderError{ Name: "SymbolDecodeValue", @@ -620,6 +641,9 @@ func (DefaultValueDecoders) symbolDecodeType(dctx DecodeContext, vr bsonrw.Value } // SymbolDecodeValue is the ValueDecoderFunc for the primitive.Symbol type. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) SymbolDecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tSymbol { return ValueDecoderError{Name: "SymbolDecodeValue", Types: []reflect.Type{tSymbol}, Received: val} @@ -634,7 +658,7 @@ func (dvd DefaultValueDecoders) SymbolDecodeValue(dctx DecodeContext, vr bsonrw. return nil } -func (DefaultValueDecoders) binaryDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (DefaultValueDecoders) binaryDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t != tBinary { return emptyValue, ValueDecoderError{ Name: "BinaryDecodeValue", @@ -664,6 +688,9 @@ func (DefaultValueDecoders) binaryDecodeType(dc DecodeContext, vr bsonrw.ValueRe } // BinaryDecodeValue is the ValueDecoderFunc for Binary. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) BinaryDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tBinary { return ValueDecoderError{Name: "BinaryDecodeValue", Types: []reflect.Type{tBinary}, Received: val} @@ -678,7 +705,7 @@ func (dvd DefaultValueDecoders) BinaryDecodeValue(dc DecodeContext, vr bsonrw.Va return nil } -func (DefaultValueDecoders) undefinedDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (DefaultValueDecoders) undefinedDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t != tUndefined { return emptyValue, ValueDecoderError{ Name: "UndefinedDecodeValue", @@ -704,6 +731,9 @@ func (DefaultValueDecoders) undefinedDecodeType(dc DecodeContext, vr bsonrw.Valu } // UndefinedDecodeValue is the ValueDecoderFunc for Undefined. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) UndefinedDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tUndefined { return ValueDecoderError{Name: "UndefinedDecodeValue", Types: []reflect.Type{tUndefined}, Received: val} @@ -719,7 +749,7 @@ func (dvd DefaultValueDecoders) UndefinedDecodeValue(dc DecodeContext, vr bsonrw } // Accept both 12-byte string and pretty-printed 24-byte hex string formats. -func (dvd DefaultValueDecoders) objectIDDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (dvd DefaultValueDecoders) objectIDDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t != tOID { return emptyValue, ValueDecoderError{ Name: "ObjectIDDecodeValue", @@ -765,6 +795,9 @@ func (dvd DefaultValueDecoders) objectIDDecodeType(dc DecodeContext, vr bsonrw.V } // ObjectIDDecodeValue is the ValueDecoderFunc for primitive.ObjectID. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) ObjectIDDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tOID { return ValueDecoderError{Name: "ObjectIDDecodeValue", Types: []reflect.Type{tOID}, Received: val} @@ -779,7 +812,7 @@ func (dvd DefaultValueDecoders) ObjectIDDecodeValue(dc DecodeContext, vr bsonrw. return nil } -func (DefaultValueDecoders) dateTimeDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (DefaultValueDecoders) dateTimeDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t != tDateTime { return emptyValue, ValueDecoderError{ Name: "DateTimeDecodeValue", @@ -808,6 +841,9 @@ func (DefaultValueDecoders) dateTimeDecodeType(dc DecodeContext, vr bsonrw.Value } // DateTimeDecodeValue is the ValueDecoderFunc for DateTime. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) DateTimeDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tDateTime { return ValueDecoderError{Name: "DateTimeDecodeValue", Types: []reflect.Type{tDateTime}, Received: val} @@ -822,7 +858,7 @@ func (dvd DefaultValueDecoders) DateTimeDecodeValue(dc DecodeContext, vr bsonrw. return nil } -func (DefaultValueDecoders) nullDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (DefaultValueDecoders) nullDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t != tNull { return emptyValue, ValueDecoderError{ Name: "NullDecodeValue", @@ -848,6 +884,9 @@ func (DefaultValueDecoders) nullDecodeType(dc DecodeContext, vr bsonrw.ValueRead } // NullDecodeValue is the ValueDecoderFunc for Null. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) NullDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tNull { return ValueDecoderError{Name: "NullDecodeValue", Types: []reflect.Type{tNull}, Received: val} @@ -862,7 +901,7 @@ func (dvd DefaultValueDecoders) NullDecodeValue(dc DecodeContext, vr bsonrw.Valu return nil } -func (DefaultValueDecoders) regexDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (DefaultValueDecoders) regexDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t != tRegex { return emptyValue, ValueDecoderError{ Name: "RegexDecodeValue", @@ -891,6 +930,9 @@ func (DefaultValueDecoders) regexDecodeType(dc DecodeContext, vr bsonrw.ValueRea } // RegexDecodeValue is the ValueDecoderFunc for Regex. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) RegexDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tRegex { return ValueDecoderError{Name: "RegexDecodeValue", Types: []reflect.Type{tRegex}, Received: val} @@ -905,7 +947,7 @@ func (dvd DefaultValueDecoders) RegexDecodeValue(dc DecodeContext, vr bsonrw.Val return nil } -func (DefaultValueDecoders) dBPointerDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (DefaultValueDecoders) dBPointerDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t != tDBPointer { return emptyValue, ValueDecoderError{ Name: "DBPointerDecodeValue", @@ -935,6 +977,9 @@ func (DefaultValueDecoders) dBPointerDecodeType(dc DecodeContext, vr bsonrw.Valu } // DBPointerDecodeValue is the ValueDecoderFunc for DBPointer. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) DBPointerDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tDBPointer { return ValueDecoderError{Name: "DBPointerDecodeValue", Types: []reflect.Type{tDBPointer}, Received: val} @@ -949,7 +994,7 @@ func (dvd DefaultValueDecoders) DBPointerDecodeValue(dc DecodeContext, vr bsonrw return nil } -func (DefaultValueDecoders) timestampDecodeType(dc DecodeContext, vr bsonrw.ValueReader, reflectType reflect.Type) (reflect.Value, error) { +func (DefaultValueDecoders) timestampDecodeType(_ DecodeContext, vr bsonrw.ValueReader, reflectType reflect.Type) (reflect.Value, error) { if reflectType != tTimestamp { return emptyValue, ValueDecoderError{ Name: "TimestampDecodeValue", @@ -978,6 +1023,9 @@ func (DefaultValueDecoders) timestampDecodeType(dc DecodeContext, vr bsonrw.Valu } // TimestampDecodeValue is the ValueDecoderFunc for Timestamp. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) TimestampDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tTimestamp { return ValueDecoderError{Name: "TimestampDecodeValue", Types: []reflect.Type{tTimestamp}, Received: val} @@ -992,7 +1040,7 @@ func (dvd DefaultValueDecoders) TimestampDecodeValue(dc DecodeContext, vr bsonrw return nil } -func (DefaultValueDecoders) minKeyDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (DefaultValueDecoders) minKeyDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t != tMinKey { return emptyValue, ValueDecoderError{ Name: "MinKeyDecodeValue", @@ -1020,6 +1068,9 @@ func (DefaultValueDecoders) minKeyDecodeType(dc DecodeContext, vr bsonrw.ValueRe } // MinKeyDecodeValue is the ValueDecoderFunc for MinKey. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) MinKeyDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tMinKey { return ValueDecoderError{Name: "MinKeyDecodeValue", Types: []reflect.Type{tMinKey}, Received: val} @@ -1034,7 +1085,7 @@ func (dvd DefaultValueDecoders) MinKeyDecodeValue(dc DecodeContext, vr bsonrw.Va return nil } -func (DefaultValueDecoders) maxKeyDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (DefaultValueDecoders) maxKeyDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t != tMaxKey { return emptyValue, ValueDecoderError{ Name: "MaxKeyDecodeValue", @@ -1062,6 +1113,9 @@ func (DefaultValueDecoders) maxKeyDecodeType(dc DecodeContext, vr bsonrw.ValueRe } // MaxKeyDecodeValue is the ValueDecoderFunc for MaxKey. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) MaxKeyDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tMaxKey { return ValueDecoderError{Name: "MaxKeyDecodeValue", Types: []reflect.Type{tMaxKey}, Received: val} @@ -1076,7 +1130,7 @@ func (dvd DefaultValueDecoders) MaxKeyDecodeValue(dc DecodeContext, vr bsonrw.Va return nil } -func (dvd DefaultValueDecoders) decimal128DecodeType(dctx DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (dvd DefaultValueDecoders) decimal128DecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t != tDecimal { return emptyValue, ValueDecoderError{ Name: "Decimal128DecodeValue", @@ -1105,6 +1159,9 @@ func (dvd DefaultValueDecoders) decimal128DecodeType(dctx DecodeContext, vr bson } // Decimal128DecodeValue is the ValueDecoderFunc for primitive.Decimal128. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) Decimal128DecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tDecimal { return ValueDecoderError{Name: "Decimal128DecodeValue", Types: []reflect.Type{tDecimal}, Received: val} @@ -1119,7 +1176,7 @@ func (dvd DefaultValueDecoders) Decimal128DecodeValue(dctx DecodeContext, vr bso return nil } -func (dvd DefaultValueDecoders) jsonNumberDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (dvd DefaultValueDecoders) jsonNumberDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t != tJSONNumber { return emptyValue, ValueDecoderError{ Name: "JSONNumberDecodeValue", @@ -1164,6 +1221,9 @@ func (dvd DefaultValueDecoders) jsonNumberDecodeType(dc DecodeContext, vr bsonrw } // JSONNumberDecodeValue is the ValueDecoderFunc for json.Number. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) JSONNumberDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tJSONNumber { return ValueDecoderError{Name: "JSONNumberDecodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} @@ -1178,7 +1238,7 @@ func (dvd DefaultValueDecoders) JSONNumberDecodeValue(dc DecodeContext, vr bsonr return nil } -func (dvd DefaultValueDecoders) urlDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (dvd DefaultValueDecoders) urlDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t != tURL { return emptyValue, ValueDecoderError{ Name: "URLDecodeValue", @@ -1213,6 +1273,9 @@ func (dvd DefaultValueDecoders) urlDecodeType(dc DecodeContext, vr bsonrw.ValueR } // URLDecodeValue is the ValueDecoderFunc for url.URL. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) URLDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tURL { return ValueDecoderError{Name: "URLDecodeValue", Types: []reflect.Type{tURL}, Received: val} @@ -1230,7 +1293,7 @@ func (dvd DefaultValueDecoders) URLDecodeValue(dc DecodeContext, vr bsonrw.Value // TimeDecodeValue is the ValueDecoderFunc for time.Time. // // Deprecated: TimeDecodeValue is not registered by default. Use TimeCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) TimeDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { +func (dvd DefaultValueDecoders) TimeDecodeValue(_ DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if vr.Type() != bsontype.DateTime { return fmt.Errorf("cannot decode %v into a time.Time", vr.Type()) } @@ -1251,7 +1314,7 @@ func (dvd DefaultValueDecoders) TimeDecodeValue(dc DecodeContext, vr bsonrw.Valu // ByteSliceDecodeValue is the ValueDecoderFunc for []byte. // // Deprecated: ByteSliceDecodeValue is not registered by default. Use ByteSliceCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) ByteSliceDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { +func (dvd DefaultValueDecoders) ByteSliceDecodeValue(_ DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if vr.Type() != bsontype.Binary && vr.Type() != bsontype.Null { return fmt.Errorf("cannot decode %v into a []byte", vr.Type()) } @@ -1316,7 +1379,7 @@ func (dvd DefaultValueDecoders) MapDecodeValue(dc DecodeContext, vr bsonrw.Value keyType := val.Type().Key() for { key, vr, err := dr.ReadElement() - if err == bsonrw.ErrEOD { + if errors.Is(err, bsonrw.ErrEOD) { break } if err != nil { @@ -1336,6 +1399,9 @@ func (dvd DefaultValueDecoders) MapDecodeValue(dc DecodeContext, vr bsonrw.Value } // ArrayDecodeValue is the ValueDecoderFunc for array types. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) ArrayDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Array { return ValueDecoderError{Name: "ArrayDecodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val} @@ -1447,11 +1513,26 @@ func (dvd DefaultValueDecoders) SliceDecodeValue(dc DecodeContext, vr bsonrw.Val } // ValueUnmarshalerDecodeValue is the ValueDecoderFunc for ValueUnmarshaler implementations. -func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.IsValid() || (!val.Type().Implements(tValueUnmarshaler) && !reflect.PtrTo(val.Type()).Implements(tValueUnmarshaler)) { return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} } + // If BSON value is null and the go value is a pointer, then don't call + // UnmarshalBSONValue. Even if the Go pointer is already initialized (i.e., + // non-nil), encountering null in BSON will result in the pointer being + // directly set to nil here. Since the pointer is being replaced with nil, + // there is no opportunity (or reason) for the custom UnmarshalBSONValue logic + // to be called. + if vr.Type() == bsontype.Null && val.Kind() == reflect.Ptr { + val.Set(reflect.Zero(val.Type())) + + return vr.ReadNull() + } + if val.Kind() == reflect.Ptr && val.IsNil() { if !val.CanSet() { return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} @@ -1463,7 +1544,7 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(dc DecodeContext, vr if !val.CanAddr() { return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} } - val = val.Addr() // If they type doesn't implement the interface, a pointer to it must. + val = val.Addr() // If the type doesn't implement the interface, a pointer to it must. } t, src, err := bsonrw.Copier{}.CopyValueToBytes(vr) @@ -1471,16 +1552,19 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(dc DecodeContext, vr return err } - fn := val.Convert(tValueUnmarshaler).MethodByName("UnmarshalBSONValue") - errVal := fn.Call([]reflect.Value{reflect.ValueOf(t), reflect.ValueOf(src)})[0] - if !errVal.IsNil() { - return errVal.Interface().(error) + m, ok := val.Interface().(ValueUnmarshaler) + if !ok { + // NB: this error should be unreachable due to the above checks + return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} } - return nil + return m.UnmarshalBSONValue(t, src) } // UnmarshalerDecodeValue is the ValueDecoderFunc for Unmarshaler implementations. -func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(_ DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.IsValid() || (!val.Type().Implements(tUnmarshaler) && !reflect.PtrTo(val.Type()).Implements(tUnmarshaler)) { return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val} } @@ -1492,13 +1576,6 @@ func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(dc DecodeContext, vr bson val.Set(reflect.New(val.Type().Elem())) } - if !val.Type().Implements(tUnmarshaler) { - if !val.CanAddr() { - return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val} - } - val = val.Addr() // If they type doesn't implement the interface, a pointer to it must. - } - _, src, err := bsonrw.Copier{}.CopyValueToBytes(vr) if err != nil { return err @@ -1516,12 +1593,19 @@ func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(dc DecodeContext, vr bson return nil } - fn := val.Convert(tUnmarshaler).MethodByName("UnmarshalBSON") - errVal := fn.Call([]reflect.Value{reflect.ValueOf(src)})[0] - if !errVal.IsNil() { - return errVal.Interface().(error) + if !val.Type().Implements(tUnmarshaler) { + if !val.CanAddr() { + return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val} + } + val = val.Addr() // If the type doesn't implement the interface, a pointer to it must. } - return nil + + m, ok := val.Interface().(Unmarshaler) + if !ok { + // NB: this error should be unreachable due to the above checks + return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val} + } + return m.UnmarshalBSON(src) } // EmptyInterfaceDecodeValue is the ValueDecoderFunc for interface{}. @@ -1565,7 +1649,10 @@ func (dvd DefaultValueDecoders) EmptyInterfaceDecodeValue(dc DecodeContext, vr b } // CoreDocumentDecodeValue is the ValueDecoderFunc for bsoncore.Document. -func (DefaultValueDecoders) CoreDocumentDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (DefaultValueDecoders) CoreDocumentDecodeValue(_ DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCoreDocument { return ValueDecoderError{Name: "CoreDocumentDecodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} } @@ -1600,7 +1687,7 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr bsonrw.ValueR idx := 0 for { vr, err := ar.ReadValue() - if err == bsonrw.ErrEOA { + if errors.Is(err, bsonrw.ErrEOA) { break } if err != nil { @@ -1671,6 +1758,9 @@ func (dvd DefaultValueDecoders) codeWithScopeDecodeType(dc DecodeContext, vr bso } // CodeWithScopeDecodeValue is the ValueDecoderFunc for CodeWithScope. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. func (dvd DefaultValueDecoders) CodeWithScopeDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCodeWithScope { return ValueDecoderError{Name: "CodeWithScopeDecodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} @@ -1709,7 +1799,7 @@ func (DefaultValueDecoders) decodeElemsFromDocumentReader(dc DecodeContext, dr b elems := make([]reflect.Value, 0) for { key, vr, err := dr.ReadElement() - if err == bsonrw.ErrEOD { + if errors.Is(err, bsonrw.ErrEOD) { break } if err != nil { diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_encoders.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_encoders.go index 6bdb43c..4751ae9 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_encoders.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_encoders.go @@ -58,10 +58,16 @@ func encodeElement(ec EncodeContext, dw bsonrw.DocumentWriter, e primitive.E) er // DefaultValueEncoders is a namespace type for the default ValueEncoders used // when creating a registry. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. type DefaultValueEncoders struct{} // RegisterDefaultEncoders will register the encoder methods attached to DefaultValueEncoders with // the provided RegistryBuilder. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. func (dve DefaultValueEncoders) RegisterDefaultEncoders(rb *RegistryBuilder) { if rb == nil { panic(errors.New("argument to RegisterDefaultEncoders must not be nil")) @@ -113,7 +119,10 @@ func (dve DefaultValueEncoders) RegisterDefaultEncoders(rb *RegistryBuilder) { } // BooleanEncodeValue is the ValueEncoderFunc for bool types. -func (dve DefaultValueEncoders) BooleanEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) BooleanEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Bool { return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} } @@ -125,6 +134,9 @@ func fitsIn32Bits(i int64) bool { } // IntEncodeValue is the ValueEncoderFunc for int types. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. func (dve DefaultValueEncoders) IntEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Int8, reflect.Int16, reflect.Int32: @@ -176,7 +188,10 @@ func (dve DefaultValueEncoders) UintEncodeValue(ec EncodeContext, vw bsonrw.Valu } // FloatEncodeValue is the ValueEncoderFunc for float types. -func (dve DefaultValueEncoders) FloatEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) FloatEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Float32, reflect.Float64: return vw.WriteDouble(val.Float()) @@ -188,7 +203,7 @@ func (dve DefaultValueEncoders) FloatEncodeValue(ec EncodeContext, vw bsonrw.Val // StringEncodeValue is the ValueEncoderFunc for string types. // // Deprecated: StringEncodeValue is not registered by default. Use StringCodec.EncodeValue instead. -func (dve DefaultValueEncoders) StringEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +func (dve DefaultValueEncoders) StringEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if val.Kind() != reflect.String { return ValueEncoderError{ Name: "StringEncodeValue", @@ -201,7 +216,10 @@ func (dve DefaultValueEncoders) StringEncodeValue(ectx EncodeContext, vw bsonrw. } // ObjectIDEncodeValue is the ValueEncoderFunc for primitive.ObjectID. -func (dve DefaultValueEncoders) ObjectIDEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) ObjectIDEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tOID { return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val} } @@ -209,7 +227,10 @@ func (dve DefaultValueEncoders) ObjectIDEncodeValue(ec EncodeContext, vw bsonrw. } // Decimal128EncodeValue is the ValueEncoderFunc for primitive.Decimal128. -func (dve DefaultValueEncoders) Decimal128EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) Decimal128EncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDecimal { return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val} } @@ -217,6 +238,9 @@ func (dve DefaultValueEncoders) Decimal128EncodeValue(ec EncodeContext, vw bsonr } // JSONNumberEncodeValue is the ValueEncoderFunc for json.Number. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tJSONNumber { return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} @@ -237,7 +261,10 @@ func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw bsonr } // URLEncodeValue is the ValueEncoderFunc for url.URL. -func (dve DefaultValueEncoders) URLEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) URLEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tURL { return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val} } @@ -248,7 +275,7 @@ func (dve DefaultValueEncoders) URLEncodeValue(ec EncodeContext, vw bsonrw.Value // TimeEncodeValue is the ValueEncoderFunc for time.TIme. // // Deprecated: TimeEncodeValue is not registered by default. Use TimeCodec.EncodeValue instead. -func (dve DefaultValueEncoders) TimeEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +func (dve DefaultValueEncoders) TimeEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tTime { return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} } @@ -260,7 +287,7 @@ func (dve DefaultValueEncoders) TimeEncodeValue(ec EncodeContext, vw bsonrw.Valu // ByteSliceEncodeValue is the ValueEncoderFunc for []byte. // // Deprecated: ByteSliceEncodeValue is not registered by default. Use ByteSliceCodec.EncodeValue instead. -func (dve DefaultValueEncoders) ByteSliceEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +func (dve DefaultValueEncoders) ByteSliceEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tByteSlice { return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} } @@ -316,7 +343,7 @@ func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw bsonrw.Docum } currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.MapIndex(key)) - if lookupErr != nil && lookupErr != errInvalidValue { + if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -325,7 +352,7 @@ func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw bsonrw.Docum return err } - if lookupErr == errInvalidValue { + if errors.Is(lookupErr, errInvalidValue) { err = vw.WriteNull() if err != nil { return err @@ -343,6 +370,9 @@ func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw bsonrw.Docum } // ArrayEncodeValue is the ValueEncoderFunc for array types. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Array { return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val} @@ -388,7 +418,7 @@ func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw bsonrw.Val for idx := 0; idx < val.Len(); idx++ { currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx)) - if lookupErr != nil && lookupErr != errInvalidValue { + if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -397,7 +427,7 @@ func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw bsonrw.Val return err } - if lookupErr == errInvalidValue { + if errors.Is(lookupErr, errInvalidValue) { err = vw.WriteNull() if err != nil { return err @@ -457,7 +487,7 @@ func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw bsonrw.Val for idx := 0; idx < val.Len(); idx++ { currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx)) - if lookupErr != nil && lookupErr != errInvalidValue { + if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -466,7 +496,7 @@ func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw bsonrw.Val return err } - if lookupErr == errInvalidValue { + if errors.Is(lookupErr, errInvalidValue) { err = vw.WriteNull() if err != nil { return err @@ -515,7 +545,10 @@ func (dve DefaultValueEncoders) EmptyInterfaceEncodeValue(ec EncodeContext, vw b } // ValueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations. -func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement ValueMarshaler switch { case !val.IsValid(): @@ -531,17 +564,22 @@ func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(ec EncodeContext, vw b return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val} } - fn := val.Convert(tValueMarshaler).MethodByName("MarshalBSONValue") - returns := fn.Call(nil) - if !returns[2].IsNil() { - return returns[2].Interface().(error) + m, ok := val.Interface().(ValueMarshaler) + if !ok { + return vw.WriteNull() + } + t, data, err := m.MarshalBSONValue() + if err != nil { + return err } - t, data := returns[0].Interface().(bsontype.Type), returns[1].Interface().([]byte) return bsonrw.Copier{}.CopyValueFromBytes(vw, t, data) } // MarshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations. -func (dve DefaultValueEncoders) MarshalerEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) MarshalerEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Marshaler switch { case !val.IsValid(): @@ -557,16 +595,21 @@ func (dve DefaultValueEncoders) MarshalerEncodeValue(ec EncodeContext, vw bsonrw return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val} } - fn := val.Convert(tMarshaler).MethodByName("MarshalBSON") - returns := fn.Call(nil) - if !returns[1].IsNil() { - return returns[1].Interface().(error) + m, ok := val.Interface().(Marshaler) + if !ok { + return vw.WriteNull() + } + data, err := m.MarshalBSON() + if err != nil { + return err } - data := returns[0].Interface().([]byte) return bsonrw.Copier{}.CopyValueFromBytes(vw, bsontype.EmbeddedDocument, data) } // ProxyEncodeValue is the ValueEncoderFunc for Proxy implementations. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Proxy switch { @@ -583,27 +626,38 @@ func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw bsonrw.Val return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val} } - fn := val.Convert(tProxy).MethodByName("ProxyBSON") - returns := fn.Call(nil) - if !returns[1].IsNil() { - return returns[1].Interface().(error) + m, ok := val.Interface().(Proxy) + if !ok { + return vw.WriteNull() } - data := returns[0] - var encoder ValueEncoder - var err error - if data.Elem().IsValid() { - encoder, err = ec.LookupEncoder(data.Elem().Type()) - } else { - encoder, err = ec.LookupEncoder(nil) + v, err := m.ProxyBSON() + if err != nil { + return err } + if v == nil { + encoder, err := ec.LookupEncoder(nil) + if err != nil { + return err + } + return encoder.EncodeValue(ec, vw, reflect.ValueOf(nil)) + } + vv := reflect.ValueOf(v) + switch vv.Kind() { + case reflect.Ptr, reflect.Interface: + vv = vv.Elem() + } + encoder, err := ec.LookupEncoder(vv.Type()) if err != nil { return err } - return encoder.EncodeValue(ec, vw, data.Elem()) + return encoder.EncodeValue(ec, vw, vv) } // JavaScriptEncodeValue is the ValueEncoderFunc for the primitive.JavaScript type. -func (DefaultValueEncoders) JavaScriptEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) JavaScriptEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tJavaScript { return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val} } @@ -612,7 +666,10 @@ func (DefaultValueEncoders) JavaScriptEncodeValue(ectx EncodeContext, vw bsonrw. } // SymbolEncodeValue is the ValueEncoderFunc for the primitive.Symbol type. -func (DefaultValueEncoders) SymbolEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) SymbolEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tSymbol { return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val} } @@ -621,7 +678,10 @@ func (DefaultValueEncoders) SymbolEncodeValue(ectx EncodeContext, vw bsonrw.Valu } // BinaryEncodeValue is the ValueEncoderFunc for Binary. -func (DefaultValueEncoders) BinaryEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) BinaryEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tBinary { return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val} } @@ -631,7 +691,10 @@ func (DefaultValueEncoders) BinaryEncodeValue(ec EncodeContext, vw bsonrw.ValueW } // UndefinedEncodeValue is the ValueEncoderFunc for Undefined. -func (DefaultValueEncoders) UndefinedEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) UndefinedEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tUndefined { return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val} } @@ -640,7 +703,10 @@ func (DefaultValueEncoders) UndefinedEncodeValue(ec EncodeContext, vw bsonrw.Val } // DateTimeEncodeValue is the ValueEncoderFunc for DateTime. -func (DefaultValueEncoders) DateTimeEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) DateTimeEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDateTime { return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val} } @@ -649,7 +715,10 @@ func (DefaultValueEncoders) DateTimeEncodeValue(ec EncodeContext, vw bsonrw.Valu } // NullEncodeValue is the ValueEncoderFunc for Null. -func (DefaultValueEncoders) NullEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) NullEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tNull { return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val} } @@ -658,7 +727,10 @@ func (DefaultValueEncoders) NullEncodeValue(ec EncodeContext, vw bsonrw.ValueWri } // RegexEncodeValue is the ValueEncoderFunc for Regex. -func (DefaultValueEncoders) RegexEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) RegexEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRegex { return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val} } @@ -669,7 +741,10 @@ func (DefaultValueEncoders) RegexEncodeValue(ec EncodeContext, vw bsonrw.ValueWr } // DBPointerEncodeValue is the ValueEncoderFunc for DBPointer. -func (DefaultValueEncoders) DBPointerEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) DBPointerEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDBPointer { return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val} } @@ -680,7 +755,10 @@ func (DefaultValueEncoders) DBPointerEncodeValue(ec EncodeContext, vw bsonrw.Val } // TimestampEncodeValue is the ValueEncoderFunc for Timestamp. -func (DefaultValueEncoders) TimestampEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) TimestampEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tTimestamp { return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val} } @@ -691,7 +769,10 @@ func (DefaultValueEncoders) TimestampEncodeValue(ec EncodeContext, vw bsonrw.Val } // MinKeyEncodeValue is the ValueEncoderFunc for MinKey. -func (DefaultValueEncoders) MinKeyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) MinKeyEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tMinKey { return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val} } @@ -700,7 +781,10 @@ func (DefaultValueEncoders) MinKeyEncodeValue(ec EncodeContext, vw bsonrw.ValueW } // MaxKeyEncodeValue is the ValueEncoderFunc for MaxKey. -func (DefaultValueEncoders) MaxKeyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) MaxKeyEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tMaxKey { return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val} } @@ -709,7 +793,10 @@ func (DefaultValueEncoders) MaxKeyEncodeValue(ec EncodeContext, vw bsonrw.ValueW } // CoreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. -func (DefaultValueEncoders) CoreDocumentEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) CoreDocumentEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCoreDocument { return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} } @@ -720,6 +807,9 @@ func (DefaultValueEncoders) CoreDocumentEncodeValue(ec EncodeContext, vw bsonrw. } // CodeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. func (dve DefaultValueEncoders) CodeWithScopeEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCodeWithScope { return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/doc.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/doc.go index c1e20f9..4613e5a 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/doc.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/doc.go @@ -1,3 +1,9 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + // Package bsoncodec provides a system for encoding values to BSON representations and decoding // values from BSON representations. This package considers both binary BSON and ExtendedJSON as // BSON representations. The types in this package enable a flexible system for handling this @@ -11,7 +17,7 @@ // 2) A Registry that holds these ValueEncoders and ValueDecoders and provides methods for // retrieving them. // -// ValueEncoders and ValueDecoders +// # ValueEncoders and ValueDecoders // // The ValueEncoder interface is implemented by types that can encode a provided Go type to BSON. // The value to encode is provided as a reflect.Value and a bsonrw.ValueWriter is used within the @@ -25,55 +31,60 @@ // allow the use of a function with the correct signature as a ValueDecoder. A DecodeContext // instance is provided and serves similar functionality to the EncodeContext. // -// Registry and RegistryBuilder +// # Registry // -// A Registry is an immutable store for ValueEncoders, ValueDecoders, and a type map. See the Registry type -// documentation for examples of registering various custom encoders and decoders. A Registry can be constructed using a -// RegistryBuilder, which handles three main types of codecs: +// A Registry is a store for ValueEncoders, ValueDecoders, and a type map. See the Registry type +// documentation for examples of registering various custom encoders and decoders. A Registry can +// have three main types of codecs: // -// 1. Type encoders/decoders - These can be registered using the RegisterTypeEncoder and RegisterTypeDecoder methods. -// The registered codec will be invoked when encoding/decoding a value whose type matches the registered type exactly. -// If the registered type is an interface, the codec will be invoked when encoding or decoding values whose type is the -// interface, but not for values with concrete types that implement the interface. +// 1. Type encoders/decoders - These can be registered using the RegisterTypeEncoder and +// RegisterTypeDecoder methods. The registered codec will be invoked when encoding/decoding a value +// whose type matches the registered type exactly. +// If the registered type is an interface, the codec will be invoked when encoding or decoding +// values whose type is the interface, but not for values with concrete types that implement the +// interface. // -// 2. Hook encoders/decoders - These can be registered using the RegisterHookEncoder and RegisterHookDecoder methods. -// These methods only accept interface types and the registered codecs will be invoked when encoding or decoding values -// whose types implement the interface. An example of a hook defined by the driver is bson.Marshaler. The driver will -// call the MarshalBSON method for any value whose type implements bson.Marshaler, regardless of the value's concrete -// type. +// 2. Hook encoders/decoders - These can be registered using the RegisterHookEncoder and +// RegisterHookDecoder methods. These methods only accept interface types and the registered codecs +// will be invoked when encoding or decoding values whose types implement the interface. An example +// of a hook defined by the driver is bson.Marshaler. The driver will call the MarshalBSON method +// for any value whose type implements bson.Marshaler, regardless of the value's concrete type. // -// 3. Type map entries - This can be used to associate a BSON type with a Go type. These type associations are used when -// decoding into a bson.D/bson.M or a struct field of type interface{}. For example, by default, BSON int32 and int64 -// values decode as Go int32 and int64 instances, respectively, when decoding into a bson.D. The following code would -// change the behavior so these values decode as Go int instances instead: +// 3. Type map entries - This can be used to associate a BSON type with a Go type. These type +// associations are used when decoding into a bson.D/bson.M or a struct field of type interface{}. +// For example, by default, BSON int32 and int64 values decode as Go int32 and int64 instances, +// respectively, when decoding into a bson.D. The following code would change the behavior so these +// values decode as Go int instances instead: // -// intType := reflect.TypeOf(int(0)) -// registryBuilder.RegisterTypeMapEntry(bsontype.Int32, intType).RegisterTypeMapEntry(bsontype.Int64, intType) +// intType := reflect.TypeOf(int(0)) +// registry.RegisterTypeMapEntry(bsontype.Int32, intType).RegisterTypeMapEntry(bsontype.Int64, intType) // -// 4. Kind encoder/decoders - These can be registered using the RegisterDefaultEncoder and RegisterDefaultDecoder -// methods. The registered codec will be invoked when encoding or decoding values whose reflect.Kind matches the -// registered reflect.Kind as long as the value's type doesn't match a registered type or hook encoder/decoder first. -// These methods should be used to change the behavior for all values for a specific kind. +// 4. Kind encoder/decoders - These can be registered using the RegisterDefaultEncoder and +// RegisterDefaultDecoder methods. The registered codec will be invoked when encoding or decoding +// values whose reflect.Kind matches the registered reflect.Kind as long as the value's type doesn't +// match a registered type or hook encoder/decoder first. These methods should be used to change the +// behavior for all values for a specific kind. // -// Registry Lookup Procedure +// # Registry Lookup Procedure // // When looking up an encoder in a Registry, the precedence rules are as follows: // // 1. A type encoder registered for the exact type of the value. // -// 2. A hook encoder registered for an interface that is implemented by the value or by a pointer to the value. If the -// value matches multiple hooks (e.g. the type implements bsoncodec.Marshaler and bsoncodec.ValueMarshaler), the first -// one registered will be selected. Note that registries constructed using bson.NewRegistryBuilder have driver-defined -// hooks registered for the bsoncodec.Marshaler, bsoncodec.ValueMarshaler, and bsoncodec.Proxy interfaces, so those -// will take precedence over any new hooks. +// 2. A hook encoder registered for an interface that is implemented by the value or by a pointer to +// the value. If the value matches multiple hooks (e.g. the type implements bsoncodec.Marshaler and +// bsoncodec.ValueMarshaler), the first one registered will be selected. Note that registries +// constructed using bson.NewRegistry have driver-defined hooks registered for the +// bsoncodec.Marshaler, bsoncodec.ValueMarshaler, and bsoncodec.Proxy interfaces, so those will take +// precedence over any new hooks. // // 3. A kind encoder registered for the value's kind. // -// If all of these lookups fail to find an encoder, an error of type ErrNoEncoder is returned. The same precedence -// rules apply for decoders, with the exception that an error of type ErrNoDecoder will be returned if no decoder is -// found. +// If all of these lookups fail to find an encoder, an error of type ErrNoEncoder is returned. The +// same precedence rules apply for decoders, with the exception that an error of type ErrNoDecoder +// will be returned if no decoder is found. // -// DefaultValueEncoders and DefaultValueDecoders +// # DefaultValueEncoders and DefaultValueDecoders // // The DefaultValueEncoders and DefaultValueDecoders types provide a full set of ValueEncoders and // ValueDecoders for handling a wide range of Go types, including all of the types within the diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/empty_interface_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/empty_interface_codec.go index a15636d..098368f 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/empty_interface_codec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/empty_interface_codec.go @@ -16,18 +16,44 @@ import ( ) // EmptyInterfaceCodec is the Codec used for interface{} values. +// +// Deprecated: EmptyInterfaceCodec will not be directly configurable in Go +// Driver 2.0. To configure the empty interface encode and decode behavior, use +// the configuration methods on a [go.mongodb.org/mongo-driver/bson.Encoder] or +// [go.mongodb.org/mongo-driver/bson.Decoder]. To configure the empty interface +// encode and decode behavior for a mongo.Client, use +// [go.mongodb.org/mongo-driver/mongo/options.ClientOptions.SetBSONOptions]. +// +// For example, to configure a mongo.Client to unmarshal BSON binary field +// values as a Go byte slice, use: +// +// opt := options.Client().SetBSONOptions(&options.BSONOptions{ +// BinaryAsSlice: true, +// }) +// +// See the deprecation notice for each field in EmptyInterfaceCodec for the +// corresponding settings. type EmptyInterfaceCodec struct { + // DecodeBinaryAsSlice causes DecodeValue to unmarshal BSON binary field values that are the + // "Generic" or "Old" BSON binary subtype as a Go byte slice instead of a primitive.Binary. + // + // Deprecated: Use bson.Decoder.BinaryAsSlice or options.BSONOptions.BinaryAsSlice instead. DecodeBinaryAsSlice bool } var ( defaultEmptyInterfaceCodec = NewEmptyInterfaceCodec() - _ ValueCodec = defaultEmptyInterfaceCodec + // Assert that defaultEmptyInterfaceCodec satisfies the typeDecoder interface, which allows it + // to be used by collection type decoders (e.g. map, slice, etc) to set individual values in a + // collection. _ typeDecoder = defaultEmptyInterfaceCodec ) // NewEmptyInterfaceCodec returns a EmptyInterfaceCodec with options opts. +// +// Deprecated: NewEmptyInterfaceCodec will not be available in Go Driver 2.0. See +// [EmptyInterfaceCodec] for more details. func NewEmptyInterfaceCodec(opts ...*bsonoptions.EmptyInterfaceCodecOptions) *EmptyInterfaceCodec { interfaceOpt := bsonoptions.MergeEmptyInterfaceCodecOptions(opts...) @@ -57,11 +83,18 @@ func (eic EmptyInterfaceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWrit func (eic EmptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType bsontype.Type) (reflect.Type, error) { isDocument := valueType == bsontype.Type(0) || valueType == bsontype.EmbeddedDocument - if isDocument && dc.Ancestor != nil { - // Using ancestor information rather than looking up the type map entry forces consistent decoding. - // If we're decoding into a bson.D, subdocuments should also be decoded as bson.D, even if a type map entry - // has been registered. - return dc.Ancestor, nil + if isDocument { + if dc.defaultDocumentType != nil { + // If the bsontype is an embedded document and the DocumentType is set on the DecodeContext, then return + // that type. + return dc.defaultDocumentType, nil + } + if dc.Ancestor != nil { + // Using ancestor information rather than looking up the type map entry forces consistent decoding. + // If we're decoding into a bson.D, subdocuments should also be decoded as bson.D, even if a type map entry + // has been registered. + return dc.Ancestor, nil + } } rtype, err := dc.LookupTypeMapEntry(valueType) @@ -114,7 +147,7 @@ func (eic EmptyInterfaceCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReade return emptyValue, err } - if eic.DecodeBinaryAsSlice && rtype == tBinary { + if (eic.DecodeBinaryAsSlice || dc.binaryAsSlice) && rtype == tBinary { binElem := elem.Interface().(primitive.Binary) if binElem.Subtype == bsontype.BinaryGeneric || binElem.Subtype == bsontype.BinaryBinaryOld { elem = reflect.ValueOf(binElem.Data) diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/map_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/map_codec.go index 1f7acbc..d7e00ff 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/map_codec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/map_codec.go @@ -7,6 +7,8 @@ package bsoncodec import ( + "encoding" + "errors" "fmt" "reflect" "strconv" @@ -19,14 +21,44 @@ import ( var defaultMapCodec = NewMapCodec() // MapCodec is the Codec used for map values. +// +// Deprecated: MapCodec will not be directly configurable in Go Driver 2.0. To +// configure the map encode and decode behavior, use the configuration methods +// on a [go.mongodb.org/mongo-driver/bson.Encoder] or +// [go.mongodb.org/mongo-driver/bson.Decoder]. To configure the map encode and +// decode behavior for a mongo.Client, use +// [go.mongodb.org/mongo-driver/mongo/options.ClientOptions.SetBSONOptions]. +// +// For example, to configure a mongo.Client to marshal nil Go maps as empty BSON +// documents, use: +// +// opt := options.Client().SetBSONOptions(&options.BSONOptions{ +// NilMapAsEmpty: true, +// }) +// +// See the deprecation notice for each field in MapCodec for the corresponding +// settings. type MapCodec struct { - DecodeZerosMap bool - EncodeNilAsEmpty bool + // DecodeZerosMap causes DecodeValue to delete any existing values from Go maps in the destination + // value passed to Decode before unmarshaling BSON documents into them. + // + // Deprecated: Use bson.Decoder.ZeroMaps or options.BSONOptions.ZeroMaps instead. + DecodeZerosMap bool + + // EncodeNilAsEmpty causes EncodeValue to marshal nil Go maps as empty BSON documents instead of + // BSON null. + // + // Deprecated: Use bson.Encoder.NilMapAsEmpty or options.BSONOptions.NilMapAsEmpty instead. + EncodeNilAsEmpty bool + + // EncodeKeysWithStringer causes the Encoder to convert Go map keys to BSON document field name + // strings using fmt.Sprintf() instead of the default string conversion logic. + // + // Deprecated: Use bson.Encoder.StringifyMapKeysWithFmt or + // options.BSONOptions.StringifyMapKeysWithFmt instead. EncodeKeysWithStringer bool } -var _ ValueCodec = &MapCodec{} - // KeyMarshaler is the interface implemented by an object that can marshal itself into a string key. // This applies to types used as map keys and is similar to encoding.TextMarshaler. type KeyMarshaler interface { @@ -44,6 +76,9 @@ type KeyUnmarshaler interface { } // NewMapCodec returns a MapCodec with options opts. +// +// Deprecated: NewMapCodec will not be available in Go Driver 2.0. See +// [MapCodec] for more details. func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec { mapOpt := bsonoptions.MergeMapCodecOptions(opts...) @@ -66,7 +101,7 @@ func (mc *MapCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val ref return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} } - if val.IsNil() && !mc.EncodeNilAsEmpty { + if val.IsNil() && !mc.EncodeNilAsEmpty && !ec.nilMapAsEmpty { // If we have a nil map but we can't WriteNull, that means we're probably trying to encode // to a TopLevel document. We can't currently tell if this is what actually happened, but if // there's a deeper underlying problem, the error will also be returned from WriteDocument, @@ -99,7 +134,7 @@ func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, v keys := val.MapKeys() for _, key := range keys { - keyStr, err := mc.encodeKey(key) + keyStr, err := mc.encodeKey(key, ec.stringifyMapKeysWithFmt) if err != nil { return err } @@ -109,7 +144,7 @@ func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, v } currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.MapIndex(key)) - if lookupErr != nil && lookupErr != errInvalidValue { + if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -118,7 +153,7 @@ func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, v return err } - if lookupErr == errInvalidValue { + if errors.Is(lookupErr, errInvalidValue) { err = vw.WriteNull() if err != nil { return err @@ -162,7 +197,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref val.Set(reflect.MakeMap(val.Type())) } - if val.Len() > 0 && mc.DecodeZerosMap { + if val.Len() > 0 && (mc.DecodeZerosMap || dc.zeroMaps) { clearMap(val) } @@ -181,7 +216,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref for { key, vr, err := dr.ReadElement() - if err == bsonrw.ErrEOD { + if errors.Is(err, bsonrw.ErrEOD) { break } if err != nil { @@ -210,8 +245,8 @@ func clearMap(m reflect.Value) { } } -func (mc *MapCodec) encodeKey(val reflect.Value) (string, error) { - if mc.EncodeKeysWithStringer { +func (mc *MapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (string, error) { + if mc.EncodeKeysWithStringer || encodeKeysWithStringer { return fmt.Sprint(val), nil } @@ -230,6 +265,19 @@ func (mc *MapCodec) encodeKey(val reflect.Value) (string, error) { } return "", err } + // keys implement encoding.TextMarshaler are marshaled. + if km, ok := val.Interface().(encoding.TextMarshaler); ok { + if val.Kind() == reflect.Ptr && val.IsNil() { + return "", nil + } + + buf, err := km.MarshalText() + if err != nil { + return "", err + } + + return string(buf), nil + } switch val.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -241,6 +289,7 @@ func (mc *MapCodec) encodeKey(val reflect.Value) (string, error) { } var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem() +var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) { keyVal := reflect.ValueOf(key) @@ -252,6 +301,12 @@ func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, v := keyVal.Interface().(KeyUnmarshaler) err = v.UnmarshalKey(key) keyVal = keyVal.Elem() + // Try to decode encoding.TextUnmarshalers. + case reflect.PtrTo(keyType).Implements(textUnmarshalerType): + keyVal = reflect.New(keyType) + v := keyVal.Interface().(encoding.TextUnmarshaler) + err = v.UnmarshalText([]byte(key)) + keyVal = keyVal.Elem() // Otherwise, go to type specific behavior default: switch keyType.Kind() { @@ -274,7 +329,7 @@ func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, if mc.EncodeKeysWithStringer { parsed, err := strconv.ParseFloat(key, 64) if err != nil { - return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %v", keyType.Kind(), err) + return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %w", keyType.Kind(), err) } keyVal = reflect.ValueOf(parsed) break diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/pointer_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/pointer_codec.go index 616a3e7..ddfa4a3 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/pointer_codec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/pointer_codec.go @@ -8,7 +8,6 @@ package bsoncodec import ( "reflect" - "sync" "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/bson/bsontype" @@ -18,18 +17,28 @@ var _ ValueEncoder = &PointerCodec{} var _ ValueDecoder = &PointerCodec{} // PointerCodec is the Codec used for pointers. +// +// Deprecated: PointerCodec will not be directly accessible in Go Driver 2.0. To +// override the default pointer encode and decode behavior, create a new registry +// with [go.mongodb.org/mongo-driver/bson.NewRegistry] and register a new +// encoder and decoder for pointers. +// +// For example, +// +// reg := bson.NewRegistry() +// reg.RegisterKindEncoder(reflect.Ptr, myPointerEncoder) +// reg.RegisterKindDecoder(reflect.Ptr, myPointerDecoder) type PointerCodec struct { - ecache map[reflect.Type]ValueEncoder - dcache map[reflect.Type]ValueDecoder - l sync.RWMutex + ecache typeEncoderCache + dcache typeDecoderCache } // NewPointerCodec returns a PointerCodec that has been initialized. +// +// Deprecated: NewPointerCodec will not be available in Go Driver 2.0. See +// [PointerCodec] for more details. func NewPointerCodec() *PointerCodec { - return &PointerCodec{ - ecache: make(map[reflect.Type]ValueEncoder), - dcache: make(map[reflect.Type]ValueDecoder), - } + return &PointerCodec{} } // EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil @@ -46,24 +55,19 @@ func (pc *PointerCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val return vw.WriteNull() } - pc.l.RLock() - enc, ok := pc.ecache[val.Type()] - pc.l.RUnlock() - if ok { - if enc == nil { - return ErrNoEncoder{Type: val.Type()} + typ := val.Type() + if v, ok := pc.ecache.Load(typ); ok { + if v == nil { + return ErrNoEncoder{Type: typ} } - return enc.EncodeValue(ec, vw, val.Elem()) + return v.EncodeValue(ec, vw, val.Elem()) } - - enc, err := ec.LookupEncoder(val.Type().Elem()) - pc.l.Lock() - pc.ecache[val.Type()] = enc - pc.l.Unlock() + // TODO(charlie): handle concurrent requests for the same type + enc, err := ec.LookupEncoder(typ.Elem()) + enc = pc.ecache.LoadOrStore(typ, enc) if err != nil { return err } - return enc.EncodeValue(ec, vw, val.Elem()) } @@ -74,36 +78,31 @@ func (pc *PointerCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val return ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val} } + typ := val.Type() if vr.Type() == bsontype.Null { - val.Set(reflect.Zero(val.Type())) + val.Set(reflect.Zero(typ)) return vr.ReadNull() } if vr.Type() == bsontype.Undefined { - val.Set(reflect.Zero(val.Type())) + val.Set(reflect.Zero(typ)) return vr.ReadUndefined() } if val.IsNil() { - val.Set(reflect.New(val.Type().Elem())) + val.Set(reflect.New(typ.Elem())) } - pc.l.RLock() - dec, ok := pc.dcache[val.Type()] - pc.l.RUnlock() - if ok { - if dec == nil { - return ErrNoDecoder{Type: val.Type()} + if v, ok := pc.dcache.Load(typ); ok { + if v == nil { + return ErrNoDecoder{Type: typ} } - return dec.DecodeValue(dc, vr, val.Elem()) + return v.DecodeValue(dc, vr, val.Elem()) } - - dec, err := dc.LookupDecoder(val.Type().Elem()) - pc.l.Lock() - pc.dcache[val.Type()] = dec - pc.l.Unlock() + // TODO(charlie): handle concurrent requests for the same type + dec, err := dc.LookupDecoder(typ.Elem()) + dec = pc.dcache.LoadOrStore(typ, dec) if err != nil { return err } - return dec.DecodeValue(dc, vr, val.Elem()) } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/registry.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/registry.go index f6f3800..196c491 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/registry.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/registry.go @@ -16,12 +16,18 @@ import ( ) // ErrNilType is returned when nil is passed to either LookupEncoder or LookupDecoder. +// +// Deprecated: ErrNilType will not be supported in Go Driver 2.0. var ErrNilType = errors.New("cannot perform a decoder lookup on ") // ErrNotPointer is returned when a non-pointer type is provided to LookupDecoder. +// +// Deprecated: ErrNotPointer will not be supported in Go Driver 2.0. var ErrNotPointer = errors.New("non-pointer provided to LookupDecoder") // ErrNoEncoder is returned when there wasn't an encoder available for a type. +// +// Deprecated: ErrNoEncoder will not be supported in Go Driver 2.0. type ErrNoEncoder struct { Type reflect.Type } @@ -34,6 +40,8 @@ func (ene ErrNoEncoder) Error() string { } // ErrNoDecoder is returned when there wasn't a decoder available for a type. +// +// Deprecated: ErrNoDecoder will not be supported in Go Driver 2.0. type ErrNoDecoder struct { Type reflect.Type } @@ -43,6 +51,8 @@ func (end ErrNoDecoder) Error() string { } // ErrNoTypeMapEntry is returned when there wasn't a type available for the provided BSON type. +// +// Deprecated: ErrNoTypeMapEntry will not be supported in Go Driver 2.0. type ErrNoTypeMapEntry struct { Type bsontype.Type } @@ -52,63 +62,30 @@ func (entme ErrNoTypeMapEntry) Error() string { } // ErrNotInterface is returned when the provided type is not an interface. +// +// Deprecated: ErrNotInterface will not be supported in Go Driver 2.0. var ErrNotInterface = errors.New("The provided type is not an interface") // A RegistryBuilder is used to build a Registry. This type is not goroutine // safe. +// +// Deprecated: Use Registry instead. type RegistryBuilder struct { - typeEncoders map[reflect.Type]ValueEncoder - interfaceEncoders []interfaceValueEncoder - kindEncoders map[reflect.Kind]ValueEncoder - - typeDecoders map[reflect.Type]ValueDecoder - interfaceDecoders []interfaceValueDecoder - kindDecoders map[reflect.Kind]ValueDecoder - - typeMap map[bsontype.Type]reflect.Type -} - -// A Registry is used to store and retrieve codecs for types and interfaces. This type is the main -// typed passed around and Encoders and Decoders are constructed from it. -type Registry struct { - typeEncoders map[reflect.Type]ValueEncoder - typeDecoders map[reflect.Type]ValueDecoder - - interfaceEncoders []interfaceValueEncoder - interfaceDecoders []interfaceValueDecoder - - kindEncoders map[reflect.Kind]ValueEncoder - kindDecoders map[reflect.Kind]ValueDecoder - - typeMap map[bsontype.Type]reflect.Type - - mu sync.RWMutex + registry *Registry } // NewRegistryBuilder creates a new empty RegistryBuilder. +// +// Deprecated: Use NewRegistry instead. func NewRegistryBuilder() *RegistryBuilder { return &RegistryBuilder{ - typeEncoders: make(map[reflect.Type]ValueEncoder), - typeDecoders: make(map[reflect.Type]ValueDecoder), - - interfaceEncoders: make([]interfaceValueEncoder, 0), - interfaceDecoders: make([]interfaceValueDecoder, 0), - - kindEncoders: make(map[reflect.Kind]ValueEncoder), - kindDecoders: make(map[reflect.Kind]ValueDecoder), - - typeMap: make(map[bsontype.Type]reflect.Type), + registry: NewRegistry(), } } -func buildDefaultRegistry() *Registry { - rb := NewRegistryBuilder() - defaultValueEncoders.RegisterDefaultEncoders(rb) - defaultValueDecoders.RegisterDefaultDecoders(rb) - return rb.Build() -} - // RegisterCodec will register the provided ValueCodec for the provided type. +// +// Deprecated: Use Registry.RegisterTypeEncoder and Registry.RegisterTypeDecoder instead. func (rb *RegistryBuilder) RegisterCodec(t reflect.Type, codec ValueCodec) *RegistryBuilder { rb.RegisterTypeEncoder(t, codec) rb.RegisterTypeDecoder(t, codec) @@ -120,31 +97,22 @@ func (rb *RegistryBuilder) RegisterCodec(t reflect.Type, codec ValueCodec) *Regi // The type will be used directly, so an encoder can be registered for a type and a different encoder can be registered // for a pointer to that type. // -// If the given type is an interface, the encoder will be called when marshalling a type that is that interface. It -// will not be called when marshalling a non-interface type that implements the interface. +// If the given type is an interface, the encoder will be called when marshaling a type that is that interface. It +// will not be called when marshaling a non-interface type that implements the interface. +// +// Deprecated: Use Registry.RegisterTypeEncoder instead. func (rb *RegistryBuilder) RegisterTypeEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder { - rb.typeEncoders[t] = enc + rb.registry.RegisterTypeEncoder(t, enc) return rb } // RegisterHookEncoder will register an encoder for the provided interface type t. This encoder will be called when -// marshalling a type if the type implements t or a pointer to the type implements t. If the provided type is not +// marshaling a type if the type implements t or a pointer to the type implements t. If the provided type is not // an interface (i.e. t.Kind() != reflect.Interface), this method will panic. +// +// Deprecated: Use Registry.RegisterInterfaceEncoder instead. func (rb *RegistryBuilder) RegisterHookEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder { - if t.Kind() != reflect.Interface { - panicStr := fmt.Sprintf("RegisterHookEncoder expects a type with kind reflect.Interface, "+ - "got type %s with kind %s", t, t.Kind()) - panic(panicStr) - } - - for idx, encoder := range rb.interfaceEncoders { - if encoder.i == t { - rb.interfaceEncoders[idx].ve = enc - return rb - } - } - - rb.interfaceEncoders = append(rb.interfaceEncoders, interfaceValueEncoder{i: t, ve: enc}) + rb.registry.RegisterInterfaceEncoder(t, enc) return rb } @@ -153,97 +121,78 @@ func (rb *RegistryBuilder) RegisterHookEncoder(t reflect.Type, enc ValueEncoder) // The type will be used directly, so a decoder can be registered for a type and a different decoder can be registered // for a pointer to that type. // -// If the given type is an interface, the decoder will be called when unmarshalling into a type that is that interface. -// It will not be called when unmarshalling into a non-interface type that implements the interface. +// If the given type is an interface, the decoder will be called when unmarshaling into a type that is that interface. +// It will not be called when unmarshaling into a non-interface type that implements the interface. +// +// Deprecated: Use Registry.RegisterTypeDecoder instead. func (rb *RegistryBuilder) RegisterTypeDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder { - rb.typeDecoders[t] = dec + rb.registry.RegisterTypeDecoder(t, dec) return rb } // RegisterHookDecoder will register an decoder for the provided interface type t. This decoder will be called when -// unmarshalling into a type if the type implements t or a pointer to the type implements t. If the provided type is not +// unmarshaling into a type if the type implements t or a pointer to the type implements t. If the provided type is not // an interface (i.e. t.Kind() != reflect.Interface), this method will panic. +// +// Deprecated: Use Registry.RegisterInterfaceDecoder instead. func (rb *RegistryBuilder) RegisterHookDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder { - if t.Kind() != reflect.Interface { - panicStr := fmt.Sprintf("RegisterHookDecoder expects a type with kind reflect.Interface, "+ - "got type %s with kind %s", t, t.Kind()) - panic(panicStr) - } - - for idx, decoder := range rb.interfaceDecoders { - if decoder.i == t { - rb.interfaceDecoders[idx].vd = dec - return rb - } - } - - rb.interfaceDecoders = append(rb.interfaceDecoders, interfaceValueDecoder{i: t, vd: dec}) + rb.registry.RegisterInterfaceDecoder(t, dec) return rb } // RegisterEncoder registers the provided type and encoder pair. // -// Deprecated: Use RegisterTypeEncoder or RegisterHookEncoder instead. +// Deprecated: Use Registry.RegisterTypeEncoder or Registry.RegisterInterfaceEncoder instead. func (rb *RegistryBuilder) RegisterEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder { if t == tEmpty { - rb.typeEncoders[t] = enc + rb.registry.RegisterTypeEncoder(t, enc) return rb } switch t.Kind() { case reflect.Interface: - for idx, ir := range rb.interfaceEncoders { - if ir.i == t { - rb.interfaceEncoders[idx].ve = enc - return rb - } - } - - rb.interfaceEncoders = append(rb.interfaceEncoders, interfaceValueEncoder{i: t, ve: enc}) + rb.registry.RegisterInterfaceEncoder(t, enc) default: - rb.typeEncoders[t] = enc + rb.registry.RegisterTypeEncoder(t, enc) } return rb } // RegisterDecoder registers the provided type and decoder pair. // -// Deprecated: Use RegisterTypeDecoder or RegisterHookDecoder instead. +// Deprecated: Use Registry.RegisterTypeDecoder or Registry.RegisterInterfaceDecoder instead. func (rb *RegistryBuilder) RegisterDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder { if t == nil { - rb.typeDecoders[nil] = dec + rb.registry.RegisterTypeDecoder(t, dec) return rb } if t == tEmpty { - rb.typeDecoders[t] = dec + rb.registry.RegisterTypeDecoder(t, dec) return rb } switch t.Kind() { case reflect.Interface: - for idx, ir := range rb.interfaceDecoders { - if ir.i == t { - rb.interfaceDecoders[idx].vd = dec - return rb - } - } - - rb.interfaceDecoders = append(rb.interfaceDecoders, interfaceValueDecoder{i: t, vd: dec}) + rb.registry.RegisterInterfaceDecoder(t, dec) default: - rb.typeDecoders[t] = dec + rb.registry.RegisterTypeDecoder(t, dec) } return rb } -// RegisterDefaultEncoder will registr the provided ValueEncoder to the provided +// RegisterDefaultEncoder will register the provided ValueEncoder to the provided // kind. +// +// Deprecated: Use Registry.RegisterKindEncoder instead. func (rb *RegistryBuilder) RegisterDefaultEncoder(kind reflect.Kind, enc ValueEncoder) *RegistryBuilder { - rb.kindEncoders[kind] = enc + rb.registry.RegisterKindEncoder(kind, enc) return rb } // RegisterDefaultDecoder will register the provided ValueDecoder to the // provided kind. +// +// Deprecated: Use Registry.RegisterKindDecoder instead. func (rb *RegistryBuilder) RegisterDefaultDecoder(kind reflect.Kind, dec ValueDecoder) *RegistryBuilder { - rb.kindDecoders[kind] = dec + rb.registry.RegisterKindDecoder(kind, dec) return rb } @@ -254,121 +203,235 @@ func (rb *RegistryBuilder) RegisterDefaultDecoder(kind reflect.Kind, dec ValueDe // By default, BSON documents will decode into interface{} values as bson.D. To change the default type for BSON // documents, a type map entry for bsontype.EmbeddedDocument should be registered. For example, to force BSON documents // to decode to bson.Raw, use the following code: +// // rb.RegisterTypeMapEntry(bsontype.EmbeddedDocument, reflect.TypeOf(bson.Raw{})) +// +// Deprecated: Use Registry.RegisterTypeMapEntry instead. func (rb *RegistryBuilder) RegisterTypeMapEntry(bt bsontype.Type, rt reflect.Type) *RegistryBuilder { - rb.typeMap[bt] = rt + rb.registry.RegisterTypeMapEntry(bt, rt) return rb } // Build creates a Registry from the current state of this RegistryBuilder. +// +// Deprecated: Use NewRegistry instead. func (rb *RegistryBuilder) Build() *Registry { - registry := new(Registry) - - registry.typeEncoders = make(map[reflect.Type]ValueEncoder) - for t, enc := range rb.typeEncoders { - registry.typeEncoders[t] = enc + r := &Registry{ + interfaceEncoders: append([]interfaceValueEncoder(nil), rb.registry.interfaceEncoders...), + interfaceDecoders: append([]interfaceValueDecoder(nil), rb.registry.interfaceDecoders...), + typeEncoders: rb.registry.typeEncoders.Clone(), + typeDecoders: rb.registry.typeDecoders.Clone(), + kindEncoders: rb.registry.kindEncoders.Clone(), + kindDecoders: rb.registry.kindDecoders.Clone(), } + rb.registry.typeMap.Range(func(k, v interface{}) bool { + if k != nil && v != nil { + r.typeMap.Store(k, v) + } + return true + }) + return r +} + +// A Registry is used to store and retrieve codecs for types and interfaces. This type is the main +// typed passed around and Encoders and Decoders are constructed from it. +type Registry struct { + interfaceEncoders []interfaceValueEncoder + interfaceDecoders []interfaceValueDecoder + typeEncoders *typeEncoderCache + typeDecoders *typeDecoderCache + kindEncoders *kindEncoderCache + kindDecoders *kindDecoderCache + typeMap sync.Map // map[bsontype.Type]reflect.Type +} - registry.typeDecoders = make(map[reflect.Type]ValueDecoder) - for t, dec := range rb.typeDecoders { - registry.typeDecoders[t] = dec +// NewRegistry creates a new empty Registry. +func NewRegistry() *Registry { + return &Registry{ + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), } +} + +// RegisterTypeEncoder registers the provided ValueEncoder for the provided type. +// +// The type will be used as provided, so an encoder can be registered for a type and a different +// encoder can be registered for a pointer to that type. +// +// If the given type is an interface, the encoder will be called when marshaling a type that is +// that interface. It will not be called when marshaling a non-interface type that implements the +// interface. To get the latter behavior, call RegisterHookEncoder instead. +// +// RegisterTypeEncoder should not be called concurrently with any other Registry method. +func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) { + r.typeEncoders.Store(valueType, enc) +} + +// RegisterTypeDecoder registers the provided ValueDecoder for the provided type. +// +// The type will be used as provided, so a decoder can be registered for a type and a different +// decoder can be registered for a pointer to that type. +// +// If the given type is an interface, the decoder will be called when unmarshaling into a type that +// is that interface. It will not be called when unmarshaling into a non-interface type that +// implements the interface. To get the latter behavior, call RegisterHookDecoder instead. +// +// RegisterTypeDecoder should not be called concurrently with any other Registry method. +func (r *Registry) RegisterTypeDecoder(valueType reflect.Type, dec ValueDecoder) { + r.typeDecoders.Store(valueType, dec) +} - registry.interfaceEncoders = make([]interfaceValueEncoder, len(rb.interfaceEncoders)) - copy(registry.interfaceEncoders, rb.interfaceEncoders) +// RegisterKindEncoder registers the provided ValueEncoder for the provided kind. +// +// Use RegisterKindEncoder to register an encoder for any type with the same underlying kind. For +// example, consider the type MyInt defined as +// +// type MyInt int32 +// +// To define an encoder for MyInt and int32, use RegisterKindEncoder like +// +// reg.RegisterKindEncoder(reflect.Int32, myEncoder) +// +// RegisterKindEncoder should not be called concurrently with any other Registry method. +func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc ValueEncoder) { + r.kindEncoders.Store(kind, enc) +} + +// RegisterKindDecoder registers the provided ValueDecoder for the provided kind. +// +// Use RegisterKindDecoder to register a decoder for any type with the same underlying kind. For +// example, consider the type MyInt defined as +// +// type MyInt int32 +// +// To define an decoder for MyInt and int32, use RegisterKindDecoder like +// +// reg.RegisterKindDecoder(reflect.Int32, myDecoder) +// +// RegisterKindDecoder should not be called concurrently with any other Registry method. +func (r *Registry) RegisterKindDecoder(kind reflect.Kind, dec ValueDecoder) { + r.kindDecoders.Store(kind, dec) +} - registry.interfaceDecoders = make([]interfaceValueDecoder, len(rb.interfaceDecoders)) - copy(registry.interfaceDecoders, rb.interfaceDecoders) +// RegisterInterfaceEncoder registers an encoder for the provided interface type iface. This encoder will +// be called when marshaling a type if the type implements iface or a pointer to the type +// implements iface. If the provided type is not an interface +// (i.e. iface.Kind() != reflect.Interface), this method will panic. +// +// RegisterInterfaceEncoder should not be called concurrently with any other Registry method. +func (r *Registry) RegisterInterfaceEncoder(iface reflect.Type, enc ValueEncoder) { + if iface.Kind() != reflect.Interface { + panicStr := fmt.Errorf("RegisterInterfaceEncoder expects a type with kind reflect.Interface, "+ + "got type %s with kind %s", iface, iface.Kind()) + panic(panicStr) + } - registry.kindEncoders = make(map[reflect.Kind]ValueEncoder) - for kind, enc := range rb.kindEncoders { - registry.kindEncoders[kind] = enc + for idx, encoder := range r.interfaceEncoders { + if encoder.i == iface { + r.interfaceEncoders[idx].ve = enc + return + } } - registry.kindDecoders = make(map[reflect.Kind]ValueDecoder) - for kind, dec := range rb.kindDecoders { - registry.kindDecoders[kind] = dec + r.interfaceEncoders = append(r.interfaceEncoders, interfaceValueEncoder{i: iface, ve: enc}) +} + +// RegisterInterfaceDecoder registers an decoder for the provided interface type iface. This decoder will +// be called when unmarshaling into a type if the type implements iface or a pointer to the type +// implements iface. If the provided type is not an interface (i.e. iface.Kind() != reflect.Interface), +// this method will panic. +// +// RegisterInterfaceDecoder should not be called concurrently with any other Registry method. +func (r *Registry) RegisterInterfaceDecoder(iface reflect.Type, dec ValueDecoder) { + if iface.Kind() != reflect.Interface { + panicStr := fmt.Errorf("RegisterInterfaceDecoder expects a type with kind reflect.Interface, "+ + "got type %s with kind %s", iface, iface.Kind()) + panic(panicStr) } - registry.typeMap = make(map[bsontype.Type]reflect.Type) - for bt, rt := range rb.typeMap { - registry.typeMap[bt] = rt + for idx, decoder := range r.interfaceDecoders { + if decoder.i == iface { + r.interfaceDecoders[idx].vd = dec + return + } } - return registry + r.interfaceDecoders = append(r.interfaceDecoders, interfaceValueDecoder{i: iface, vd: dec}) +} + +// RegisterTypeMapEntry will register the provided type to the BSON type. The primary usage for this +// mapping is decoding situations where an empty interface is used and a default type needs to be +// created and decoded into. +// +// By default, BSON documents will decode into interface{} values as bson.D. To change the default type for BSON +// documents, a type map entry for bsontype.EmbeddedDocument should be registered. For example, to force BSON documents +// to decode to bson.Raw, use the following code: +// +// reg.RegisterTypeMapEntry(bsontype.EmbeddedDocument, reflect.TypeOf(bson.Raw{})) +func (r *Registry) RegisterTypeMapEntry(bt bsontype.Type, rt reflect.Type) { + r.typeMap.Store(bt, rt) } -// LookupEncoder inspects the registry for an encoder for the given type. The lookup precedence works as follows: +// LookupEncoder returns the first matching encoder in the Registry. It uses the following lookup +// order: // -// 1. An encoder registered for the exact type. If the given type represents an interface, an encoder registered using -// RegisterTypeEncoder for the interface will be selected. +// 1. An encoder registered for the exact type. If the given type is an interface, an encoder +// registered using RegisterTypeEncoder for that interface will be selected. // -// 2. An encoder registered using RegisterHookEncoder for an interface implemented by the type or by a pointer to the -// type. +// 2. An encoder registered using RegisterInterfaceEncoder for an interface implemented by the type +// or by a pointer to the type. // -// 3. An encoder registered for the reflect.Kind of the value. +// 3. An encoder registered using RegisterKindEncoder for the kind of value. // -// If no encoder is found, an error of type ErrNoEncoder is returned. -func (r *Registry) LookupEncoder(t reflect.Type) (ValueEncoder, error) { - encodererr := ErrNoEncoder{Type: t} - r.mu.RLock() - enc, found := r.lookupTypeEncoder(t) - r.mu.RUnlock() +// If no encoder is found, an error of type ErrNoEncoder is returned. LookupEncoder is safe for +// concurrent use by multiple goroutines after all codecs and encoders are registered. +func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { + if valueType == nil { + return nil, ErrNoEncoder{Type: valueType} + } + enc, found := r.lookupTypeEncoder(valueType) if found { if enc == nil { - return nil, ErrNoEncoder{Type: t} + return nil, ErrNoEncoder{Type: valueType} } return enc, nil } - enc, found = r.lookupInterfaceEncoder(t, true) + enc, found = r.lookupInterfaceEncoder(valueType, true) if found { - r.mu.Lock() - r.typeEncoders[t] = enc - r.mu.Unlock() - return enc, nil + return r.typeEncoders.LoadOrStore(valueType, enc), nil } - if t == nil { - r.mu.Lock() - r.typeEncoders[t] = nil - r.mu.Unlock() - return nil, encodererr - } - - enc, found = r.kindEncoders[t.Kind()] - if !found { - r.mu.Lock() - r.typeEncoders[t] = nil - r.mu.Unlock() - return nil, encodererr + if v, ok := r.kindEncoders.Load(valueType.Kind()); ok { + return r.storeTypeEncoder(valueType, v), nil } + return nil, ErrNoEncoder{Type: valueType} +} - r.mu.Lock() - r.typeEncoders[t] = enc - r.mu.Unlock() - return enc, nil +func (r *Registry) storeTypeEncoder(rt reflect.Type, enc ValueEncoder) ValueEncoder { + return r.typeEncoders.LoadOrStore(rt, enc) } -func (r *Registry) lookupTypeEncoder(t reflect.Type) (ValueEncoder, bool) { - enc, found := r.typeEncoders[t] - return enc, found +func (r *Registry) lookupTypeEncoder(rt reflect.Type) (ValueEncoder, bool) { + return r.typeEncoders.Load(rt) } -func (r *Registry) lookupInterfaceEncoder(t reflect.Type, allowAddr bool) (ValueEncoder, bool) { - if t == nil { +func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool) (ValueEncoder, bool) { + if valueType == nil { return nil, false } for _, ienc := range r.interfaceEncoders { - if t.Implements(ienc.i) { + if valueType.Implements(ienc.i) { return ienc.ve, true } - if allowAddr && t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(ienc.i) { - // if *t implements an interface, this will catch if t implements an interface further ahead - // in interfaceEncoders - defaultEnc, found := r.lookupInterfaceEncoder(t, false) + if allowAddr && valueType.Kind() != reflect.Ptr && reflect.PtrTo(valueType).Implements(ienc.i) { + // if *t implements an interface, this will catch if t implements an interface further + // ahead in interfaceEncoders + defaultEnc, found := r.lookupInterfaceEncoder(valueType, false) if !found { - defaultEnc = r.kindEncoders[t.Kind()] + defaultEnc, _ = r.kindEncoders.Load(valueType.Kind()) } return newCondAddrEncoder(ienc.ve, defaultEnc), true } @@ -376,70 +439,61 @@ func (r *Registry) lookupInterfaceEncoder(t reflect.Type, allowAddr bool) (Value return nil, false } -// LookupDecoder inspects the registry for an decoder for the given type. The lookup precedence works as follows: +// LookupDecoder returns the first matching decoder in the Registry. It uses the following lookup +// order: // -// 1. A decoder registered for the exact type. If the given type represents an interface, a decoder registered using -// RegisterTypeDecoder for the interface will be selected. +// 1. A decoder registered for the exact type. If the given type is an interface, a decoder +// registered using RegisterTypeDecoder for that interface will be selected. // -// 2. A decoder registered using RegisterHookDecoder for an interface implemented by the type or by a pointer to the -// type. +// 2. A decoder registered using RegisterInterfaceDecoder for an interface implemented by the type or by +// a pointer to the type. // -// 3. A decoder registered for the reflect.Kind of the value. +// 3. A decoder registered using RegisterKindDecoder for the kind of value. // -// If no decoder is found, an error of type ErrNoDecoder is returned. -func (r *Registry) LookupDecoder(t reflect.Type) (ValueDecoder, error) { - if t == nil { +// If no decoder is found, an error of type ErrNoDecoder is returned. LookupDecoder is safe for +// concurrent use by multiple goroutines after all codecs and decoders are registered. +func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { + if valueType == nil { return nil, ErrNilType } - decodererr := ErrNoDecoder{Type: t} - r.mu.RLock() - dec, found := r.lookupTypeDecoder(t) - r.mu.RUnlock() + dec, found := r.lookupTypeDecoder(valueType) if found { if dec == nil { - return nil, ErrNoDecoder{Type: t} + return nil, ErrNoDecoder{Type: valueType} } return dec, nil } - dec, found = r.lookupInterfaceDecoder(t, true) + dec, found = r.lookupInterfaceDecoder(valueType, true) if found { - r.mu.Lock() - r.typeDecoders[t] = dec - r.mu.Unlock() - return dec, nil + return r.storeTypeDecoder(valueType, dec), nil } - dec, found = r.kindDecoders[t.Kind()] - if !found { - r.mu.Lock() - r.typeDecoders[t] = nil - r.mu.Unlock() - return nil, decodererr + if v, ok := r.kindDecoders.Load(valueType.Kind()); ok { + return r.storeTypeDecoder(valueType, v), nil } + return nil, ErrNoDecoder{Type: valueType} +} - r.mu.Lock() - r.typeDecoders[t] = dec - r.mu.Unlock() - return dec, nil +func (r *Registry) lookupTypeDecoder(valueType reflect.Type) (ValueDecoder, bool) { + return r.typeDecoders.Load(valueType) } -func (r *Registry) lookupTypeDecoder(t reflect.Type) (ValueDecoder, bool) { - dec, found := r.typeDecoders[t] - return dec, found +func (r *Registry) storeTypeDecoder(typ reflect.Type, dec ValueDecoder) ValueDecoder { + return r.typeDecoders.LoadOrStore(typ, dec) } -func (r *Registry) lookupInterfaceDecoder(t reflect.Type, allowAddr bool) (ValueDecoder, bool) { +func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool) (ValueDecoder, bool) { for _, idec := range r.interfaceDecoders { - if t.Implements(idec.i) { + if valueType.Implements(idec.i) { return idec.vd, true } - if allowAddr && t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(idec.i) { - // if *t implements an interface, this will catch if t implements an interface further ahead - // in interfaceDecoders - defaultDec, found := r.lookupInterfaceDecoder(t, false) + if allowAddr && valueType.Kind() != reflect.Ptr && reflect.PtrTo(valueType).Implements(idec.i) { + // if *t implements an interface, this will catch if t implements an interface further + // ahead in interfaceDecoders + defaultDec, found := r.lookupInterfaceDecoder(valueType, false) if !found { - defaultDec = r.kindDecoders[t.Kind()] + defaultDec, _ = r.kindDecoders.Load(valueType.Kind()) } return newCondAddrDecoder(idec.vd, defaultDec), true } @@ -449,12 +503,14 @@ func (r *Registry) lookupInterfaceDecoder(t reflect.Type, allowAddr bool) (Value // LookupTypeMapEntry inspects the registry's type map for a Go type for the corresponding BSON // type. If no type is found, ErrNoTypeMapEntry is returned. +// +// LookupTypeMapEntry should not be called concurrently with any other Registry method. func (r *Registry) LookupTypeMapEntry(bt bsontype.Type) (reflect.Type, error) { - t, ok := r.typeMap[bt] - if !ok || t == nil { + v, ok := r.typeMap.Load(bt) + if v == nil || !ok { return nil, ErrNoTypeMapEntry{Type: bt} } - return t, nil + return v.(reflect.Type), nil } type interfaceValueEncoder struct { diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/slice_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/slice_codec.go index 3c1b6b8..14c9fd2 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/slice_codec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/slice_codec.go @@ -7,6 +7,7 @@ package bsoncodec import ( + "errors" "fmt" "reflect" @@ -19,13 +20,35 @@ import ( var defaultSliceCodec = NewSliceCodec() // SliceCodec is the Codec used for slice values. +// +// Deprecated: SliceCodec will not be directly configurable in Go Driver 2.0. To +// configure the slice encode and decode behavior, use the configuration methods +// on a [go.mongodb.org/mongo-driver/bson.Encoder] or +// [go.mongodb.org/mongo-driver/bson.Decoder]. To configure the slice encode and +// decode behavior for a mongo.Client, use +// [go.mongodb.org/mongo-driver/mongo/options.ClientOptions.SetBSONOptions]. +// +// For example, to configure a mongo.Client to marshal nil Go slices as empty +// BSON arrays, use: +// +// opt := options.Client().SetBSONOptions(&options.BSONOptions{ +// NilSliceAsEmpty: true, +// }) +// +// See the deprecation notice for each field in SliceCodec for the corresponding +// settings. type SliceCodec struct { + // EncodeNilAsEmpty causes EncodeValue to marshal nil Go slices as empty BSON arrays instead of + // BSON null. + // + // Deprecated: Use bson.Encoder.NilSliceAsEmpty instead. EncodeNilAsEmpty bool } -var _ ValueCodec = &MapCodec{} - // NewSliceCodec returns a MapCodec with options opts. +// +// Deprecated: NewSliceCodec will not be available in Go Driver 2.0. See +// [SliceCodec] for more details. func NewSliceCodec(opts ...*bsonoptions.SliceCodecOptions) *SliceCodec { sliceOpt := bsonoptions.MergeSliceCodecOptions(opts...) @@ -42,21 +65,19 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val re return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } - if val.IsNil() && !sc.EncodeNilAsEmpty { + if val.IsNil() && !sc.EncodeNilAsEmpty && !ec.nilSliceAsEmpty { return vw.WriteNull() } // If we have a []byte we want to treat it as a binary instead of as an array. if val.Type().Elem() == tByte { - var byteSlice []byte - for idx := 0; idx < val.Len(); idx++ { - byteSlice = append(byteSlice, val.Index(idx).Interface().(byte)) - } + byteSlice := make([]byte, val.Len()) + reflect.Copy(reflect.ValueOf(byteSlice), val) return vw.WriteBinary(byteSlice) } // If we have a []primitive.E we want to treat it as a document instead of as an array. - if val.Type().ConvertibleTo(tD) { + if val.Type() == tD || val.Type().ConvertibleTo(tD) { d := val.Convert(tD).Interface().(primitive.D) dw, err := vw.WriteDocument() @@ -87,7 +108,7 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val re for idx := 0; idx < val.Len(); idx++ { currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.Index(idx)) - if lookupErr != nil && lookupErr != errInvalidValue { + if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -96,7 +117,7 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val re return err } - if lookupErr == errInvalidValue { + if errors.Is(lookupErr, errInvalidValue) { err = vw.WriteNull() if err != nil { return err @@ -145,11 +166,8 @@ func (sc *SliceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val r if val.IsNil() { val.Set(reflect.MakeSlice(val.Type(), 0, len(data))) } - val.SetLen(0) - for _, elem := range data { - val.Set(reflect.Append(val, reflect.ValueOf(elem))) - } + val.Set(reflect.AppendSlice(val, reflect.ValueOf(data))) return nil case bsontype.String: if sliceType := val.Type().Elem(); sliceType != tByte { @@ -164,11 +182,8 @@ func (sc *SliceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val r if val.IsNil() { val.Set(reflect.MakeSlice(val.Type(), 0, len(byteStr))) } - val.SetLen(0) - for _, elem := range byteStr { - val.Set(reflect.Append(val, reflect.ValueOf(elem))) - } + val.Set(reflect.AppendSlice(val, reflect.ValueOf(byteStr))) return nil default: return fmt.Errorf("cannot decode %v into a slice", vrType) diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/string_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/string_codec.go index 5332b7c..a8f885a 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/string_codec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/string_codec.go @@ -15,26 +15,46 @@ import ( "go.mongodb.org/mongo-driver/bson/bsontype" ) -// StringCodec is the Codec used for struct values. +// StringCodec is the Codec used for string values. +// +// Deprecated: StringCodec will not be directly accessible in Go Driver 2.0. To +// override the default string encode and decode behavior, create a new registry +// with [go.mongodb.org/mongo-driver/bson.NewRegistry] and register a new +// encoder and decoder for strings. +// +// For example, +// +// reg := bson.NewRegistry() +// reg.RegisterKindEncoder(reflect.String, myStringEncoder) +// reg.RegisterKindDecoder(reflect.String, myStringDecoder) type StringCodec struct { + // DecodeObjectIDAsHex specifies if object IDs should be decoded as their hex representation. + // If false, a string made from the raw object ID bytes will be used. Defaults to true. + // + // Deprecated: Decoding object IDs as raw bytes will not be supported in Go Driver 2.0. DecodeObjectIDAsHex bool } var ( defaultStringCodec = NewStringCodec() - _ ValueCodec = defaultStringCodec + // Assert that defaultStringCodec satisfies the typeDecoder interface, which allows it to be + // used by collection type decoders (e.g. map, slice, etc) to set individual values in a + // collection. _ typeDecoder = defaultStringCodec ) // NewStringCodec returns a StringCodec with options opts. +// +// Deprecated: NewStringCodec will not be available in Go Driver 2.0. See +// [StringCodec] for more details. func NewStringCodec(opts ...*bsonoptions.StringCodecOptions) *StringCodec { stringOpt := bsonoptions.MergeStringCodecOptions(opts...) return &StringCodec{*stringOpt.DecodeObjectIDAsHex} } // EncodeValue is the ValueEncoder for string types. -func (sc *StringCodec) EncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +func (sc *StringCodec) EncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if val.Kind() != reflect.String { return ValueEncoderError{ Name: "StringEncodeValue", @@ -46,7 +66,7 @@ func (sc *StringCodec) EncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, va return vw.WriteString(val.String()) } -func (sc *StringCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { +func (sc *StringCodec) decodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { if t.Kind() != reflect.String { return emptyValue, ValueDecoderError{ Name: "StringDecodeValue", @@ -71,6 +91,7 @@ func (sc *StringCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t ref if sc.DecodeObjectIDAsHex { str = oid.Hex() } else { + // TODO(GODRIVER-2796): Return an error here instead of decoding to a garbled string. byteArray := [12]byte(oid) str = string(byteArray[:]) } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/struct_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/struct_codec.go index be3f208..f8d9690 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/struct_codec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/struct_codec.go @@ -59,14 +59,58 @@ type Zeroer interface { } // StructCodec is the Codec used for struct values. +// +// Deprecated: StructCodec will not be directly configurable in Go Driver 2.0. +// To configure the struct encode and decode behavior, use the configuration +// methods on a [go.mongodb.org/mongo-driver/bson.Encoder] or +// [go.mongodb.org/mongo-driver/bson.Decoder]. To configure the struct encode +// and decode behavior for a mongo.Client, use +// [go.mongodb.org/mongo-driver/mongo/options.ClientOptions.SetBSONOptions]. +// +// For example, to configure a mongo.Client to omit zero-value structs when +// using the "omitempty" struct tag, use: +// +// opt := options.Client().SetBSONOptions(&options.BSONOptions{ +// OmitZeroStruct: true, +// }) +// +// See the deprecation notice for each field in StructCodec for the corresponding +// settings. type StructCodec struct { - cache map[reflect.Type]*structDescription - l sync.RWMutex - parser StructTagParser - DecodeZeroStruct bool - DecodeDeepZeroInline bool - EncodeOmitDefaultStruct bool - AllowUnexportedFields bool + cache sync.Map // map[reflect.Type]*structDescription + parser StructTagParser + + // DecodeZeroStruct causes DecodeValue to delete any existing values from Go structs in the + // destination value passed to Decode before unmarshaling BSON documents into them. + // + // Deprecated: Use bson.Decoder.ZeroStructs or options.BSONOptions.ZeroStructs instead. + DecodeZeroStruct bool + + // DecodeDeepZeroInline causes DecodeValue to delete any existing values from Go structs in the + // destination value passed to Decode before unmarshaling BSON documents into them. + // + // Deprecated: DecodeDeepZeroInline will not be supported in Go Driver 2.0. + DecodeDeepZeroInline bool + + // EncodeOmitDefaultStruct causes the Encoder to consider the zero value for a struct (e.g. + // MyStruct{}) as empty and omit it from the marshaled BSON when the "omitempty" struct tag + // option is set. + // + // Deprecated: Use bson.Encoder.OmitZeroStruct or options.BSONOptions.OmitZeroStruct instead. + EncodeOmitDefaultStruct bool + + // AllowUnexportedFields allows encoding and decoding values from un-exported struct fields. + // + // Deprecated: AllowUnexportedFields does not work on recent versions of Go and will not be + // supported in Go Driver 2.0. + AllowUnexportedFields bool + + // OverwriteDuplicatedInlinedFields, if false, causes EncodeValue to return an error if there is + // a duplicate field in the marshaled BSON when the "inline" struct tag option is set. The + // default value is true. + // + // Deprecated: Use bson.Encoder.ErrorOnInlineDuplicates or + // options.BSONOptions.ErrorOnInlineDuplicates instead. OverwriteDuplicatedInlinedFields bool } @@ -74,6 +118,9 @@ var _ ValueEncoder = &StructCodec{} var _ ValueDecoder = &StructCodec{} // NewStructCodec returns a StructCodec that uses p for struct tag parsing. +// +// Deprecated: NewStructCodec will not be available in Go Driver 2.0. See +// [StructCodec] for more details. func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) { if p == nil { return nil, errors.New("a StructTagParser must be provided to NewStructCodec") @@ -82,7 +129,6 @@ func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) structOpt := bsonoptions.MergeStructCodecOptions(opts...) codec := &StructCodec{ - cache: make(map[reflect.Type]*structDescription), parser: p, } @@ -106,12 +152,12 @@ func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) } // EncodeValue handles encoding generic struct types. -func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +func (sc *StructCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Struct { return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} } - sd, err := sc.describeStruct(r.Registry, val.Type()) + sd, err := sc.describeStruct(ec.Registry, val.Type(), ec.useJSONStructTags, ec.errorOnInlineDuplicates) if err != nil { return err } @@ -131,13 +177,13 @@ func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val r } } - desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(r, desc.encoder, rv) + desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(ec, desc.encoder, rv) - if err != nil && err != errInvalidValue { + if err != nil && !errors.Is(err, errInvalidValue) { return err } - if err == errInvalidValue { + if errors.Is(err, errInvalidValue) { if desc.omitEmpty { continue } @@ -158,17 +204,17 @@ func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val r encoder := desc.encoder - var isZero bool - rvInterface := rv.Interface() + var empty bool if cz, ok := encoder.(CodecZeroer); ok { - isZero = cz.IsTypeZero(rvInterface) + empty = cz.IsTypeZero(rv.Interface()) } else if rv.Kind() == reflect.Interface { - // sc.isZero will not treat an interface rv as an interface, so we need to check for the zero interface separately. - isZero = rv.IsNil() + // isEmpty will not treat an interface rv as an interface, so we need to check for the + // nil interface separately. + empty = rv.IsNil() } else { - isZero = sc.isZero(rvInterface) + empty = isEmpty(rv, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct) } - if desc.omitEmpty && isZero { + if desc.omitEmpty && empty { continue } @@ -177,7 +223,17 @@ func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val r return err } - ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize} + ectx := EncodeContext{ + Registry: ec.Registry, + MinSize: desc.minSize || ec.MinSize, + errorOnInlineDuplicates: ec.errorOnInlineDuplicates, + stringifyMapKeysWithFmt: ec.stringifyMapKeysWithFmt, + nilMapAsEmpty: ec.nilMapAsEmpty, + nilSliceAsEmpty: ec.nilSliceAsEmpty, + nilByteSliceAsEmpty: ec.nilByteSliceAsEmpty, + omitZeroStruct: ec.omitZeroStruct, + useJSONStructTags: ec.useJSONStructTags, + } err = encoder.EncodeValue(ectx, vw2, rv) if err != nil { return err @@ -191,15 +247,15 @@ func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val r return exists } - return defaultMapCodec.mapEncodeValue(r, dw, rv, collisionFn) + return defaultMapCodec.mapEncodeValue(ec, dw, rv, collisionFn) } return dw.WriteDocumentEnd() } func newDecodeError(key string, original error) error { - de, ok := original.(*DecodeError) - if !ok { + var de *DecodeError + if !errors.As(original, &de) { return &DecodeError{ keys: []string{key}, wrapped: original, @@ -213,7 +269,7 @@ func newDecodeError(key string, original error) error { // DecodeValue implements the Codec interface. // By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr. // For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared. -func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { +func (sc *StructCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.Struct { return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} } @@ -238,12 +294,12 @@ func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val r return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type()) } - sd, err := sc.describeStruct(r.Registry, val.Type()) + sd, err := sc.describeStruct(dc.Registry, val.Type(), dc.useJSONStructTags, false) if err != nil { return err } - if sc.DecodeZeroStruct { + if sc.DecodeZeroStruct || dc.zeroStructs { val.Set(reflect.Zero(val.Type())) } if sc.DecodeDeepZeroInline && sd.inline { @@ -254,7 +310,7 @@ func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val r var inlineMap reflect.Value if sd.inlineMap >= 0 { inlineMap = val.Field(sd.inlineMap) - decoder, err = r.LookupDecoder(inlineMap.Type().Elem()) + decoder, err = dc.LookupDecoder(inlineMap.Type().Elem()) if err != nil { return err } @@ -267,7 +323,7 @@ func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val r for { name, vr, err := dr.ReadElement() - if err == bsonrw.ErrEOD { + if errors.Is(err, bsonrw.ErrEOD) { break } if err != nil { @@ -298,8 +354,8 @@ func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val r } elem := reflect.New(inlineMap.Type().Elem()).Elem() - r.Ancestor = inlineMap.Type() - err = decoder.DecodeValue(r, vr, elem) + dc.Ancestor = inlineMap.Type() + err = decoder.DecodeValue(dc, vr, elem) if err != nil { return err } @@ -326,7 +382,17 @@ func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val r } field = field.Addr() - dctx := DecodeContext{Registry: r.Registry, Truncate: fd.truncate || r.Truncate} + dctx := DecodeContext{ + Registry: dc.Registry, + Truncate: fd.truncate || dc.Truncate, + defaultDocumentType: dc.defaultDocumentType, + binaryAsSlice: dc.binaryAsSlice, + useJSONStructTags: dc.useJSONStructTags, + useLocalTimeZone: dc.useLocalTimeZone, + zeroMaps: dc.zeroMaps, + zeroStructs: dc.zeroStructs, + } + if fd.decoder == nil { return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()}) } @@ -340,51 +406,35 @@ func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val r return nil } -func (sc *StructCodec) isZero(i interface{}) bool { - v := reflect.ValueOf(i) - - // check the value validity - if !v.IsValid() { - return true - } - - if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) { - return z.IsZero() +func isEmpty(v reflect.Value, omitZeroStruct bool) bool { + kind := v.Kind() + if (kind != reflect.Ptr || !v.IsNil()) && v.Type().Implements(tZeroer) { + return v.Interface().(Zeroer).IsZero() } - - switch v.Kind() { + switch kind { case reflect.Array, reflect.Map, reflect.Slice, reflect.String: return v.Len() == 0 - case reflect.Bool: - return !v.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return v.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return v.Uint() == 0 - case reflect.Float32, reflect.Float64: - return v.Float() == 0 - case reflect.Interface, reflect.Ptr: - return v.IsNil() case reflect.Struct: - if sc.EncodeOmitDefaultStruct { - vt := v.Type() - if vt == tTime { - return v.Interface().(time.Time).IsZero() + if !omitZeroStruct { + return false + } + vt := v.Type() + if vt == tTime { + return v.Interface().(time.Time).IsZero() + } + numField := vt.NumField() + for i := 0; i < numField; i++ { + ff := vt.Field(i) + if ff.PkgPath != "" && !ff.Anonymous { + continue // Private field } - for i := 0; i < v.NumField(); i++ { - if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous { - continue // Private field - } - fld := v.Field(i) - if !sc.isZero(fld.Interface()) { - return false - } + if !isEmpty(v.Field(i), omitZeroStruct) { + return false } - return true } + return true } - - return false + return !v.IsValid() || v.IsZero() } type structDescription struct { @@ -435,16 +485,35 @@ func (bi byIndex) Less(i, j int) bool { return len(bi[i].inline) < len(bi[j].inline) } -func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) { +func (sc *StructCodec) describeStruct( + r *Registry, + t reflect.Type, + useJSONStructTags bool, + errorOnDuplicates bool, +) (*structDescription, error) { // We need to analyze the struct, including getting the tags, collecting // information about inlining, and create a map of the field name to the field. - sc.l.RLock() - ds, exists := sc.cache[t] - sc.l.RUnlock() - if exists { - return ds, nil + if v, ok := sc.cache.Load(t); ok { + return v.(*structDescription), nil } + // TODO(charlie): Only describe the struct once when called + // concurrently with the same type. + ds, err := sc.describeStructSlow(r, t, useJSONStructTags, errorOnDuplicates) + if err != nil { + return nil, err + } + if v, loaded := sc.cache.LoadOrStore(t, ds); loaded { + ds = v.(*structDescription) + } + return ds, nil +} +func (sc *StructCodec) describeStructSlow( + r *Registry, + t reflect.Type, + useJSONStructTags bool, + errorOnDuplicates bool, +) (*structDescription, error) { numFields := t.NumField() sd := &structDescription{ fm: make(map[string]fieldDescription, numFields), @@ -477,7 +546,14 @@ func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescr decoder: decoder, } - stags, err := sc.parser.ParseStructTags(sf) + var stags StructTags + // If the caller requested that we use JSON struct tags, use the JSONFallbackStructTagParser + // instead of the parser defined on the codec. + if useJSONStructTags { + stags, err = JSONFallbackStructTagParser.ParseStructTags(sf) + } else { + stags, err = sc.parser.ParseStructTags(sf) + } if err != nil { return nil, err } @@ -507,7 +583,7 @@ func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescr } fallthrough case reflect.Struct: - inlinesf, err := sc.describeStruct(r, sfType) + inlinesf, err := sc.describeStruct(r, sfType, useJSONStructTags, errorOnDuplicates) if err != nil { return nil, err } @@ -559,7 +635,7 @@ func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescr continue } dominant, ok := dominantField(fields[i : i+advance]) - if !ok || !sc.OverwriteDuplicatedInlinedFields { + if !ok || !sc.OverwriteDuplicatedInlinedFields || errorOnDuplicates { return nil, fmt.Errorf("struct %s has duplicated key %s", t.String(), name) } sd.fl = append(sd.fl, dominant) @@ -568,10 +644,6 @@ func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescr sort.Sort(byIndex(sd.fl)) - sc.l.Lock() - sc.cache[t] = sd - sc.l.Unlock() - return sd, nil } @@ -629,21 +701,21 @@ func getInlineField(val reflect.Value, index []int) (reflect.Value, error) { // DeepZero returns recursive zero object func deepZero(st reflect.Type) (result reflect.Value) { - result = reflect.Indirect(reflect.New(st)) - - if result.Kind() == reflect.Struct { - for i := 0; i < result.NumField(); i++ { - if f := result.Field(i); f.Kind() == reflect.Ptr { - if f.CanInterface() { - if ft := reflect.TypeOf(f.Interface()); ft.Elem().Kind() == reflect.Struct { - result.Field(i).Set(recursivePointerTo(deepZero(ft.Elem()))) - } + if st.Kind() == reflect.Struct { + numField := st.NumField() + for i := 0; i < numField; i++ { + if result == emptyValue { + result = reflect.Indirect(reflect.New(st)) + } + f := result.Field(i) + if f.CanInterface() { + if f.Type().Kind() == reflect.Struct { + result.Field(i).Set(recursivePointerTo(deepZero(f.Type().Elem()))) } } } } - - return + return result } // recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/struct_tag_parser.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/struct_tag_parser.go index 6f406c1..18d85bf 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/struct_tag_parser.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/struct_tag_parser.go @@ -12,12 +12,16 @@ import ( ) // StructTagParser returns the struct tags for a given struct field. +// +// Deprecated: Defining custom BSON struct tag parsers will not be supported in Go Driver 2.0. type StructTagParser interface { ParseStructTags(reflect.StructField) (StructTags, error) } // StructTagParserFunc is an adapter that allows a generic function to be used // as a StructTagParser. +// +// Deprecated: Defining custom BSON struct tag parsers will not be supported in Go Driver 2.0. type StructTagParserFunc func(reflect.StructField) (StructTags, error) // ParseStructTags implements the StructTagParser interface. @@ -34,23 +38,23 @@ func (stpf StructTagParserFunc) ParseStructTags(sf reflect.StructField) (StructT // // The properties are defined below: // -// OmitEmpty Only include the field if it's not set to the zero value for the type or to -// empty slices or maps. +// OmitEmpty Only include the field if it's not set to the zero value for the type or to +// empty slices or maps. // -// MinSize Marshal an integer of a type larger than 32 bits value as an int32, if that's -// feasible while preserving the numeric value. +// MinSize Marshal an integer of a type larger than 32 bits value as an int32, if that's +// feasible while preserving the numeric value. // -// Truncate When unmarshaling a BSON double, it is permitted to lose precision to fit within -// a float32. +// Truncate When unmarshaling a BSON double, it is permitted to lose precision to fit within +// a float32. // -// Inline Inline the field, which must be a struct or a map, causing all of its fields -// or keys to be processed as if they were part of the outer struct. For maps, -// keys must not conflict with the bson keys of other struct fields. +// Inline Inline the field, which must be a struct or a map, causing all of its fields +// or keys to be processed as if they were part of the outer struct. For maps, +// keys must not conflict with the bson keys of other struct fields. // -// Skip This struct field should be skipped. This is usually denoted by parsing a "-" -// for the name. +// Skip This struct field should be skipped. This is usually denoted by parsing a "-" +// for the name. // -// TODO(skriptble): Add tags for undefined as nil and for null as nil. +// Deprecated: Defining custom BSON struct tag parsers will not be supported in Go Driver 2.0. type StructTags struct { Name string OmitEmpty bool @@ -67,24 +71,26 @@ type StructTags struct { // If there is no name in the struct tag fields, the struct field name is lowercased. // The tag formats accepted are: // -// "[][,[,]]" +// "[][,[,]]" // -// `(...) bson:"[][,[,]]" (...)` +// `(...) bson:"[][,[,]]" (...)` // // An example: // -// type T struct { -// A bool -// B int "myb" -// C string "myc,omitempty" -// D string `bson:",omitempty" json:"jsonkey"` -// E int64 ",minsize" -// F int64 "myf,omitempty,minsize" -// } +// type T struct { +// A bool +// B int "myb" +// C string "myc,omitempty" +// D string `bson:",omitempty" json:"jsonkey"` +// E int64 ",minsize" +// F int64 "myf,omitempty,minsize" +// } // // A struct tag either consisting entirely of '-' or with a bson key with a // value consisting entirely of '-' will return a StructTags with Skip true and // the remaining fields will be their default values. +// +// Deprecated: DefaultStructTagParser will be removed in Go Driver 2.0. var DefaultStructTagParser StructTagParserFunc = func(sf reflect.StructField) (StructTags, error) { key := strings.ToLower(sf.Name) tag, ok := sf.Tag.Lookup("bson") @@ -125,6 +131,9 @@ func parseTags(key string, tag string) (StructTags, error) { // JSONFallbackStructTagParser has the same behavior as DefaultStructTagParser // but will also fallback to parsing the json tag instead on a field where the // bson tag isn't available. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.UseJSONStructTags] and +// [go.mongodb.org/mongo-driver/bson.Decoder.UseJSONStructTags] instead. var JSONFallbackStructTagParser StructTagParserFunc = func(sf reflect.StructField) (StructTags, error) { key := strings.ToLower(sf.Name) tag, ok := sf.Tag.Lookup("bson") diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/time_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/time_codec.go index ec7e30f..22fb762 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/time_codec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/time_codec.go @@ -22,18 +22,42 @@ const ( ) // TimeCodec is the Codec used for time.Time values. +// +// Deprecated: TimeCodec will not be directly configurable in Go Driver 2.0. +// To configure the time.Time encode and decode behavior, use the configuration +// methods on a [go.mongodb.org/mongo-driver/bson.Encoder] or +// [go.mongodb.org/mongo-driver/bson.Decoder]. To configure the time.Time encode +// and decode behavior for a mongo.Client, use +// [go.mongodb.org/mongo-driver/mongo/options.ClientOptions.SetBSONOptions]. +// +// For example, to configure a mongo.Client to ..., use: +// +// opt := options.Client().SetBSONOptions(&options.BSONOptions{ +// UseLocalTimeZone: true, +// }) +// +// See the deprecation notice for each field in TimeCodec for the corresponding +// settings. type TimeCodec struct { + // UseLocalTimeZone specifies if we should decode into the local time zone. Defaults to false. + // + // Deprecated: Use bson.Decoder.UseLocalTimeZone or options.BSONOptions.UseLocalTimeZone + // instead. UseLocalTimeZone bool } var ( defaultTimeCodec = NewTimeCodec() - _ ValueCodec = defaultTimeCodec + // Assert that defaultTimeCodec satisfies the typeDecoder interface, which allows it to be used + // by collection type decoders (e.g. map, slice, etc) to set individual values in a collection. _ typeDecoder = defaultTimeCodec ) // NewTimeCodec returns a TimeCodec with options opts. +// +// Deprecated: NewTimeCodec will not be available in Go Driver 2.0. See +// [TimeCodec] for more details. func NewTimeCodec(opts ...*bsonoptions.TimeCodecOptions) *TimeCodec { timeOpt := bsonoptions.MergeTimeCodecOptions(opts...) @@ -95,7 +119,7 @@ func (tc *TimeCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t refle return emptyValue, fmt.Errorf("cannot decode %v into a time.Time", vrType) } - if !tc.UseLocalTimeZone { + if !tc.UseLocalTimeZone && !dc.useLocalTimeZone { timeVal = timeVal.UTC() } return reflect.ValueOf(timeVal), nil @@ -117,7 +141,7 @@ func (tc *TimeCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val re } // EncodeValue is the ValueEncoderFunc for time.TIme. -func (tc *TimeCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +func (tc *TimeCodec) EncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tTime { return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/types.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/types.go index 07f4b70..6ade17b 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/types.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/types.go @@ -34,6 +34,7 @@ var tValueUnmarshaler = reflect.TypeOf((*ValueUnmarshaler)(nil)).Elem() var tMarshaler = reflect.TypeOf((*Marshaler)(nil)).Elem() var tUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem() var tProxy = reflect.TypeOf((*Proxy)(nil)).Elem() +var tZeroer = reflect.TypeOf((*Zeroer)(nil)).Elem() var tBinary = reflect.TypeOf(primitive.Binary{}) var tUndefined = reflect.TypeOf(primitive.Undefined{}) diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/uint_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/uint_codec.go index 0b21ce9..39b0713 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/uint_codec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/uint_codec.go @@ -17,18 +17,43 @@ import ( ) // UIntCodec is the Codec used for uint values. +// +// Deprecated: UIntCodec will not be directly configurable in Go Driver 2.0. To +// configure the uint encode and decode behavior, use the configuration methods +// on a [go.mongodb.org/mongo-driver/bson.Encoder] or +// [go.mongodb.org/mongo-driver/bson.Decoder]. To configure the uint encode and +// decode behavior for a mongo.Client, use +// [go.mongodb.org/mongo-driver/mongo/options.ClientOptions.SetBSONOptions]. +// +// For example, to configure a mongo.Client to marshal Go uint values as the +// minimum BSON int size that can represent the value, use: +// +// opt := options.Client().SetBSONOptions(&options.BSONOptions{ +// IntMinSize: true, +// }) +// +// See the deprecation notice for each field in UIntCodec for the corresponding +// settings. type UIntCodec struct { + // EncodeToMinSize causes EncodeValue to marshal Go uint values (excluding uint64) as the + // minimum BSON int size (either 32-bit or 64-bit) that can represent the integer value. + // + // Deprecated: Use bson.Encoder.IntMinSize or options.BSONOptions.IntMinSize instead. EncodeToMinSize bool } var ( defaultUIntCodec = NewUIntCodec() - _ ValueCodec = defaultUIntCodec + // Assert that defaultUIntCodec satisfies the typeDecoder interface, which allows it to be used + // by collection type decoders (e.g. map, slice, etc) to set individual values in a collection. _ typeDecoder = defaultUIntCodec ) // NewUIntCodec returns a UIntCodec with options opts. +// +// Deprecated: NewUIntCodec will not be available in Go Driver 2.0. See +// [UIntCodec] for more details. func NewUIntCodec(opts ...*bsonoptions.UIntCodecOptions) *UIntCodec { uintOpt := bsonoptions.MergeUIntCodecOptions(opts...) @@ -139,11 +164,15 @@ func (uic *UIntCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t refl return reflect.ValueOf(uint64(i64)), nil case reflect.Uint: - if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint + if i64 < 0 { + return emptyValue, fmt.Errorf("%d overflows uint", i64) + } + v := uint64(i64) + if v > math.MaxUint { // Can we fit this inside of an uint return emptyValue, fmt.Errorf("%d overflows uint", i64) } - return reflect.ValueOf(uint(i64)), nil + return reflect.ValueOf(uint(v)), nil default: return emptyValue, ValueDecoderError{ Name: "UintDecodeValue", diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/byte_slice_codec_options.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/byte_slice_codec_options.go index b1256a4..996bd17 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/byte_slice_codec_options.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/byte_slice_codec_options.go @@ -7,22 +7,33 @@ package bsonoptions // ByteSliceCodecOptions represents all possible options for byte slice encoding and decoding. +// +// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal +// and unmarshal behavior instead. type ByteSliceCodecOptions struct { EncodeNilAsEmpty *bool // Specifies if a nil byte slice should encode as an empty binary instead of null. Defaults to false. } // ByteSliceCodec creates a new *ByteSliceCodecOptions +// +// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal +// and unmarshal behavior instead. func ByteSliceCodec() *ByteSliceCodecOptions { return &ByteSliceCodecOptions{} } // SetEncodeNilAsEmpty specifies if a nil byte slice should encode as an empty binary instead of null. Defaults to false. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilByteSliceAsEmpty] instead. func (bs *ByteSliceCodecOptions) SetEncodeNilAsEmpty(b bool) *ByteSliceCodecOptions { bs.EncodeNilAsEmpty = &b return bs } // MergeByteSliceCodecOptions combines the given *ByteSliceCodecOptions into a single *ByteSliceCodecOptions in a last one wins fashion. +// +// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a +// single options struct instead. func MergeByteSliceCodecOptions(opts ...*ByteSliceCodecOptions) *ByteSliceCodecOptions { bs := ByteSliceCodec() for _, opt := range opts { diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/doc.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/doc.go new file mode 100644 index 0000000..c40973c --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/doc.go @@ -0,0 +1,8 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +// Package bsonoptions defines the optional configurations for the BSON codecs. +package bsonoptions diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/empty_interface_codec_options.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/empty_interface_codec_options.go index 6caaa00..f522c7e 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/empty_interface_codec_options.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/empty_interface_codec_options.go @@ -7,22 +7,33 @@ package bsonoptions // EmptyInterfaceCodecOptions represents all possible options for interface{} encoding and decoding. +// +// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal +// and unmarshal behavior instead. type EmptyInterfaceCodecOptions struct { DecodeBinaryAsSlice *bool // Specifies if Old and Generic type binarys should default to []slice instead of primitive.Binary. Defaults to false. } // EmptyInterfaceCodec creates a new *EmptyInterfaceCodecOptions +// +// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal +// and unmarshal behavior instead. func EmptyInterfaceCodec() *EmptyInterfaceCodecOptions { return &EmptyInterfaceCodecOptions{} } // SetDecodeBinaryAsSlice specifies if Old and Generic type binarys should default to []slice instead of primitive.Binary. Defaults to false. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.BinaryAsSlice] instead. func (e *EmptyInterfaceCodecOptions) SetDecodeBinaryAsSlice(b bool) *EmptyInterfaceCodecOptions { e.DecodeBinaryAsSlice = &b return e } // MergeEmptyInterfaceCodecOptions combines the given *EmptyInterfaceCodecOptions into a single *EmptyInterfaceCodecOptions in a last one wins fashion. +// +// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a +// single options struct instead. func MergeEmptyInterfaceCodecOptions(opts ...*EmptyInterfaceCodecOptions) *EmptyInterfaceCodecOptions { e := EmptyInterfaceCodec() for _, opt := range opts { diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/map_codec_options.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/map_codec_options.go index 7a6a880..a7a7c1d 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/map_codec_options.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/map_codec_options.go @@ -7,6 +7,9 @@ package bsonoptions // MapCodecOptions represents all possible options for map encoding and decoding. +// +// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal +// and unmarshal behavior instead. type MapCodecOptions struct { DecodeZerosMap *bool // Specifies if the map should be zeroed before decoding into it. Defaults to false. EncodeNilAsEmpty *bool // Specifies if a nil map should encode as an empty document instead of null. Defaults to false. @@ -19,17 +22,24 @@ type MapCodecOptions struct { } // MapCodec creates a new *MapCodecOptions +// +// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal +// and unmarshal behavior instead. func MapCodec() *MapCodecOptions { return &MapCodecOptions{} } // SetDecodeZerosMap specifies if the map should be zeroed before decoding into it. Defaults to false. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.ZeroMaps] instead. func (t *MapCodecOptions) SetDecodeZerosMap(b bool) *MapCodecOptions { t.DecodeZerosMap = &b return t } // SetEncodeNilAsEmpty specifies if a nil map should encode as an empty document instead of null. Defaults to false. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilMapAsEmpty] instead. func (t *MapCodecOptions) SetEncodeNilAsEmpty(b bool) *MapCodecOptions { t.EncodeNilAsEmpty = &b return t @@ -40,12 +50,17 @@ func (t *MapCodecOptions) SetEncodeNilAsEmpty(b bool) *MapCodecOptions { // type must either be a string, an integer type, or implement bsoncodec.KeyUnmarshaler. If true, keys are encoded with // fmt.Sprint() and the encoding key type must be a string, an integer type, or a float. If true, the use of Stringer // will override TextMarshaler/TextUnmarshaler. Defaults to false. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.StringifyMapKeysWithFmt] instead. func (t *MapCodecOptions) SetEncodeKeysWithStringer(b bool) *MapCodecOptions { t.EncodeKeysWithStringer = &b return t } // MergeMapCodecOptions combines the given *MapCodecOptions into a single *MapCodecOptions in a last one wins fashion. +// +// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a +// single options struct instead. func MergeMapCodecOptions(opts ...*MapCodecOptions) *MapCodecOptions { s := MapCodec() for _, opt := range opts { diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/slice_codec_options.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/slice_codec_options.go index ef965e4..3c1e4f3 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/slice_codec_options.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/slice_codec_options.go @@ -7,22 +7,33 @@ package bsonoptions // SliceCodecOptions represents all possible options for slice encoding and decoding. +// +// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal +// and unmarshal behavior instead. type SliceCodecOptions struct { EncodeNilAsEmpty *bool // Specifies if a nil slice should encode as an empty array instead of null. Defaults to false. } // SliceCodec creates a new *SliceCodecOptions +// +// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal +// and unmarshal behavior instead. func SliceCodec() *SliceCodecOptions { return &SliceCodecOptions{} } // SetEncodeNilAsEmpty specifies if a nil slice should encode as an empty array instead of null. Defaults to false. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilSliceAsEmpty] instead. func (s *SliceCodecOptions) SetEncodeNilAsEmpty(b bool) *SliceCodecOptions { s.EncodeNilAsEmpty = &b return s } // MergeSliceCodecOptions combines the given *SliceCodecOptions into a single *SliceCodecOptions in a last one wins fashion. +// +// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a +// single options struct instead. func MergeSliceCodecOptions(opts ...*SliceCodecOptions) *SliceCodecOptions { s := SliceCodec() for _, opt := range opts { diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/string_codec_options.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/string_codec_options.go index 65964f4..f8b76f9 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/string_codec_options.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/string_codec_options.go @@ -9,23 +9,34 @@ package bsonoptions var defaultDecodeOIDAsHex = true // StringCodecOptions represents all possible options for string encoding and decoding. +// +// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal +// and unmarshal behavior instead. type StringCodecOptions struct { DecodeObjectIDAsHex *bool // Specifies if we should decode ObjectID as the hex value. Defaults to true. } // StringCodec creates a new *StringCodecOptions +// +// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal +// and unmarshal behavior instead. func StringCodec() *StringCodecOptions { return &StringCodecOptions{} } // SetDecodeObjectIDAsHex specifies if object IDs should be decoded as their hex representation. If false, a string made // from the raw object ID bytes will be used. Defaults to true. +// +// Deprecated: Decoding object IDs as raw bytes will not be supported in Go Driver 2.0. func (t *StringCodecOptions) SetDecodeObjectIDAsHex(b bool) *StringCodecOptions { t.DecodeObjectIDAsHex = &b return t } // MergeStringCodecOptions combines the given *StringCodecOptions into a single *StringCodecOptions in a last one wins fashion. +// +// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a +// single options struct instead. func MergeStringCodecOptions(opts ...*StringCodecOptions) *StringCodecOptions { s := &StringCodecOptions{&defaultDecodeOIDAsHex} for _, opt := range opts { diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/struct_codec_options.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/struct_codec_options.go index 78d1dd8..1cbfa32 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/struct_codec_options.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/struct_codec_options.go @@ -9,6 +9,9 @@ package bsonoptions var defaultOverwriteDuplicatedInlinedFields = true // StructCodecOptions represents all possible options for struct encoding and decoding. +// +// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal +// and unmarshal behavior instead. type StructCodecOptions struct { DecodeZeroStruct *bool // Specifies if structs should be zeroed before decoding into them. Defaults to false. DecodeDeepZeroInline *bool // Specifies if structs should be recursively zeroed when a inline value is decoded. Defaults to false. @@ -18,17 +21,24 @@ type StructCodecOptions struct { } // StructCodec creates a new *StructCodecOptions +// +// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal +// and unmarshal behavior instead. func StructCodec() *StructCodecOptions { return &StructCodecOptions{} } // SetDecodeZeroStruct specifies if structs should be zeroed before decoding into them. Defaults to false. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.ZeroStructs] instead. func (t *StructCodecOptions) SetDecodeZeroStruct(b bool) *StructCodecOptions { t.DecodeZeroStruct = &b return t } // SetDecodeDeepZeroInline specifies if structs should be zeroed before decoding into them. Defaults to false. +// +// Deprecated: DecodeDeepZeroInline will not be supported in Go Driver 2.0. func (t *StructCodecOptions) SetDecodeDeepZeroInline(b bool) *StructCodecOptions { t.DecodeDeepZeroInline = &b return t @@ -36,6 +46,8 @@ func (t *StructCodecOptions) SetDecodeDeepZeroInline(b bool) *StructCodecOptions // SetEncodeOmitDefaultStruct specifies if default structs should be considered empty by omitempty. A default struct has all // its values set to their default value. Defaults to false. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.OmitZeroStruct] instead. func (t *StructCodecOptions) SetEncodeOmitDefaultStruct(b bool) *StructCodecOptions { t.EncodeOmitDefaultStruct = &b return t @@ -45,18 +57,26 @@ func (t *StructCodecOptions) SetEncodeOmitDefaultStruct(b bool) *StructCodecOpti // same bson key. When true and decoding, values will be written to the outermost struct with a matching key, and when // encoding, keys will have the value of the top-most matching field. When false, decoding and encoding will error if // there are duplicate keys after the struct is inlined. Defaults to true. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.ErrorOnInlineDuplicates] instead. func (t *StructCodecOptions) SetOverwriteDuplicatedInlinedFields(b bool) *StructCodecOptions { t.OverwriteDuplicatedInlinedFields = &b return t } // SetAllowUnexportedFields specifies if unexported fields should be marshaled/unmarshaled. Defaults to false. +// +// Deprecated: AllowUnexportedFields does not work on recent versions of Go and will not be +// supported in Go Driver 2.0. func (t *StructCodecOptions) SetAllowUnexportedFields(b bool) *StructCodecOptions { t.AllowUnexportedFields = &b return t } // MergeStructCodecOptions combines the given *StructCodecOptions into a single *StructCodecOptions in a last one wins fashion. +// +// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a +// single options struct instead. func MergeStructCodecOptions(opts ...*StructCodecOptions) *StructCodecOptions { s := &StructCodecOptions{ OverwriteDuplicatedInlinedFields: &defaultOverwriteDuplicatedInlinedFields, diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/time_codec_options.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/time_codec_options.go index 13496d1..3f38433 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/time_codec_options.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/time_codec_options.go @@ -7,22 +7,33 @@ package bsonoptions // TimeCodecOptions represents all possible options for time.Time encoding and decoding. +// +// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal +// and unmarshal behavior instead. type TimeCodecOptions struct { UseLocalTimeZone *bool // Specifies if we should decode into the local time zone. Defaults to false. } // TimeCodec creates a new *TimeCodecOptions +// +// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal +// and unmarshal behavior instead. func TimeCodec() *TimeCodecOptions { return &TimeCodecOptions{} } // SetUseLocalTimeZone specifies if we should decode into the local time zone. Defaults to false. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.UseLocalTimeZone] instead. func (t *TimeCodecOptions) SetUseLocalTimeZone(b bool) *TimeCodecOptions { t.UseLocalTimeZone = &b return t } // MergeTimeCodecOptions combines the given *TimeCodecOptions into a single *TimeCodecOptions in a last one wins fashion. +// +// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a +// single options struct instead. func MergeTimeCodecOptions(opts ...*TimeCodecOptions) *TimeCodecOptions { t := TimeCodec() for _, opt := range opts { diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/uint_codec_options.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/uint_codec_options.go index e08b7f1..5091e4d 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/uint_codec_options.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonoptions/uint_codec_options.go @@ -7,22 +7,33 @@ package bsonoptions // UIntCodecOptions represents all possible options for uint encoding and decoding. +// +// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal +// and unmarshal behavior instead. type UIntCodecOptions struct { EncodeToMinSize *bool // Specifies if all uints except uint64 should be decoded to minimum size bsontype. Defaults to false. } // UIntCodec creates a new *UIntCodecOptions +// +// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal +// and unmarshal behavior instead. func UIntCodec() *UIntCodecOptions { return &UIntCodecOptions{} } // SetEncodeToMinSize specifies if all uints except uint64 should be decoded to minimum size bsontype. Defaults to false. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.IntMinSize] instead. func (u *UIntCodecOptions) SetEncodeToMinSize(b bool) *UIntCodecOptions { u.EncodeToMinSize = &b return u } // MergeUIntCodecOptions combines the given *UIntCodecOptions into a single *UIntCodecOptions in a last one wins fashion. +// +// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a +// single options struct instead. func MergeUIntCodecOptions(opts ...*UIntCodecOptions) *UIntCodecOptions { u := UIntCodec() for _, opt := range opts { diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/copier.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/copier.go index 5cdf646..1e25570 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/copier.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/copier.go @@ -7,6 +7,7 @@ package bsonrw import ( + "errors" "fmt" "io" @@ -17,20 +18,32 @@ import ( // Copier is a type that allows copying between ValueReaders, ValueWriters, and // []byte values. +// +// Deprecated: Copying BSON documents using the ValueWriter and ValueReader interfaces will not be +// supported in Go Driver 2.0. type Copier struct{} // NewCopier creates a new copier with the given registry. If a nil registry is provided // a default registry is used. +// +// Deprecated: Copying BSON documents using the ValueWriter and ValueReader interfaces will not be +// supported in Go Driver 2.0. func NewCopier() Copier { return Copier{} } // CopyDocument handles copying a document from src to dst. +// +// Deprecated: Copying BSON documents using the ValueWriter and ValueReader interfaces will not be +// supported in Go Driver 2.0. func CopyDocument(dst ValueWriter, src ValueReader) error { return Copier{}.CopyDocument(dst, src) } // CopyDocument handles copying one document from the src to the dst. +// +// Deprecated: Copying BSON documents using the ValueWriter and ValueReader interfaces will not be +// supported in Go Driver 2.0. func (c Copier) CopyDocument(dst ValueWriter, src ValueReader) error { dr, err := src.ReadDocument() if err != nil { @@ -47,6 +60,9 @@ func (c Copier) CopyDocument(dst ValueWriter, src ValueReader) error { // CopyArrayFromBytes copies the values from a BSON array represented as a // []byte to a ValueWriter. +// +// Deprecated: Copying BSON arrays using the ValueWriter and ValueReader interfaces will not be +// supported in Go Driver 2.0. func (c Copier) CopyArrayFromBytes(dst ValueWriter, src []byte) error { aw, err := dst.WriteArray() if err != nil { @@ -63,6 +79,9 @@ func (c Copier) CopyArrayFromBytes(dst ValueWriter, src []byte) error { // CopyDocumentFromBytes copies the values from a BSON document represented as a // []byte to a ValueWriter. +// +// Deprecated: Copying BSON documents using the ValueWriter and ValueReader interfaces will not be +// supported in Go Driver 2.0. func (c Copier) CopyDocumentFromBytes(dst ValueWriter, src []byte) error { dw, err := dst.WriteDocument() if err != nil { @@ -81,6 +100,9 @@ type writeElementFn func(key string) (ValueWriter, error) // CopyBytesToArrayWriter copies the values from a BSON Array represented as a []byte to an // ArrayWriter. +// +// Deprecated: Copying BSON arrays using the ArrayWriter interface will not be supported in Go +// Driver 2.0. func (c Copier) CopyBytesToArrayWriter(dst ArrayWriter, src []byte) error { wef := func(_ string) (ValueWriter, error) { return dst.WriteArrayElement() @@ -91,6 +113,9 @@ func (c Copier) CopyBytesToArrayWriter(dst ArrayWriter, src []byte) error { // CopyBytesToDocumentWriter copies the values from a BSON document represented as a []byte to a // DocumentWriter. +// +// Deprecated: Copying BSON documents using the ValueWriter and ValueReader interfaces will not be +// supported in Go Driver 2.0. func (c Copier) CopyBytesToDocumentWriter(dst DocumentWriter, src []byte) error { wef := func(key string) (ValueWriter, error) { return dst.WriteDocumentElement(key) @@ -100,7 +125,7 @@ func (c Copier) CopyBytesToDocumentWriter(dst DocumentWriter, src []byte) error } func (c Copier) copyBytesToValueWriter(src []byte, wef writeElementFn) error { - // TODO(skriptble): Create errors types here. Anything thats a tag should be a property. + // TODO(skriptble): Create errors types here. Anything that is a tag should be a property. length, rem, ok := bsoncore.ReadLength(src) if !ok { return fmt.Errorf("couldn't read length from src, not enough bytes. length=%d", len(src)) @@ -150,12 +175,18 @@ func (c Copier) copyBytesToValueWriter(src []byte, wef writeElementFn) error { // CopyDocumentToBytes copies an entire document from the ValueReader and // returns it as bytes. +// +// Deprecated: Copying BSON documents using the ValueWriter and ValueReader interfaces will not be +// supported in Go Driver 2.0. func (c Copier) CopyDocumentToBytes(src ValueReader) ([]byte, error) { return c.AppendDocumentBytes(nil, src) } // AppendDocumentBytes functions the same as CopyDocumentToBytes, but will // append the result to dst. +// +// Deprecated: Copying BSON documents using the ValueWriter and ValueReader interfaces will not be +// supported in Go Driver 2.0. func (c Copier) AppendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) { if br, ok := src.(BytesReader); ok { _, dst, err := br.ReadValueBytes(dst) @@ -163,7 +194,7 @@ func (c Copier) AppendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) } vw := vwPool.Get().(*valueWriter) - defer vwPool.Put(vw) + defer putValueWriter(vw) vw.reset(dst) @@ -173,6 +204,9 @@ func (c Copier) AppendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) } // AppendArrayBytes copies an array from the ValueReader to dst. +// +// Deprecated: Copying BSON arrays using the ValueWriter and ValueReader interfaces will not be +// supported in Go Driver 2.0. func (c Copier) AppendArrayBytes(dst []byte, src ValueReader) ([]byte, error) { if br, ok := src.(BytesReader); ok { _, dst, err := br.ReadValueBytes(dst) @@ -180,7 +214,7 @@ func (c Copier) AppendArrayBytes(dst []byte, src ValueReader) ([]byte, error) { } vw := vwPool.Get().(*valueWriter) - defer vwPool.Put(vw) + defer putValueWriter(vw) vw.reset(dst) @@ -190,6 +224,8 @@ func (c Copier) AppendArrayBytes(dst []byte, src ValueReader) ([]byte, error) { } // CopyValueFromBytes will write the value represtend by t and src to dst. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.UnmarshalValue] instead. func (c Copier) CopyValueFromBytes(dst ValueWriter, t bsontype.Type, src []byte) error { if wvb, ok := dst.(BytesWriter); ok { return wvb.WriteValueBytes(t, src) @@ -206,19 +242,24 @@ func (c Copier) CopyValueFromBytes(dst ValueWriter, t bsontype.Type, src []byte) // CopyValueToBytes copies a value from src and returns it as a bsontype.Type and a // []byte. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.MarshalValue] instead. func (c Copier) CopyValueToBytes(src ValueReader) (bsontype.Type, []byte, error) { return c.AppendValueBytes(nil, src) } // AppendValueBytes functions the same as CopyValueToBytes, but will append the // result to dst. +// +// Deprecated: Appending individual BSON elements to an existing slice will not be supported in Go +// Driver 2.0. func (c Copier) AppendValueBytes(dst []byte, src ValueReader) (bsontype.Type, []byte, error) { if br, ok := src.(BytesReader); ok { return br.ReadValueBytes(dst) } vw := vwPool.Get().(*valueWriter) - defer vwPool.Put(vw) + defer putValueWriter(vw) start := len(dst) @@ -234,6 +275,9 @@ func (c Copier) AppendValueBytes(dst []byte, src ValueReader) (bsontype.Type, [] } // CopyValue will copy a single value from src to dst. +// +// Deprecated: Copying BSON values using the ValueWriter and ValueReader interfaces will not be +// supported in Go Driver 2.0. func (c Copier) CopyValue(dst ValueWriter, src ValueReader) error { var err error switch src.Type() { @@ -399,7 +443,7 @@ func (c Copier) copyArray(dst ValueWriter, src ValueReader) error { for { vr, err := ar.ReadValue() - if err == ErrEOA { + if errors.Is(err, ErrEOA) { break } if err != nil { @@ -423,7 +467,7 @@ func (c Copier) copyArray(dst ValueWriter, src ValueReader) error { func (c Copier) copyDocumentCore(dw DocumentWriter, dr DocumentReader) error { for { key, vr, err := dr.ReadElement() - if err == ErrEOD { + if errors.Is(err, ErrEOD) { break } if err != nil { diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_parser.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_parser.go index 54c76bf..f0702d9 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_parser.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_parser.go @@ -305,7 +305,7 @@ func (ejp *extJSONParser) readValue(t bsontype.Type) (*extJSONValue, error) { } // remove hyphens - uuidNoHyphens := strings.Replace(uuid, "-", "", -1) + uuidNoHyphens := strings.ReplaceAll(uuid, "-", "") if len(uuidNoHyphens) != 32 { return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens") } @@ -313,7 +313,7 @@ func (ejp *extJSONParser) readValue(t bsontype.Type) (*extJSONValue, error) { // convert hex to bytes bytes, err := hex.DecodeString(uuidNoHyphens) if err != nil { - return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %v", err) + return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %w", err) } ejp.advanceState() diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_reader.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_reader.go index 35832d7..59ddfc4 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_reader.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_reader.go @@ -7,6 +7,7 @@ package bsonrw import ( + "errors" "fmt" "io" "sync" @@ -16,11 +17,15 @@ import ( ) // ExtJSONValueReaderPool is a pool for ValueReaders that read ExtJSON. +// +// Deprecated: ExtJSONValueReaderPool will not be supported in Go Driver 2.0. type ExtJSONValueReaderPool struct { pool sync.Pool } // NewExtJSONValueReaderPool instantiates a new ExtJSONValueReaderPool. +// +// Deprecated: ExtJSONValueReaderPool will not be supported in Go Driver 2.0. func NewExtJSONValueReaderPool() *ExtJSONValueReaderPool { return &ExtJSONValueReaderPool{ pool: sync.Pool{ @@ -32,6 +37,8 @@ func NewExtJSONValueReaderPool() *ExtJSONValueReaderPool { } // Get retrieves a ValueReader from the pool and uses src as the underlying ExtJSON. +// +// Deprecated: ExtJSONValueReaderPool will not be supported in Go Driver 2.0. func (bvrp *ExtJSONValueReaderPool) Get(r io.Reader, canonical bool) (ValueReader, error) { vr := bvrp.pool.Get().(*extJSONValueReader) return vr.reset(r, canonical) @@ -39,6 +46,8 @@ func (bvrp *ExtJSONValueReaderPool) Get(r io.Reader, canonical bool) (ValueReade // Put inserts a ValueReader into the pool. If the ValueReader is not a ExtJSON ValueReader nothing // is inserted into the pool and ok will be false. +// +// Deprecated: ExtJSONValueReaderPool will not be supported in Go Driver 2.0. func (bvrp *ExtJSONValueReaderPool) Put(vr ValueReader) (ok bool) { bvr, ok := vr.(*extJSONValueReader) if !ok { @@ -605,7 +614,7 @@ func (ejvr *extJSONValueReader) ReadElement() (string, ValueReader, error) { name, t, err := ejvr.p.readKey() if err != nil { - if err == ErrEOD { + if errors.Is(err, ErrEOD) { if ejvr.stack[ejvr.frame].mode == mCodeWithScope { _, err := ejvr.p.peekType() if err != nil { @@ -632,7 +641,7 @@ func (ejvr *extJSONValueReader) ReadValue() (ValueReader, error) { t, err := ejvr.p.peekType() if err != nil { - if err == ErrEOA { + if errors.Is(err, ErrEOA) { ejvr.pop() } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_wrappers.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_wrappers.go index 9695704..af6ae7b 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_wrappers.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_wrappers.go @@ -95,9 +95,9 @@ func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) { return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t) } - i, err := strconv.ParseInt(val.v.(string), 16, 64) + i, err := strconv.ParseUint(val.v.(string), 16, 8) if err != nil { - return nil, 0, fmt.Errorf("invalid $binary subType string: %s", val.v.(string)) + return nil, 0, fmt.Errorf("invalid $binary subType string: %q: %w", val.v.(string), err) } subType = byte(i) diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_writer.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_writer.go index 99ed524..86a2935 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_writer.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_writer.go @@ -23,11 +23,15 @@ import ( ) // ExtJSONValueWriterPool is a pool for ExtJSON ValueWriters. +// +// Deprecated: ExtJSONValueWriterPool will not be supported in Go Driver 2.0. type ExtJSONValueWriterPool struct { pool sync.Pool } // NewExtJSONValueWriterPool creates a new pool for ValueWriter instances that write to ExtJSON. +// +// Deprecated: ExtJSONValueWriterPool will not be supported in Go Driver 2.0. func NewExtJSONValueWriterPool() *ExtJSONValueWriterPool { return &ExtJSONValueWriterPool{ pool: sync.Pool{ @@ -39,6 +43,8 @@ func NewExtJSONValueWriterPool() *ExtJSONValueWriterPool { } // Get retrieves a ExtJSON ValueWriter from the pool and resets it to use w as the destination. +// +// Deprecated: ExtJSONValueWriterPool will not be supported in Go Driver 2.0. func (bvwp *ExtJSONValueWriterPool) Get(w io.Writer, canonical, escapeHTML bool) ValueWriter { vw := bvwp.pool.Get().(*extJSONValueWriter) if writer, ok := w.(*SliceWriter); ok { @@ -53,6 +59,8 @@ func (bvwp *ExtJSONValueWriterPool) Get(w io.Writer, canonical, escapeHTML bool) // Put inserts a ValueWriter into the pool. If the ValueWriter is not a ExtJSON ValueWriter, nothing // happens and ok will be false. +// +// Deprecated: ExtJSONValueWriterPool will not be supported in Go Driver 2.0. func (bvwp *ExtJSONValueWriterPool) Put(vw ValueWriter) (ok bool) { bvw, ok := vw.(*extJSONValueWriter) if !ok { @@ -80,6 +88,7 @@ type extJSONValueWriter struct { frame int64 canonical bool escapeHTML bool + newlines bool } // NewExtJSONValueWriter creates a ValueWriter that writes Extended JSON to w. @@ -88,10 +97,13 @@ func NewExtJSONValueWriter(w io.Writer, canonical, escapeHTML bool) (ValueWriter return nil, errNilWriter } - return newExtJSONWriter(w, canonical, escapeHTML), nil + // Enable newlines for all Extended JSON value writers created by NewExtJSONValueWriter. We + // expect these value writers to be used with an Encoder, which should add newlines after + // encoded Extended JSON documents. + return newExtJSONWriter(w, canonical, escapeHTML, true), nil } -func newExtJSONWriter(w io.Writer, canonical, escapeHTML bool) *extJSONValueWriter { +func newExtJSONWriter(w io.Writer, canonical, escapeHTML, newlines bool) *extJSONValueWriter { stack := make([]ejvwState, 1, 5) stack[0] = ejvwState{mode: mTopLevel} @@ -101,6 +113,7 @@ func newExtJSONWriter(w io.Writer, canonical, escapeHTML bool) *extJSONValueWrit stack: stack, canonical: canonical, escapeHTML: escapeHTML, + newlines: newlines, } } @@ -455,12 +468,13 @@ func (ejvw *extJSONValueWriter) WriteRegex(pattern string, options string) error return err } + options = sortStringAlphebeticAscending(options) var buf bytes.Buffer buf.WriteString(`{"$regularExpression":{"pattern":`) writeStringWithEscapes(pattern, &buf, ejvw.escapeHTML) - buf.WriteString(`,"options":"`) - buf.WriteString(sortStringAlphebeticAscending(options)) - buf.WriteString(`"}},`) + buf.WriteString(`,"options":`) + writeStringWithEscapes(options, &buf, ejvw.escapeHTML) + buf.WriteString(`}},`) ejvw.buf = append(ejvw.buf, buf.Bytes()...) @@ -564,6 +578,12 @@ func (ejvw *extJSONValueWriter) WriteDocumentEnd() error { case mDocument: ejvw.buf = append(ejvw.buf, ',') case mTopLevel: + // If the value writer has newlines enabled, end top-level documents with a newline so that + // multiple documents encoded to the same writer are separated by newlines. That matches the + // Go json.Encoder behavior and also works with bsonrw.NewExtJSONValueReader. + if ejvw.newlines { + ejvw.buf = append(ejvw.buf, '\n') + } if ejvw.w != nil { if _, err := ejvw.w.Write(ejvw.buf); err != nil { return err @@ -609,13 +629,14 @@ func (ejvw *extJSONValueWriter) WriteArrayEnd() error { func formatDouble(f float64) string { var s string - if math.IsInf(f, 1) { + switch { + case math.IsInf(f, 1): s = "Infinity" - } else if math.IsInf(f, -1) { + case math.IsInf(f, -1): s = "-Infinity" - } else if math.IsNaN(f) { + case math.IsNaN(f): s = "NaN" - } else { + default: // Print exactly one decimalType place for integers; otherwise, print as many are necessary to // perfectly represent it. s = strconv.FormatFloat(f, 'G', -1, 64) @@ -720,9 +741,7 @@ func (ss sortableString) Less(i, j int) bool { } func (ss sortableString) Swap(i, j int) { - oldI := ss[i] - ss[i] = ss[j] - ss[j] = oldI + ss[i], ss[j] = ss[j], ss[i] } func sortStringAlphebeticAscending(s string) string { diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/json_scanner.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/json_scanner.go index cd4843a..9782891 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/json_scanner.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/json_scanner.go @@ -58,7 +58,7 @@ func (js *jsonScanner) nextToken() (*jsonToken, error) { c, err = js.readNextByte() } - if err == io.EOF { + if errors.Is(err, io.EOF) { return &jsonToken{t: jttEOF}, nil } else if err != nil { return nil, err @@ -82,12 +82,13 @@ func (js *jsonScanner) nextToken() (*jsonToken, error) { return js.scanString() default: // check if it's a number - if c == '-' || isDigit(c) { + switch { + case c == '-' || isDigit(c): return js.scanNumber(c) - } else if c == 't' || c == 'f' || c == 'n' { + case c == 't' || c == 'f' || c == 'n': // maybe a literal return js.scanLiteral(c) - } else { + default: return nil, fmt.Errorf("invalid JSON input. Position: %d. Character: %c", js.pos-1, c) } } @@ -174,7 +175,7 @@ func getu4(s []byte) rune { for _, c := range s[:4] { switch { case '0' <= c && c <= '9': - c = c - '0' + c -= '0' case 'a' <= c && c <= 'f': c = c - 'a' + 10 case 'A' <= c && c <= 'F': @@ -198,7 +199,7 @@ func (js *jsonScanner) scanString() (*jsonToken, error) { for { c, err = js.readNextByte() if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil, errors.New("end of input in JSON string") } return nil, err @@ -209,7 +210,7 @@ func (js *jsonScanner) scanString() (*jsonToken, error) { case '\\': c, err = js.readNextByte() if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil, errors.New("end of input in JSON string") } return nil, err @@ -248,7 +249,7 @@ func (js *jsonScanner) scanString() (*jsonToken, error) { if utf16.IsSurrogate(rn) { c, err = js.readNextByte() if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil, errors.New("end of input in JSON string") } return nil, err @@ -264,7 +265,7 @@ func (js *jsonScanner) scanString() (*jsonToken, error) { c, err = js.readNextByte() if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil, errors.New("end of input in JSON string") } return nil, err @@ -325,17 +326,18 @@ func (js *jsonScanner) scanLiteral(first byte) (*jsonToken, error) { c5, err := js.readNextByte() - if bytes.Equal([]byte("true"), lit) && (isValueTerminator(c5) || err == io.EOF) { + switch { + case bytes.Equal([]byte("true"), lit) && (isValueTerminator(c5) || errors.Is(err, io.EOF)): js.pos = int(math.Max(0, float64(js.pos-1))) return &jsonToken{t: jttBool, v: true, p: p}, nil - } else if bytes.Equal([]byte("null"), lit) && (isValueTerminator(c5) || err == io.EOF) { + case bytes.Equal([]byte("null"), lit) && (isValueTerminator(c5) || errors.Is(err, io.EOF)): js.pos = int(math.Max(0, float64(js.pos-1))) return &jsonToken{t: jttNull, v: nil, p: p}, nil - } else if bytes.Equal([]byte("fals"), lit) { + case bytes.Equal([]byte("fals"), lit): if c5 == 'e' { c5, err = js.readNextByte() - if isValueTerminator(c5) || err == io.EOF { + if isValueTerminator(c5) || errors.Is(err, io.EOF) { js.pos = int(math.Max(0, float64(js.pos-1))) return &jsonToken{t: jttBool, v: false, p: p}, nil } @@ -384,7 +386,7 @@ func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) { for { c, err = js.readNextByte() - if err != nil && err != io.EOF { + if err != nil && !errors.Is(err, io.EOF) { return nil, err } @@ -413,7 +415,7 @@ func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) { case '}', ']', ',': s = nssDone default: - if isWhiteSpace(c) || err == io.EOF { + if isWhiteSpace(c) || errors.Is(err, io.EOF) { s = nssDone } else { s = nssInvalid @@ -430,12 +432,13 @@ func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) { case '}', ']', ',': s = nssDone default: - if isWhiteSpace(c) || err == io.EOF { + switch { + case isWhiteSpace(c) || errors.Is(err, io.EOF): s = nssDone - } else if isDigit(c) { + case isDigit(c): s = nssSawIntegerDigits b.WriteByte(c) - } else { + default: s = nssInvalid } } @@ -455,12 +458,13 @@ func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) { case '}', ']', ',': s = nssDone default: - if isWhiteSpace(c) || err == io.EOF { + switch { + case isWhiteSpace(c) || errors.Is(err, io.EOF): s = nssDone - } else if isDigit(c) { + case isDigit(c): s = nssSawFractionDigits b.WriteByte(c) - } else { + default: s = nssInvalid } } @@ -490,12 +494,13 @@ func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) { case '}', ']', ',': s = nssDone default: - if isWhiteSpace(c) || err == io.EOF { + switch { + case isWhiteSpace(c) || errors.Is(err, io.EOF): s = nssDone - } else if isDigit(c) { + case isDigit(c): s = nssSawExponentDigits b.WriteByte(c) - } else { + default: s = nssInvalid } } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/reader.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/reader.go index 0b8fa28..324b10b 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/reader.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/reader.go @@ -58,6 +58,8 @@ type ValueReader interface { // types that implement ValueReader may also implement this interface. // // The bytes of the value will be appended to dst. +// +// Deprecated: BytesReader will not be supported in Go Driver 2.0. type BytesReader interface { ReadValueBytes(dst []byte) (bsontype.Type, []byte, error) } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/value_reader.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/value_reader.go index 458588b..0e07d50 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/value_reader.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/value_reader.go @@ -28,11 +28,15 @@ var vrPool = sync.Pool{ } // BSONValueReaderPool is a pool for ValueReaders that read BSON. +// +// Deprecated: BSONValueReaderPool will not be supported in Go Driver 2.0. type BSONValueReaderPool struct { pool sync.Pool } // NewBSONValueReaderPool instantiates a new BSONValueReaderPool. +// +// Deprecated: BSONValueReaderPool will not be supported in Go Driver 2.0. func NewBSONValueReaderPool() *BSONValueReaderPool { return &BSONValueReaderPool{ pool: sync.Pool{ @@ -44,6 +48,8 @@ func NewBSONValueReaderPool() *BSONValueReaderPool { } // Get retrieves a ValueReader from the pool and uses src as the underlying BSON. +// +// Deprecated: BSONValueReaderPool will not be supported in Go Driver 2.0. func (bvrp *BSONValueReaderPool) Get(src []byte) ValueReader { vr := bvrp.pool.Get().(*valueReader) vr.reset(src) @@ -52,6 +58,8 @@ func (bvrp *BSONValueReaderPool) Get(src []byte) ValueReader { // Put inserts a ValueReader into the pool. If the ValueReader is not a BSON ValueReader nothing // is inserted into the pool and ok will be false. +// +// Deprecated: BSONValueReaderPool will not be supported in Go Driver 2.0. func (bvrp *BSONValueReaderPool) Put(vr ValueReader) (ok bool) { bvr, ok := vr.(*valueReader) if !ok { @@ -86,12 +94,11 @@ type valueReader struct { // NewBSONDocumentReader returns a ValueReader using b for the underlying BSON // representation. Parameter b must be a BSON Document. -// -// TODO(skriptble): There's a lack of symmetry between the reader and writer, since the reader takes -// a []byte while the writer takes an io.Writer. We should have two versions of each, one that takes -// a []byte and one that takes an io.Reader or io.Writer. The []byte version will need to return a -// thing that can return the finished []byte since it might be reallocated when appended to. func NewBSONDocumentReader(b []byte) ValueReader { + // TODO(skriptble): There's a lack of symmetry between the reader and writer, since the reader takes a []byte while the + // TODO writer takes an io.Writer. We should have two versions of each, one that takes a []byte and one that takes an + // TODO io.Reader or io.Writer. The []byte version will need to return a thing that can return the finished []byte since + // TODO it might be reallocated when appended to. return newValueReader(b) } @@ -732,8 +739,7 @@ func (vr *valueReader) ReadValue() (ValueReader, error) { return nil, ErrEOA } - _, err = vr.readCString() - if err != nil { + if err := vr.skipCString(); err != nil { return nil, err } @@ -787,6 +793,15 @@ func (vr *valueReader) readByte() (byte, error) { return vr.d[vr.offset-1], nil } +func (vr *valueReader) skipCString() error { + idx := bytes.IndexByte(vr.d[vr.offset:], 0x00) + if idx < 0 { + return io.EOF + } + vr.offset += int64(idx) + 1 + return nil +} + func (vr *valueReader) readCString() (string, error) { idx := bytes.IndexByte(vr.d[vr.offset:], 0x00) if idx < 0 { @@ -827,7 +842,7 @@ func (vr *valueReader) peekLength() (int32, error) { } idx := vr.offset - return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil + return int32(binary.LittleEndian.Uint32(vr.d[idx:])), nil } func (vr *valueReader) readLength() (int32, error) { return vr.readi32() } @@ -839,7 +854,7 @@ func (vr *valueReader) readi32() (int32, error) { idx := vr.offset vr.offset += 4 - return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil + return int32(binary.LittleEndian.Uint32(vr.d[idx:])), nil } func (vr *valueReader) readu32() (uint32, error) { @@ -849,7 +864,7 @@ func (vr *valueReader) readu32() (uint32, error) { idx := vr.offset vr.offset += 4 - return (uint32(vr.d[idx]) | uint32(vr.d[idx+1])<<8 | uint32(vr.d[idx+2])<<16 | uint32(vr.d[idx+3])<<24), nil + return binary.LittleEndian.Uint32(vr.d[idx:]), nil } func (vr *valueReader) readi64() (int64, error) { @@ -859,8 +874,7 @@ func (vr *valueReader) readi64() (int64, error) { idx := vr.offset vr.offset += 8 - return int64(vr.d[idx]) | int64(vr.d[idx+1])<<8 | int64(vr.d[idx+2])<<16 | int64(vr.d[idx+3])<<24 | - int64(vr.d[idx+4])<<32 | int64(vr.d[idx+5])<<40 | int64(vr.d[idx+6])<<48 | int64(vr.d[idx+7])<<56, nil + return int64(binary.LittleEndian.Uint64(vr.d[idx:])), nil } func (vr *valueReader) readu64() (uint64, error) { @@ -870,6 +884,5 @@ func (vr *valueReader) readu64() (uint64, error) { idx := vr.offset vr.offset += 8 - return uint64(vr.d[idx]) | uint64(vr.d[idx+1])<<8 | uint64(vr.d[idx+2])<<16 | uint64(vr.d[idx+3])<<24 | - uint64(vr.d[idx+4])<<32 | uint64(vr.d[idx+5])<<40 | uint64(vr.d[idx+6])<<48 | uint64(vr.d[idx+7])<<56, nil + return binary.LittleEndian.Uint64(vr.d[idx:]), nil } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/value_writer.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/value_writer.go index f95a08a..501c6d7 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/value_writer.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/value_writer.go @@ -28,12 +28,23 @@ var vwPool = sync.Pool{ }, } +func putValueWriter(vw *valueWriter) { + if vw != nil { + vw.w = nil // don't leak the writer + vwPool.Put(vw) + } +} + // BSONValueWriterPool is a pool for BSON ValueWriters. +// +// Deprecated: BSONValueWriterPool will not be supported in Go Driver 2.0. type BSONValueWriterPool struct { pool sync.Pool } // NewBSONValueWriterPool creates a new pool for ValueWriter instances that write to BSON. +// +// Deprecated: BSONValueWriterPool will not be supported in Go Driver 2.0. func NewBSONValueWriterPool() *BSONValueWriterPool { return &BSONValueWriterPool{ pool: sync.Pool{ @@ -45,6 +56,8 @@ func NewBSONValueWriterPool() *BSONValueWriterPool { } // Get retrieves a BSON ValueWriter from the pool and resets it to use w as the destination. +// +// Deprecated: BSONValueWriterPool will not be supported in Go Driver 2.0. func (bvwp *BSONValueWriterPool) Get(w io.Writer) ValueWriter { vw := bvwp.pool.Get().(*valueWriter) @@ -56,6 +69,8 @@ func (bvwp *BSONValueWriterPool) Get(w io.Writer) ValueWriter { } // GetAtModeElement retrieves a ValueWriterFlusher from the pool and resets it to use w as the destination. +// +// Deprecated: BSONValueWriterPool will not be supported in Go Driver 2.0. func (bvwp *BSONValueWriterPool) GetAtModeElement(w io.Writer) ValueWriterFlusher { vw := bvwp.Get(w).(*valueWriter) vw.push(mElement) @@ -64,6 +79,8 @@ func (bvwp *BSONValueWriterPool) GetAtModeElement(w io.Writer) ValueWriterFlushe // Put inserts a ValueWriter into the pool. If the ValueWriter is not a BSON ValueWriter, nothing // happens and ok will be false. +// +// Deprecated: BSONValueWriterPool will not be supported in Go Driver 2.0. func (bvwp *BSONValueWriterPool) Put(vw ValueWriter) (ok bool) { bvw, ok := vw.(*valueWriter) if !ok { @@ -139,32 +156,21 @@ type valueWriter struct { } func (vw *valueWriter) advanceFrame() { - if vw.frame+1 >= int64(len(vw.stack)) { // We need to grow the stack - length := len(vw.stack) - if length+1 >= cap(vw.stack) { - // double it - buf := make([]vwState, 2*cap(vw.stack)+1) - copy(buf, vw.stack) - vw.stack = buf - } - vw.stack = vw.stack[:length+1] - } vw.frame++ + if vw.frame >= int64(len(vw.stack)) { + vw.stack = append(vw.stack, vwState{}) + } } func (vw *valueWriter) push(m mode) { vw.advanceFrame() // Clean the stack - vw.stack[vw.frame].mode = m - vw.stack[vw.frame].key = "" - vw.stack[vw.frame].arrkey = 0 - vw.stack[vw.frame].start = 0 + vw.stack[vw.frame] = vwState{mode: m} - vw.stack[vw.frame].mode = m switch m { case mDocument, mArray, mCodeWithScope: - vw.reserveLength() + vw.reserveLength() // WARN: this is not needed } } @@ -203,6 +209,7 @@ func newValueWriter(w io.Writer) *valueWriter { return vw } +// TODO: only used in tests func newValueWriterFromSlice(buf []byte) *valueWriter { vw := new(valueWriter) stack := make([]vwState, 1, 5) @@ -239,17 +246,16 @@ func (vw *valueWriter) invalidTransitionError(destination mode, name string, mod } func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error { - switch vw.stack[vw.frame].mode { + frame := &vw.stack[vw.frame] + switch frame.mode { case mElement: - key := vw.stack[vw.frame].key + key := frame.key if !isValidCString(key) { return errors.New("BSON element key cannot contain null bytes") } - - vw.buf = bsoncore.AppendHeader(vw.buf, t, key) + vw.appendHeader(t, key) case mValue: - // TODO: Do this with a cache of the first 1000 or so array keys. - vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey)) + vw.appendIntHeader(t, frame.arrkey) default: modes := []mode{mElement, mValue} if addmodes != nil { @@ -591,9 +597,11 @@ func (vw *valueWriter) writeLength() error { if length > maxSize { return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))} } - length = length - int(vw.stack[vw.frame].start) - start := vw.stack[vw.frame].start + frame := &vw.stack[vw.frame] + length -= int(frame.start) + start := frame.start + _ = vw.buf[start+3] // BCE vw.buf[start+0] = byte(length) vw.buf[start+1] = byte(length >> 8) vw.buf[start+2] = byte(length >> 16) @@ -602,5 +610,31 @@ func (vw *valueWriter) writeLength() error { } func isValidCString(cs string) bool { - return !strings.ContainsRune(cs, '\x00') + // Disallow the zero byte in a cstring because the zero byte is used as the + // terminating character. + // + // It's safe to check bytes instead of runes because all multibyte UTF-8 + // code points start with (binary) 11xxxxxx or 10xxxxxx, so 00000000 (i.e. + // 0) will never be part of a multibyte UTF-8 code point. This logic is the + // same as the "r < utf8.RuneSelf" case in strings.IndexRune but can be + // inlined. + // + // https://cs.opensource.google/go/go/+/refs/tags/go1.21.1:src/strings/strings.go;l=127 + return strings.IndexByte(cs, 0) == -1 +} + +// appendHeader is the same as bsoncore.AppendHeader but does not check if the +// key is a valid C string since the caller has already checked for that. +// +// The caller of this function must check if key is a valid C string. +func (vw *valueWriter) appendHeader(t bsontype.Type, key string) { + vw.buf = bsoncore.AppendType(vw.buf, t) + vw.buf = append(vw.buf, key...) + vw.buf = append(vw.buf, 0x00) +} + +func (vw *valueWriter) appendIntHeader(t bsontype.Type, key int) { + vw.buf = bsoncore.AppendType(vw.buf, t) + vw.buf = strconv.AppendInt(vw.buf, int64(key), 10) + vw.buf = append(vw.buf, 0x00) } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/writer.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/writer.go index dff65f8..628f452 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/writer.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/writer.go @@ -56,6 +56,8 @@ type ValueWriter interface { } // ValueWriterFlusher is a superset of ValueWriter that exposes functionality to flush to the underlying buffer. +// +// Deprecated: ValueWriterFlusher will not be supported in Go Driver 2.0. type ValueWriterFlusher interface { ValueWriter Flush() error @@ -64,13 +66,20 @@ type ValueWriterFlusher interface { // BytesWriter is the interface used to write BSON bytes to a ValueWriter. // This interface is meant to be a superset of ValueWriter, so that types that // implement ValueWriter may also implement this interface. +// +// Deprecated: BytesWriter will not be supported in Go Driver 2.0. type BytesWriter interface { WriteValueBytes(t bsontype.Type, b []byte) error } // SliceWriter allows a pointer to a slice of bytes to be used as an io.Writer. +// +// Deprecated: SliceWriter will not be supported in Go Driver 2.0. type SliceWriter []byte +// Write writes the bytes to the underlying slice. +// +// Deprecated: SliceWriter will not be supported in Go Driver 2.0. func (sw *SliceWriter) Write(p []byte) (int, error) { written := len(p) *sw = append(*sw, p...) diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsontype/bsontype.go b/vendor/go.mongodb.org/mongo-driver/bson/bsontype/bsontype.go index 7c91ae5..255d990 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsontype/bsontype.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsontype/bsontype.go @@ -8,7 +8,9 @@ // a stringifier for the Type to enable easier debugging when working with BSON. package bsontype // import "go.mongodb.org/mongo-driver/bson/bsontype" -// These constants uniquely refer to each BSON type. +// BSON element types as described in https://bsonspec.org/spec.html. +// +// Deprecated: Use bson.Type* constants instead. const ( Double Type = 0x01 String Type = 0x02 @@ -31,7 +33,12 @@ const ( Decimal128 Type = 0x13 MinKey Type = 0xFF MaxKey Type = 0x7F +) +// BSON binary element subtypes as described in https://bsonspec.org/spec.html. +// +// Deprecated: Use the bson.TypeBinary* constants instead. +const ( BinaryGeneric byte = 0x00 BinaryFunction byte = 0x01 BinaryBinaryOld byte = 0x02 @@ -40,6 +47,7 @@ const ( BinaryMD5 byte = 0x05 BinaryEncrypted byte = 0x06 BinaryColumn byte = 0x07 + BinarySensitive byte = 0x08 BinaryUserDefined byte = 0x80 ) @@ -95,3 +103,14 @@ func (bt Type) String() string { return "invalid" } } + +// IsValid will return true if the Type is valid. +func (bt Type) IsValid() bool { + switch bt { + case Double, String, EmbeddedDocument, Array, Binary, Undefined, ObjectID, Boolean, DateTime, Null, Regex, + DBPointer, JavaScript, Symbol, CodeWithScope, Int32, Timestamp, Int64, Decimal128, MinKey, MaxKey: + return true + default: + return false + } +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/decoder.go b/vendor/go.mongodb.org/mongo-driver/bson/decoder.go index 7f6b769..eac74cd 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/decoder.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/decoder.go @@ -33,6 +33,17 @@ var decPool = sync.Pool{ type Decoder struct { dc bsoncodec.DecodeContext vr bsonrw.ValueReader + + // We persist defaultDocumentM and defaultDocumentD on the Decoder to prevent overwriting from + // (*Decoder).SetContext. + defaultDocumentM bool + defaultDocumentD bool + + binaryAsSlice bool + useJSONStructTags bool + useLocalTimeZone bool + zeroMaps bool + zeroStructs bool } // NewDecoder returns a new decoder that uses the DefaultRegistry to read from vr. @@ -48,6 +59,9 @@ func NewDecoder(vr bsonrw.ValueReader) (*Decoder, error) { } // NewDecoderWithContext returns a new decoder that uses DecodeContext dc to read from vr. +// +// Deprecated: Use [NewDecoder] and use the Decoder configuration methods set the desired unmarshal +// behavior instead. func NewDecoderWithContext(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader) (*Decoder, error) { if dc.Registry == nil { dc.Registry = DefaultRegistry @@ -65,8 +79,7 @@ func NewDecoderWithContext(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader) (* // Decode reads the next BSON document from the stream and decodes it into the // value pointed to by val. // -// The documentation for Unmarshal contains details about of BSON into a Go -// value. +// See [Unmarshal] for details about BSON unmarshaling behavior. func (d *Decoder) Decode(val interface{}) error { if unmarshaler, ok := val.(Unmarshaler); ok { // TODO(skriptble): Reuse a []byte here and use the AppendDocumentBytes method. @@ -95,24 +108,101 @@ func (d *Decoder) Decode(val interface{}) error { if err != nil { return err } + + if d.defaultDocumentM { + d.dc.DefaultDocumentM() + } + if d.defaultDocumentD { + d.dc.DefaultDocumentD() + } + if d.binaryAsSlice { + d.dc.BinaryAsSlice() + } + if d.useJSONStructTags { + d.dc.UseJSONStructTags() + } + if d.useLocalTimeZone { + d.dc.UseLocalTimeZone() + } + if d.zeroMaps { + d.dc.ZeroMaps() + } + if d.zeroStructs { + d.dc.ZeroStructs() + } + return decoder.DecodeValue(d.dc, d.vr, rval) } // Reset will reset the state of the decoder, using the same *DecodeContext used in // the original construction but using vr for reading. func (d *Decoder) Reset(vr bsonrw.ValueReader) error { + // TODO:(GODRIVER-2719): Remove error return value. d.vr = vr return nil } // SetRegistry replaces the current registry of the decoder with r. func (d *Decoder) SetRegistry(r *bsoncodec.Registry) error { + // TODO:(GODRIVER-2719): Remove error return value. d.dc.Registry = r return nil } // SetContext replaces the current registry of the decoder with dc. +// +// Deprecated: Use the Decoder configuration methods to set the desired unmarshal behavior instead. func (d *Decoder) SetContext(dc bsoncodec.DecodeContext) error { + // TODO:(GODRIVER-2719): Remove error return value. d.dc = dc return nil } + +// DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This +// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". +func (d *Decoder) DefaultDocumentM() { + d.defaultDocumentM = true +} + +// DefaultDocumentD causes the Decoder to always unmarshal documents into the primitive.D type. This +// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". +func (d *Decoder) DefaultDocumentD() { + d.defaultDocumentD = true +} + +// AllowTruncatingDoubles causes the Decoder to truncate the fractional part of BSON "double" values +// when attempting to unmarshal them into a Go integer (int, int8, int16, int32, or int64) struct +// field. The truncation logic does not apply to BSON "decimal128" values. +func (d *Decoder) AllowTruncatingDoubles() { + d.dc.Truncate = true +} + +// BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or +// "Old" BSON binary subtype as a Go byte slice instead of a primitive.Binary. +func (d *Decoder) BinaryAsSlice() { + d.binaryAsSlice = true +} + +// UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson" +// struct tag is not specified. +func (d *Decoder) UseJSONStructTags() { + d.useJSONStructTags = true +} + +// UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead +// of the UTC timezone. +func (d *Decoder) UseLocalTimeZone() { + d.useLocalTimeZone = true +} + +// ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value +// passed to Decode before unmarshaling BSON documents into them. +func (d *Decoder) ZeroMaps() { + d.zeroMaps = true +} + +// ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination +// value passed to Decode before unmarshaling BSON documents into them. +func (d *Decoder) ZeroStructs() { + d.zeroStructs = true +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/doc.go b/vendor/go.mongodb.org/mongo-driver/bson/doc.go index 5e3825a..fb075b4 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/doc.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/doc.go @@ -6,24 +6,26 @@ // Package bson is a library for reading, writing, and manipulating BSON. BSON is a binary serialization format used to // store documents and make remote procedure calls in MongoDB. The BSON specification is located at https://bsonspec.org. -// The BSON library handles marshalling and unmarshalling of values through a configurable codec system. For a description -// of the codec system and examples of registering custom codecs, see the bsoncodec package. +// The BSON library handles marshaling and unmarshaling of values through a configurable codec system. For a description +// of the codec system and examples of registering custom codecs, see the bsoncodec package. For additional information +// and usage examples, check out the [Work with BSON] page in the Go Driver docs site. // -// Raw BSON +// # Raw BSON // // The Raw family of types is used to validate and retrieve elements from a slice of bytes. This // type is most useful when you want do lookups on BSON bytes without unmarshaling it into another // type. // // Example: -// var raw bson.Raw = ... // bytes from somewhere -// err := raw.Validate() -// if err != nil { return err } -// val := raw.Lookup("foo") -// i32, ok := val.Int32OK() -// // do something with i32... // -// Native Go Types +// var raw bson.Raw = ... // bytes from somewhere +// err := raw.Validate() +// if err != nil { return err } +// val := raw.Lookup("foo") +// i32, ok := val.Int32OK() +// // do something with i32... +// +// # Native Go Types // // The D and M types defined in this package can be used to build representations of BSON using native Go types. D is a // slice and M is a map. For more information about the use cases for these types, see the documentation on the type @@ -32,107 +34,109 @@ // Note that a D should not be constructed with duplicate key names, as that can cause undefined server behavior. // // Example: -// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}} -// bson.M{"foo": "bar", "hello": "world", "pi": 3.14159} -// -// When decoding BSON to a D or M, the following type mappings apply when unmarshalling: -// -// 1. BSON int32 unmarshals to an int32. -// 2. BSON int64 unmarshals to an int64. -// 3. BSON double unmarshals to a float64. -// 4. BSON string unmarshals to a string. -// 5. BSON boolean unmarshals to a bool. -// 6. BSON embedded document unmarshals to the parent type (i.e. D for a D, M for an M). -// 7. BSON array unmarshals to a bson.A. -// 8. BSON ObjectId unmarshals to a primitive.ObjectID. -// 9. BSON datetime unmarshals to a primitive.DateTime. -// 10. BSON binary unmarshals to a primitive.Binary. -// 11. BSON regular expression unmarshals to a primitive.Regex. -// 12. BSON JavaScript unmarshals to a primitive.JavaScript. -// 13. BSON code with scope unmarshals to a primitive.CodeWithScope. -// 14. BSON timestamp unmarshals to an primitive.Timestamp. -// 15. BSON 128-bit decimal unmarshals to an primitive.Decimal128. -// 16. BSON min key unmarshals to an primitive.MinKey. -// 17. BSON max key unmarshals to an primitive.MaxKey. -// 18. BSON undefined unmarshals to a primitive.Undefined. -// 19. BSON null unmarshals to nil. -// 20. BSON DBPointer unmarshals to a primitive.DBPointer. -// 21. BSON symbol unmarshals to a primitive.Symbol. -// -// The above mappings also apply when marshalling a D or M to BSON. Some other useful marshalling mappings are: -// -// 1. time.Time marshals to a BSON datetime. -// 2. int8, int16, and int32 marshal to a BSON int32. -// 3. int marshals to a BSON int32 if the value is between math.MinInt32 and math.MaxInt32, inclusive, and a BSON int64 -// otherwise. -// 4. int64 marshals to BSON int64. -// 5. uint8 and uint16 marshal to a BSON int32. -// 6. uint, uint32, and uint64 marshal to a BSON int32 if the value is between math.MinInt32 and math.MaxInt32, -// inclusive, and BSON int64 otherwise. -// 7. BSON null and undefined values will unmarshal into the zero value of a field (e.g. unmarshalling a BSON null or -// undefined value into a string will yield the empty string.). -// -// Structs -// -// Structs can be marshalled/unmarshalled to/from BSON or Extended JSON. When transforming structs to/from BSON or Extended +// +// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}} +// bson.M{"foo": "bar", "hello": "world", "pi": 3.14159} +// +// When decoding BSON to a D or M, the following type mappings apply when unmarshaling: +// +// 1. BSON int32 unmarshals to an int32. +// 2. BSON int64 unmarshals to an int64. +// 3. BSON double unmarshals to a float64. +// 4. BSON string unmarshals to a string. +// 5. BSON boolean unmarshals to a bool. +// 6. BSON embedded document unmarshals to the parent type (i.e. D for a D, M for an M). +// 7. BSON array unmarshals to a bson.A. +// 8. BSON ObjectId unmarshals to a primitive.ObjectID. +// 9. BSON datetime unmarshals to a primitive.DateTime. +// 10. BSON binary unmarshals to a primitive.Binary. +// 11. BSON regular expression unmarshals to a primitive.Regex. +// 12. BSON JavaScript unmarshals to a primitive.JavaScript. +// 13. BSON code with scope unmarshals to a primitive.CodeWithScope. +// 14. BSON timestamp unmarshals to an primitive.Timestamp. +// 15. BSON 128-bit decimal unmarshals to an primitive.Decimal128. +// 16. BSON min key unmarshals to an primitive.MinKey. +// 17. BSON max key unmarshals to an primitive.MaxKey. +// 18. BSON undefined unmarshals to a primitive.Undefined. +// 19. BSON null unmarshals to nil. +// 20. BSON DBPointer unmarshals to a primitive.DBPointer. +// 21. BSON symbol unmarshals to a primitive.Symbol. +// +// The above mappings also apply when marshaling a D or M to BSON. Some other useful marshaling mappings are: +// +// 1. time.Time marshals to a BSON datetime. +// 2. int8, int16, and int32 marshal to a BSON int32. +// 3. int marshals to a BSON int32 if the value is between math.MinInt32 and math.MaxInt32, inclusive, and a BSON int64 +// otherwise. +// 4. int64 marshals to BSON int64 (unless [Encoder.IntMinSize] is set). +// 5. uint8 and uint16 marshal to a BSON int32. +// 6. uint, uint32, and uint64 marshal to a BSON int64 (unless [Encoder.IntMinSize] is set). +// 7. BSON null and undefined values will unmarshal into the zero value of a field (e.g. unmarshaling a BSON null or +// undefined value into a string will yield the empty string.). +// +// # Structs +// +// Structs can be marshaled/unmarshaled to/from BSON or Extended JSON. When transforming structs to/from BSON or Extended // JSON, the following rules apply: // -// 1. Only exported fields in structs will be marshalled or unmarshalled. +// 1. Only exported fields in structs will be marshaled or unmarshaled. // -// 2. When marshalling a struct, each field will be lowercased to generate the key for the corresponding BSON element. +// 2. When marshaling a struct, each field will be lowercased to generate the key for the corresponding BSON element. // For example, a struct field named "Foo" will generate key "foo". This can be overridden via a struct tag (e.g. // `bson:"fooField"` to generate key "fooField" instead). // -// 3. An embedded struct field is marshalled as a subdocument. The key will be the lowercased name of the field's type. +// 3. An embedded struct field is marshaled as a subdocument. The key will be the lowercased name of the field's type. // -// 4. A pointer field is marshalled as the underlying type if the pointer is non-nil. If the pointer is nil, it is -// marshalled as a BSON null value. +// 4. A pointer field is marshaled as the underlying type if the pointer is non-nil. If the pointer is nil, it is +// marshaled as a BSON null value. // -// 5. When unmarshalling, a field of type interface{} will follow the D/M type mappings listed above. BSON documents -// unmarshalled into an interface{} field will be unmarshalled as a D. +// 5. When unmarshaling, a field of type interface{} will follow the D/M type mappings listed above. BSON documents +// unmarshaled into an interface{} field will be unmarshaled as a D. // // The encoding of each struct field can be customized by the "bson" struct tag. // // This tag behavior is configurable, and different struct tag behavior can be configured by initializing a new -// bsoncodec.StructCodec with the desired tag parser and registering that StructCodec onto the Registry. By default, JSON tags -// are not honored, but that can be enabled by creating a StructCodec with JSONFallbackStructTagParser, like below: +// bsoncodec.StructCodec with the desired tag parser and registering that StructCodec onto the Registry. By default, JSON +// tags are not honored, but that can be enabled by creating a StructCodec with JSONFallbackStructTagParser, like below: // // Example: -// structcodec, _ := bsoncodec.NewStructCodec(bsoncodec.JSONFallbackStructTagParser) +// +// structcodec, _ := bsoncodec.NewStructCodec(bsoncodec.JSONFallbackStructTagParser) // // The bson tag gives the name of the field, possibly followed by a comma-separated list of options. -// The name may be empty in order to specify options without overriding the default field name. The following options can be used -// to configure behavior: -// -// 1. omitempty: If the omitempty struct tag is specified on a field, the field will not be marshalled if it is set to -// the zero value. Fields with language primitive types such as integers, booleans, and strings are considered empty if -// their value is equal to the zero value for the type (i.e. 0 for integers, false for booleans, and "" for strings). -// Slices, maps, and arrays are considered empty if they are of length zero. Interfaces and pointers are considered -// empty if their value is nil. By default, structs are only considered empty if the struct type implements the -// bsoncodec.Zeroer interface and the IsZero method returns true. Struct fields whose types do not implement Zeroer are -// never considered empty and will be marshalled as embedded documents. -// NOTE: It is recommended that this tag be used for all slice and map fields. -// -// 2. minsize: If the minsize struct tag is specified on a field of type int64, uint, uint32, or uint64 and the value of -// the field can fit in a signed int32, the field will be serialized as a BSON int32 rather than a BSON int64. For other -// types, this tag is ignored. -// -// 3. truncate: If the truncate struct tag is specified on a field with a non-float numeric type, BSON doubles unmarshalled -// into that field will be truncated at the decimal point. For example, if 3.14 is unmarshalled into a field of type int, -// it will be unmarshalled as 3. If this tag is not specified, the decoder will throw an error if the value cannot be -// decoded without losing precision. For float64 or non-numeric types, this tag is ignored. -// -// 4. inline: If the inline struct tag is specified for a struct or map field, the field will be "flattened" when -// marshalling and "un-flattened" when unmarshalling. This means that all of the fields in that struct/map will be -// pulled up one level and will become top-level fields rather than being fields in a nested document. For example, if a -// map field named "Map" with value map[string]interface{}{"foo": "bar"} is inlined, the resulting document will be -// {"foo": "bar"} instead of {"map": {"foo": "bar"}}. There can only be one inlined map field in a struct. If there are -// duplicated fields in the resulting document when an inlined struct is marshalled, the inlined field will be overwritten. -// If there are duplicated fields in the resulting document when an inlined map is marshalled, an error will be returned. -// This tag can be used with fields that are pointers to structs. If an inlined pointer field is nil, it will not be -// marshalled. For fields that are not maps or structs, this tag is ignored. -// -// Marshalling and Unmarshalling -// -// Manually marshalling and unmarshalling can be done with the Marshal and Unmarshal family of functions. +// The name may be empty in order to specify options without overriding the default field name. The following options can +// be used to configure behavior: +// +// 1. omitempty: If the "omitempty" struct tag is specified on a field, the field will not be marshaled if it is set to +// an "empty" value. Numbers, booleans, and strings are considered empty if their value is equal to the zero value for +// the type (i.e. 0 for numbers, false for booleans, and "" for strings). Slices, maps, and arrays are considered +// empty if they are of length zero. Interfaces and pointers are considered empty if their value is nil. By default, +// structs are only considered empty if the struct type implements [bsoncodec.Zeroer] and the IsZero +// method returns true. Struct types that do not implement [bsoncodec.Zeroer] are never considered empty and will be +// marshaled as embedded documents. NOTE: It is recommended that this tag be used for all slice and map fields. +// +// 2. minsize: If the minsize struct tag is specified on a field of type int64, uint, uint32, or uint64 and the value of +// the field can fit in a signed int32, the field will be serialized as a BSON int32 rather than a BSON int64. For +// other types, this tag is ignored. +// +// 3. truncate: If the truncate struct tag is specified on a field with a non-float numeric type, BSON doubles +// unmarshaled into that field will be truncated at the decimal point. For example, if 3.14 is unmarshaled into a +// field of type int, it will be unmarshaled as 3. If this tag is not specified, the decoder will throw an error if +// the value cannot be decoded without losing precision. For float64 or non-numeric types, this tag is ignored. +// +// 4. inline: If the inline struct tag is specified for a struct or map field, the field will be "flattened" when +// marshaling and "un-flattened" when unmarshaling. This means that all of the fields in that struct/map will be +// pulled up one level and will become top-level fields rather than being fields in a nested document. For example, +// if a map field named "Map" with value map[string]interface{}{"foo": "bar"} is inlined, the resulting document will +// be {"foo": "bar"} instead of {"map": {"foo": "bar"}}. There can only be one inlined map field in a struct. If +// there are duplicated fields in the resulting document when an inlined struct is marshaled, the inlined field will +// be overwritten. If there are duplicated fields in the resulting document when an inlined map is marshaled, an +// error will be returned. This tag can be used with fields that are pointers to structs. If an inlined pointer field +// is nil, it will not be marshaled. For fields that are not maps or structs, this tag is ignored. +// +// # Marshaling and Unmarshaling +// +// Manually marshaling and unmarshaling can be done with the Marshal and Unmarshal family of functions. +// +// [Work with BSON]: https://www.mongodb.com/docs/drivers/go/current/fundamentals/bson/ package bson diff --git a/vendor/go.mongodb.org/mongo-driver/bson/encoder.go b/vendor/go.mongodb.org/mongo-driver/bson/encoder.go index fe5125d..0be2a97 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/encoder.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/encoder.go @@ -29,10 +29,20 @@ var encPool = sync.Pool{ type Encoder struct { ec bsoncodec.EncodeContext vw bsonrw.ValueWriter + + errorOnInlineDuplicates bool + intMinSize bool + stringifyMapKeysWithFmt bool + nilMapAsEmpty bool + nilSliceAsEmpty bool + nilByteSliceAsEmpty bool + omitZeroStruct bool + useJSONStructTags bool } // NewEncoder returns a new encoder that uses the DefaultRegistry to write to vw. func NewEncoder(vw bsonrw.ValueWriter) (*Encoder, error) { + // TODO:(GODRIVER-2719): Remove error return value. if vw == nil { return nil, errors.New("cannot create a new Encoder with a nil ValueWriter") } @@ -44,6 +54,9 @@ func NewEncoder(vw bsonrw.ValueWriter) (*Encoder, error) { } // NewEncoderWithContext returns a new encoder that uses EncodeContext ec to write to vw. +// +// Deprecated: Use [NewEncoder] and use the Encoder configuration methods to set the desired marshal +// behavior instead. func NewEncoderWithContext(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter) (*Encoder, error) { if ec.Registry == nil { ec = bsoncodec.EncodeContext{Registry: DefaultRegistry} @@ -60,8 +73,7 @@ func NewEncoderWithContext(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter) (* // Encode writes the BSON encoding of val to the stream. // -// The documentation for Marshal contains details about the conversion of Go -// values to BSON. +// See [Marshal] for details about BSON marshaling behavior. func (e *Encoder) Encode(val interface{}) error { if marshaler, ok := val.(Marshaler); ok { // TODO(skriptble): Should we have a MarshalAppender interface so that we can have []byte reuse? @@ -76,24 +88,112 @@ func (e *Encoder) Encode(val interface{}) error { if err != nil { return err } + + // Copy the configurations applied to the Encoder over to the EncodeContext, which actually + // communicates those configurations to the default ValueEncoders. + if e.errorOnInlineDuplicates { + e.ec.ErrorOnInlineDuplicates() + } + if e.intMinSize { + e.ec.MinSize = true + } + if e.stringifyMapKeysWithFmt { + e.ec.StringifyMapKeysWithFmt() + } + if e.nilMapAsEmpty { + e.ec.NilMapAsEmpty() + } + if e.nilSliceAsEmpty { + e.ec.NilSliceAsEmpty() + } + if e.nilByteSliceAsEmpty { + e.ec.NilByteSliceAsEmpty() + } + if e.omitZeroStruct { + e.ec.OmitZeroStruct() + } + if e.useJSONStructTags { + e.ec.UseJSONStructTags() + } + return encoder.EncodeValue(e.ec, e.vw, reflect.ValueOf(val)) } -// Reset will reset the state of the encoder, using the same *EncodeContext used in +// Reset will reset the state of the Encoder, using the same *EncodeContext used in // the original construction but using vw. func (e *Encoder) Reset(vw bsonrw.ValueWriter) error { + // TODO:(GODRIVER-2719): Remove error return value. e.vw = vw return nil } -// SetRegistry replaces the current registry of the encoder with r. +// SetRegistry replaces the current registry of the Encoder with r. func (e *Encoder) SetRegistry(r *bsoncodec.Registry) error { + // TODO:(GODRIVER-2719): Remove error return value. e.ec.Registry = r return nil } -// SetContext replaces the current EncodeContext of the encoder with er. +// SetContext replaces the current EncodeContext of the encoder with ec. +// +// Deprecated: Use the Encoder configuration methods set the desired marshal behavior instead. func (e *Encoder) SetContext(ec bsoncodec.EncodeContext) error { + // TODO:(GODRIVER-2719): Remove error return value. e.ec = ec return nil } + +// ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in +// the marshaled BSON when the "inline" struct tag option is set. +func (e *Encoder) ErrorOnInlineDuplicates() { + e.errorOnInlineDuplicates = true +} + +// IntMinSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, uint, +// uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) that can +// represent the integer value. +func (e *Encoder) IntMinSize() { + e.intMinSize = true +} + +// StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name +// strings using fmt.Sprint instead of the default string conversion logic. +func (e *Encoder) StringifyMapKeysWithFmt() { + e.stringifyMapKeysWithFmt = true +} + +// NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON +// null. +func (e *Encoder) NilMapAsEmpty() { + e.nilMapAsEmpty = true +} + +// NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON +// null. +func (e *Encoder) NilSliceAsEmpty() { + e.nilSliceAsEmpty = true +} + +// NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values +// instead of BSON null. +func (e *Encoder) NilByteSliceAsEmpty() { + e.nilByteSliceAsEmpty = true +} + +// TODO(GODRIVER-2820): Update the description to remove the note about only examining exported +// TODO struct fields once the logic is updated to also inspect private struct fields. + +// OmitZeroStruct causes the Encoder to consider the zero value for a struct (e.g. MyStruct{}) +// as empty and omit it from the marshaled BSON when the "omitempty" struct tag option is set. +// +// Note that the Encoder only examines exported struct fields when determining if a struct is the +// zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty. +func (e *Encoder) OmitZeroStruct() { + e.omitZeroStruct = true +} + +// UseJSONStructTags causes the Encoder to fall back to using the "json" struct tag if a "bson" +// struct tag is not specified. +func (e *Encoder) UseJSONStructTags() { + e.useJSONStructTags = true +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/marshal.go b/vendor/go.mongodb.org/mongo-driver/bson/marshal.go index db8d8ee..17ce669 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/marshal.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/marshal.go @@ -9,6 +9,7 @@ package bson import ( "bytes" "encoding/json" + "sync" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsonrw" @@ -20,17 +21,23 @@ const defaultDstCap = 256 var bvwPool = bsonrw.NewBSONValueWriterPool() var extjPool = bsonrw.NewExtJSONValueWriterPool() -// Marshaler is an interface implemented by types that can marshal themselves -// into a BSON document represented as bytes. The bytes returned must be a valid -// BSON document if the error is nil. +// Marshaler is the interface implemented by types that can marshal themselves +// into a valid BSON document. +// +// Implementations of Marshaler must return a full BSON document. To create +// custom BSON marshaling behavior for individual values in a BSON document, +// implement the ValueMarshaler interface instead. type Marshaler interface { MarshalBSON() ([]byte, error) } -// ValueMarshaler is an interface implemented by types that can marshal -// themselves into a BSON value as bytes. The type must be the valid type for -// the bytes returned. The bytes and byte type together must be valid if the -// error is nil. +// ValueMarshaler is the interface implemented by types that can marshal +// themselves into a valid BSON value. The format of the returned bytes must +// match the returned type. +// +// Implementations of ValueMarshaler must return an individual BSON value. To +// create custom BSON marshaling behavior for an entire BSON document, implement +// the Marshaler interface instead. type ValueMarshaler interface { MarshalBSONValue() (bsontype.Type, []byte, error) } @@ -48,12 +55,42 @@ func Marshal(val interface{}) ([]byte, error) { // MarshalAppend will encode val as a BSON document and append the bytes to dst. If dst is not large enough to hold the // bytes, it will be grown. If val is not a type that can be transformed into a document, MarshalValueAppend should be // used instead. +// +// Deprecated: Use [NewEncoder] and pass the dst byte slice (wrapped by a bytes.Buffer) into +// [bsonrw.NewBSONValueWriter]: +// +// buf := bytes.NewBuffer(dst) +// vw, err := bsonrw.NewBSONValueWriter(buf) +// if err != nil { +// panic(err) +// } +// enc, err := bson.NewEncoder(vw) +// if err != nil { +// panic(err) +// } +// +// See [Encoder] for more examples. func MarshalAppend(dst []byte, val interface{}) ([]byte, error) { return MarshalAppendWithRegistry(DefaultRegistry, dst, val) } // MarshalWithRegistry returns the BSON encoding of val as a BSON document. If val is not a type that can be transformed // into a document, MarshalValueWithRegistry should be used instead. +// +// Deprecated: Use [NewEncoder] and specify the Registry by calling [Encoder.SetRegistry] instead: +// +// buf := new(bytes.Buffer) +// vw, err := bsonrw.NewBSONValueWriter(buf) +// if err != nil { +// panic(err) +// } +// enc, err := bson.NewEncoder(vw) +// if err != nil { +// panic(err) +// } +// enc.SetRegistry(reg) +// +// See [Encoder] for more examples. func MarshalWithRegistry(r *bsoncodec.Registry, val interface{}) ([]byte, error) { dst := make([]byte, 0) return MarshalAppendWithRegistry(r, dst, val) @@ -61,6 +98,22 @@ func MarshalWithRegistry(r *bsoncodec.Registry, val interface{}) ([]byte, error) // MarshalWithContext returns the BSON encoding of val as a BSON document using EncodeContext ec. If val is not a type // that can be transformed into a document, MarshalValueWithContext should be used instead. +// +// Deprecated: Use [NewEncoder] and use the Encoder configuration methods to set the desired marshal +// behavior instead: +// +// buf := bytes.NewBuffer(dst) +// vw, err := bsonrw.NewBSONValueWriter(buf) +// if err != nil { +// panic(err) +// } +// enc, err := bson.NewEncoder(vw) +// if err != nil { +// panic(err) +// } +// enc.IntMinSize() +// +// See [Encoder] for more examples. func MarshalWithContext(ec bsoncodec.EncodeContext, val interface{}) ([]byte, error) { dst := make([]byte, 0) return MarshalAppendWithContext(ec, dst, val) @@ -69,16 +122,74 @@ func MarshalWithContext(ec bsoncodec.EncodeContext, val interface{}) ([]byte, er // MarshalAppendWithRegistry will encode val as a BSON document using Registry r and append the bytes to dst. If dst is // not large enough to hold the bytes, it will be grown. If val is not a type that can be transformed into a document, // MarshalValueAppendWithRegistry should be used instead. +// +// Deprecated: Use [NewEncoder], and pass the dst byte slice (wrapped by a bytes.Buffer) into +// [bsonrw.NewBSONValueWriter], and specify the Registry by calling [Encoder.SetRegistry] instead: +// +// buf := bytes.NewBuffer(dst) +// vw, err := bsonrw.NewBSONValueWriter(buf) +// if err != nil { +// panic(err) +// } +// enc, err := bson.NewEncoder(vw) +// if err != nil { +// panic(err) +// } +// enc.SetRegistry(reg) +// +// See [Encoder] for more examples. func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}) ([]byte, error) { return MarshalAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val) } +// Pool of buffers for marshalling BSON. +var bufPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, +} + // MarshalAppendWithContext will encode val as a BSON document using Registry r and EncodeContext ec and append the // bytes to dst. If dst is not large enough to hold the bytes, it will be grown. If val is not a type that can be // transformed into a document, MarshalValueAppendWithContext should be used instead. +// +// Deprecated: Use [NewEncoder], pass the dst byte slice (wrapped by a bytes.Buffer) into +// [bsonrw.NewBSONValueWriter], and use the Encoder configuration methods to set the desired marshal +// behavior instead: +// +// buf := bytes.NewBuffer(dst) +// vw, err := bsonrw.NewBSONValueWriter(buf) +// if err != nil { +// panic(err) +// } +// enc, err := bson.NewEncoder(vw) +// if err != nil { +// panic(err) +// } +// enc.IntMinSize() +// +// See [Encoder] for more examples. func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) ([]byte, error) { - sw := new(bsonrw.SliceWriter) - *sw = dst + sw := bufPool.Get().(*bytes.Buffer) + defer func() { + // Proper usage of a sync.Pool requires each entry to have approximately + // the same memory cost. To obtain this property when the stored type + // contains a variably-sized buffer, we add a hard limit on the maximum + // buffer to place back in the pool. We limit the size to 16MiB because + // that's the maximum wire message size supported by any current MongoDB + // server. + // + // Comment based on + // https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/fmt/print.go;l=147 + // + // Recycle byte slices that are smaller than 16MiB and at least half + // occupied. + if sw.Cap() < 16*1024*1024 && sw.Cap()/2 < sw.Len() { + bufPool.Put(sw) + } + }() + + sw.Reset() vw := bvwPool.Get(sw) defer bvwPool.Put(vw) @@ -99,7 +210,7 @@ func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interf return nil, err } - return *sw, nil + return append(dst, sw.Bytes()...), nil } // MarshalValue returns the BSON encoding of val. @@ -112,17 +223,26 @@ func MarshalValue(val interface{}) (bsontype.Type, []byte, error) { // MarshalValueAppend will append the BSON encoding of val to dst. If dst is not large enough to hold the BSON encoding // of val, dst will be grown. +// +// Deprecated: Appending individual BSON elements to an existing slice will not be supported in Go +// Driver 2.0. func MarshalValueAppend(dst []byte, val interface{}) (bsontype.Type, []byte, error) { return MarshalValueAppendWithRegistry(DefaultRegistry, dst, val) } // MarshalValueWithRegistry returns the BSON encoding of val using Registry r. +// +// Deprecated: Using a custom registry to marshal individual BSON values will not be supported in Go +// Driver 2.0. func MarshalValueWithRegistry(r *bsoncodec.Registry, val interface{}) (bsontype.Type, []byte, error) { dst := make([]byte, 0) return MarshalValueAppendWithRegistry(r, dst, val) } // MarshalValueWithContext returns the BSON encoding of val using EncodeContext ec. +// +// Deprecated: Using a custom EncodeContext to marshal individual BSON elements will not be +// supported in Go Driver 2.0. func MarshalValueWithContext(ec bsoncodec.EncodeContext, val interface{}) (bsontype.Type, []byte, error) { dst := make([]byte, 0) return MarshalValueAppendWithContext(ec, dst, val) @@ -130,12 +250,18 @@ func MarshalValueWithContext(ec bsoncodec.EncodeContext, val interface{}) (bsont // MarshalValueAppendWithRegistry will append the BSON encoding of val to dst using Registry r. If dst is not large // enough to hold the BSON encoding of val, dst will be grown. +// +// Deprecated: Appending individual BSON elements to an existing slice will not be supported in Go +// Driver 2.0. func MarshalValueAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}) (bsontype.Type, []byte, error) { return MarshalValueAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val) } // MarshalValueAppendWithContext will append the BSON encoding of val to dst using EncodeContext ec. If dst is not large // enough to hold the BSON encoding of val, dst will be grown. +// +// Deprecated: Appending individual BSON elements to an existing slice will not be supported in Go +// Driver 2.0. func MarshalValueAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) (bsontype.Type, []byte, error) { // get a ValueWriter configured to write to dst sw := new(bsonrw.SliceWriter) @@ -173,17 +299,63 @@ func MarshalExtJSON(val interface{}, canonical, escapeHTML bool) ([]byte, error) // MarshalExtJSONAppend will append the extended JSON encoding of val to dst. // If dst is not large enough to hold the extended JSON encoding of val, dst // will be grown. +// +// Deprecated: Use [NewEncoder] and pass the dst byte slice (wrapped by a bytes.Buffer) into +// [bsonrw.NewExtJSONValueWriter] instead: +// +// buf := bytes.NewBuffer(dst) +// vw, err := bsonrw.NewExtJSONValueWriter(buf, true, false) +// if err != nil { +// panic(err) +// } +// enc, err := bson.NewEncoder(vw) +// if err != nil { +// panic(err) +// } +// +// See [Encoder] for more examples. func MarshalExtJSONAppend(dst []byte, val interface{}, canonical, escapeHTML bool) ([]byte, error) { return MarshalExtJSONAppendWithRegistry(DefaultRegistry, dst, val, canonical, escapeHTML) } // MarshalExtJSONWithRegistry returns the extended JSON encoding of val using Registry r. +// +// Deprecated: Use [NewEncoder] and specify the Registry by calling [Encoder.SetRegistry] instead: +// +// buf := new(bytes.Buffer) +// vw, err := bsonrw.NewBSONValueWriter(buf) +// if err != nil { +// panic(err) +// } +// enc, err := bson.NewEncoder(vw) +// if err != nil { +// panic(err) +// } +// enc.SetRegistry(reg) +// +// See [Encoder] for more examples. func MarshalExtJSONWithRegistry(r *bsoncodec.Registry, val interface{}, canonical, escapeHTML bool) ([]byte, error) { dst := make([]byte, 0, defaultDstCap) return MarshalExtJSONAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val, canonical, escapeHTML) } // MarshalExtJSONWithContext returns the extended JSON encoding of val using Registry r. +// +// Deprecated: Use [NewEncoder] and use the Encoder configuration methods to set the desired marshal +// behavior instead: +// +// buf := new(bytes.Buffer) +// vw, err := bsonrw.NewBSONValueWriter(buf) +// if err != nil { +// panic(err) +// } +// enc, err := bson.NewEncoder(vw) +// if err != nil { +// panic(err) +// } +// enc.IntMinSize() +// +// See [Encoder] for more examples. func MarshalExtJSONWithContext(ec bsoncodec.EncodeContext, val interface{}, canonical, escapeHTML bool) ([]byte, error) { dst := make([]byte, 0, defaultDstCap) return MarshalExtJSONAppendWithContext(ec, dst, val, canonical, escapeHTML) @@ -192,6 +364,22 @@ func MarshalExtJSONWithContext(ec bsoncodec.EncodeContext, val interface{}, cano // MarshalExtJSONAppendWithRegistry will append the extended JSON encoding of // val to dst using Registry r. If dst is not large enough to hold the BSON // encoding of val, dst will be grown. +// +// Deprecated: Use [NewEncoder], pass the dst byte slice (wrapped by a bytes.Buffer) into +// [bsonrw.NewExtJSONValueWriter], and specify the Registry by calling [Encoder.SetRegistry] +// instead: +// +// buf := bytes.NewBuffer(dst) +// vw, err := bsonrw.NewExtJSONValueWriter(buf, true, false) +// if err != nil { +// panic(err) +// } +// enc, err := bson.NewEncoder(vw) +// if err != nil { +// panic(err) +// } +// +// See [Encoder] for more examples. func MarshalExtJSONAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}, canonical, escapeHTML bool) ([]byte, error) { return MarshalExtJSONAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val, canonical, escapeHTML) } @@ -199,6 +387,23 @@ func MarshalExtJSONAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val int // MarshalExtJSONAppendWithContext will append the extended JSON encoding of // val to dst using Registry r. If dst is not large enough to hold the BSON // encoding of val, dst will be grown. +// +// Deprecated: Use [NewEncoder], pass the dst byte slice (wrapped by a bytes.Buffer) into +// [bsonrw.NewExtJSONValueWriter], and use the Encoder configuration methods to set the desired marshal +// behavior instead: +// +// buf := bytes.NewBuffer(dst) +// vw, err := bsonrw.NewExtJSONValueWriter(buf, true, false) +// if err != nil { +// panic(err) +// } +// enc, err := bson.NewEncoder(vw) +// if err != nil { +// panic(err) +// } +// enc.IntMinSize() +// +// See [Encoder] for more examples. func MarshalExtJSONAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}, canonical, escapeHTML bool) ([]byte, error) { sw := new(bsonrw.SliceWriter) *sw = dst diff --git a/vendor/go.mongodb.org/mongo-driver/bson/primitive/decimal.go b/vendor/go.mongodb.org/mongo-driver/bson/primitive/decimal.go index ffe4eed..db8be74 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/primitive/decimal.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/primitive/decimal.go @@ -70,7 +70,6 @@ func (d Decimal128) String() string { // Bits: 1*sign 2*ignored 14*exponent 111*significand. // Implicit 0b100 prefix in significand. exp = int(d.h >> 47 & (1<<14 - 1)) - //high = 4<<47 | d.h&(1<<47-1) // Spec says all of these values are out of range. high, low = 0, 0 } else { @@ -152,21 +151,17 @@ func (d Decimal128) BigInt() (*big.Int, int, error) { // Bits: 1*sign 2*ignored 14*exponent 111*significand. // Implicit 0b100 prefix in significand. exp = int(high >> 47 & (1<<14 - 1)) - //high = 4<<47 | d.h&(1<<47-1) // Spec says all of these values are out of range. high, low = 0, 0 } else { // Bits: 1*sign 14*exponent 113*significand exp = int(high >> 49 & (1<<14 - 1)) - high = high & (1<<49 - 1) + high &= (1<<49 - 1) } exp += MinDecimal128Exp // Would be handled by the logic below, but that's trivial and common. if high == 0 && low == 0 && exp == 0 { - if posSign { - return new(big.Int), 0, nil - } return new(big.Int), 0, nil } @@ -191,10 +186,9 @@ func (d Decimal128) IsNaN() bool { // IsInf returns: // -// +1 d == Infinity -// 0 other case -// -1 d == -Infinity -// +// +1 d == Infinity +// 0 other case +// -1 d == -Infinity func (d Decimal128) IsInf() int { if d.h>>58&(1<<5-1) != 0x1E { return 0 @@ -329,6 +323,7 @@ func ParseDecimal128(s string) (Decimal128, error) { return dErr(s) } + // Parse the significand (i.e. the non-exponent part) as a big.Int. bi, ok := new(big.Int).SetString(intPart+decPart, 10) if !ok { return dErr(s) @@ -355,12 +350,25 @@ var ( // ParseDecimal128FromBigInt attempts to parse the given significand and exponent into a valid Decimal128 value. func ParseDecimal128FromBigInt(bi *big.Int, exp int) (Decimal128, bool) { - //copy + // copy bi = new(big.Int).Set(bi) q := new(big.Int) r := new(big.Int) + // If the significand is zero, the logical value will always be zero, independent of the + // exponent. However, the loops for handling out-of-range exponent values below may be extremely + // slow for zero values because the significand never changes. Limit the exponent value to the + // supported range here to prevent entering the loops below. + if bi.Cmp(zero) == 0 { + if exp > MaxDecimal128Exp { + exp = MaxDecimal128Exp + } + if exp < MinDecimal128Exp { + exp = MinDecimal128Exp + } + } + for bigIntCmpAbs(bi, maxS) == 1 { bi, _ = q.QuoRem(bi, ten, r) if r.Cmp(zero) != 0 { diff --git a/vendor/go.mongodb.org/mongo-driver/bson/primitive/objectid.go b/vendor/go.mongodb.org/mongo-driver/bson/primitive/objectid.go index 652898f..c130e3f 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/primitive/objectid.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/primitive/objectid.go @@ -61,7 +61,9 @@ func (id ObjectID) Timestamp() time.Time { // Hex returns the hex encoding of the ObjectID as a string. func (id ObjectID) Hex() string { - return hex.EncodeToString(id[:]) + var buf [24]byte + hex.Encode(buf[:], id[:]) + return string(buf[:]) } func (id ObjectID) String() string { @@ -80,18 +82,18 @@ func ObjectIDFromHex(s string) (ObjectID, error) { return NilObjectID, ErrInvalidHex } - b, err := hex.DecodeString(s) + var oid [12]byte + _, err := hex.Decode(oid[:], []byte(s)) if err != nil { return NilObjectID, err } - var oid [12]byte - copy(oid[:], b) - return oid, nil } // IsValidObjectID returns true if the provided hex string represents a valid ObjectID and false if not. +// +// Deprecated: Use ObjectIDFromHex and check the error instead. func IsValidObjectID(s string) bool { _, err := ObjectIDFromHex(s) return err == nil @@ -181,7 +183,7 @@ func processUniqueBytes() [5]byte { var b [5]byte _, err := io.ReadFull(rand.Reader, b[:]) if err != nil { - panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %v", err)) + panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %w", err)) } return b @@ -191,7 +193,7 @@ func readRandomUint32() uint32 { var b [4]byte _, err := io.ReadFull(rand.Reader, b[:]) if err != nil { - panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %v", err)) + panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %w", err)) } return (uint32(b[0]) << 0) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) diff --git a/vendor/go.mongodb.org/mongo-driver/bson/primitive/primitive.go b/vendor/go.mongodb.org/mongo-driver/bson/primitive/primitive.go index b3cba1b..65f4fbb 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/primitive/primitive.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/primitive/primitive.go @@ -45,7 +45,7 @@ var _ json.Unmarshaler = (*DateTime)(nil) // MarshalJSON marshal to time type. func (d DateTime) MarshalJSON() ([]byte, error) { - return json.Marshal(d.Time()) + return json.Marshal(d.Time().UTC()) } // UnmarshalJSON creates a primitive.DateTime from a JSON string. @@ -141,6 +141,16 @@ type Timestamp struct { I uint32 } +// After reports whether the time instant tp is after tp2. +func (tp Timestamp) After(tp2 Timestamp) bool { + return tp.T > tp2.T || (tp.T == tp2.T && tp.I > tp2.I) +} + +// Before reports whether the time instant tp is before tp2. +func (tp Timestamp) Before(tp2 Timestamp) bool { + return tp.T < tp2.T || (tp.T == tp2.T && tp.I < tp2.I) +} + // Equal compares tp to tp2 and returns true if they are equal. func (tp Timestamp) Equal(tp2 Timestamp) bool { return tp.T == tp2.T && tp.I == tp2.I @@ -151,24 +161,25 @@ func (tp Timestamp) IsZero() bool { return tp.T == 0 && tp.I == 0 } -// CompareTimestamp returns an integer comparing two Timestamps, where T is compared first, followed by I. -// Returns 0 if tp = tp2, 1 if tp > tp2, -1 if tp < tp2. -func CompareTimestamp(tp, tp2 Timestamp) int { - if tp.Equal(tp2) { +// Compare compares the time instant tp with tp2. If tp is before tp2, it returns -1; if tp is after +// tp2, it returns +1; if they're the same, it returns 0. +func (tp Timestamp) Compare(tp2 Timestamp) int { + switch { + case tp.Equal(tp2): return 0 - } - - if tp.T > tp2.T { - return 1 - } - if tp.T < tp2.T { + case tp.Before(tp2): return -1 + default: + return +1 } - // Compare I values because T values are equal - if tp.I > tp2.I { - return 1 - } - return -1 +} + +// CompareTimestamp compares the time instant tp with tp2. If tp is before tp2, it returns -1; if tp is after +// tp2, it returns +1; if they're the same, it returns 0. +// +// Deprecated: Use Timestamp.Compare instead. +func CompareTimestamp(tp, tp2 Timestamp) int { + return tp.Compare(tp2) } // MinKey represents the BSON minkey value. @@ -182,10 +193,13 @@ type MaxKey struct{} // // Example usage: // -// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}} +// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}} type D []E // Map creates a map from the elements of the D. +// +// Deprecated: Converting directly from a D to an M will not be supported in Go Driver 2.0. Instead, +// users should marshal the D to BSON using bson.Marshal and unmarshal it to M using bson.Unmarshal. func (d D) Map() M { m := make(M, len(d)) for _, e := range d { @@ -206,12 +220,12 @@ type E struct { // // Example usage: // -// bson.M{"foo": "bar", "hello": "world", "pi": 3.14159} +// bson.M{"foo": "bar", "hello": "world", "pi": 3.14159} type M map[string]interface{} // An A is an ordered representation of a BSON array. // // Example usage: // -// bson.A{"bar", "world", 3.14159, bson.D{{"qux", 12345}}} +// bson.A{"bar", "world", 3.14159, bson.D{{"qux", 12345}}} type A []interface{} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/primitive_codecs.go b/vendor/go.mongodb.org/mongo-driver/bson/primitive_codecs.go index 1cbe388..ff32a87 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/primitive_codecs.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/primitive_codecs.go @@ -8,6 +8,7 @@ package bson import ( "errors" + "fmt" "reflect" "go.mongodb.org/mongo-driver/bson/bsoncodec" @@ -21,10 +22,16 @@ var primitiveCodecs PrimitiveCodecs // PrimitiveCodecs is a namespace for all of the default bsoncodec.Codecs for the primitive types // defined in this package. +// +// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders +// registered. type PrimitiveCodecs struct{} // RegisterPrimitiveCodecs will register the encode and decode methods attached to PrimitiveCodecs // with the provided RegistryBuilder. if rb is nil, a new empty RegistryBuilder will be created. +// +// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders +// registered. func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *bsoncodec.RegistryBuilder) { if rb == nil { panic(errors.New("argument to RegisterPrimitiveCodecs must not be nil")) @@ -38,18 +45,35 @@ func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *bsoncodec.RegistryBuilder) } // RawValueEncodeValue is the ValueEncoderFunc for RawValue. -func (PrimitiveCodecs) RawValueEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// If the RawValue's Type is "invalid" and the RawValue's Value is not empty or +// nil, then this method will return an error. +// +// Deprecated: Use bson.NewRegistry to get a registry with all primitive +// encoders and decoders registered. +func (PrimitiveCodecs) RawValueEncodeValue(_ bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRawValue { - return bsoncodec.ValueEncoderError{Name: "RawValueEncodeValue", Types: []reflect.Type{tRawValue}, Received: val} + return bsoncodec.ValueEncoderError{ + Name: "RawValueEncodeValue", + Types: []reflect.Type{tRawValue}, + Received: val, + } } rawvalue := val.Interface().(RawValue) + if !rawvalue.Type.IsValid() { + return fmt.Errorf("the RawValue Type specifies an invalid BSON type: %#x", byte(rawvalue.Type)) + } + return bsonrw.Copier{}.CopyValueFromBytes(vw, rawvalue.Type, rawvalue.Value) } // RawValueDecodeValue is the ValueDecoderFunc for RawValue. -func (PrimitiveCodecs) RawValueDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { +// +// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders +// registered. +func (PrimitiveCodecs) RawValueDecodeValue(_ bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tRawValue { return bsoncodec.ValueDecoderError{Name: "RawValueDecodeValue", Types: []reflect.Type{tRawValue}, Received: val} } @@ -64,7 +88,10 @@ func (PrimitiveCodecs) RawValueDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw } // RawEncodeValue is the ValueEncoderFunc for Reader. -func (PrimitiveCodecs) RawEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { +// +// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders +// registered. +func (PrimitiveCodecs) RawEncodeValue(_ bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRaw { return bsoncodec.ValueEncoderError{Name: "RawEncodeValue", Types: []reflect.Type{tRaw}, Received: val} } @@ -75,7 +102,10 @@ func (PrimitiveCodecs) RawEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.Valu } // RawDecodeValue is the ValueDecoderFunc for Reader. -func (PrimitiveCodecs) RawDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { +// +// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders +// registered. +func (PrimitiveCodecs) RawDecodeValue(_ bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tRaw { return bsoncodec.ValueDecoderError{Name: "RawDecodeValue", Types: []reflect.Type{tRaw}, Received: val} } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/raw.go b/vendor/go.mongodb.org/mongo-driver/bson/raw.go index efd705d..130da61 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/raw.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/raw.go @@ -16,18 +16,27 @@ import ( // ErrNilReader indicates that an operation was attempted on a nil bson.Reader. var ErrNilReader = errors.New("nil reader") -// Raw is a wrapper around a byte slice. It will interpret the slice as a -// BSON document. This type is a wrapper around a bsoncore.Document. Errors returned from the -// methods on this type and associated types come from the bsoncore package. +// Raw is a raw encoded BSON document. It can be used to delay BSON document decoding or precompute +// a BSON encoded document. +// +// A Raw must be a full BSON document. Use the RawValue type for individual BSON values. type Raw []byte -// NewFromIOReader reads in a document from the given io.Reader and constructs a Raw from -// it. -func NewFromIOReader(r io.Reader) (Raw, error) { +// ReadDocument reads a BSON document from the io.Reader and returns it as a bson.Raw. If the +// reader contains multiple BSON documents, only the first document is read. +func ReadDocument(r io.Reader) (Raw, error) { doc, err := bsoncore.NewDocumentFromReader(r) return Raw(doc), err } +// NewFromIOReader reads a BSON document from the io.Reader and returns it as a bson.Raw. If the +// reader contains multiple BSON documents, only the first document is read. +// +// Deprecated: Use ReadDocument instead. +func NewFromIOReader(r io.Reader) (Raw, error) { + return ReadDocument(r) +} + // Validate validates the document. This method only validates the first document in // the slice, to validate other documents, the slice must be resliced. func (r Raw) Validate() (err error) { return bsoncore.Document(r).Validate() } @@ -51,12 +60,19 @@ func (r Raw) LookupErr(key ...string) (RawValue, error) { // elements. If the document is not valid, the elements up to the invalid point will be returned // along with an error. func (r Raw) Elements() ([]RawElement, error) { - elems, err := bsoncore.Document(r).Elements() + doc := bsoncore.Document(r) + if len(doc) == 0 { + return nil, nil + } + elems, err := doc.Elements() + if err != nil { + return nil, err + } relems := make([]RawElement, 0, len(elems)) for _, elem := range elems { relems = append(relems, RawElement(elem)) } - return relems, err + return relems, nil } // Values returns this document as a slice of values. The returned slice will contain valid values. @@ -81,5 +97,5 @@ func (r Raw) IndexErr(index uint) (RawElement, error) { return RawElement(elem), err } -// String implements the fmt.Stringer interface. +// String returns the BSON document encoded as Extended JSON. func (r Raw) String() string { return bsoncore.Document(r).String() } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/raw_element.go b/vendor/go.mongodb.org/mongo-driver/bson/raw_element.go index 006f503..8ce13c2 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/raw_element.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/raw_element.go @@ -10,10 +10,7 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) -// RawElement represents a BSON element in byte form. This type provides a simple way to -// transform a slice of bytes into a BSON element and extract information from it. -// -// RawElement is a thin wrapper around a bsoncore.Element. +// RawElement is a raw encoded BSON document or array element. type RawElement []byte // Key returns the key for this element. If the element is not valid, this method returns an empty @@ -36,7 +33,7 @@ func (re RawElement) ValueErr() (RawValue, error) { // Validate ensures re is a valid BSON element. func (re RawElement) Validate() error { return bsoncore.Element(re).Validate() } -// String implements the fmt.Stringer interface. The output will be in extended JSON format. +// String returns the BSON element encoded as Extended JSON. func (re RawElement) String() string { doc := bsoncore.BuildDocument(nil, re) j, err := MarshalExtJSON(Raw(doc), true, false) diff --git a/vendor/go.mongodb.org/mongo-driver/bson/raw_value.go b/vendor/go.mongodb.org/mongo-driver/bson/raw_value.go index 75297f3..a8088e1 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/raw_value.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/raw_value.go @@ -26,11 +26,10 @@ var ErrNilContext = errors.New("DecodeContext cannot be nil") // ErrNilRegistry is returned when the provided registry is nil. var ErrNilRegistry = errors.New("Registry cannot be nil") -// RawValue represents a BSON value in byte form. It can be used to hold unprocessed BSON or to -// defer processing of BSON. Type is the BSON type of the value and Value are the raw bytes that -// represent the element. +// RawValue is a raw encoded BSON value. It can be used to delay BSON value decoding or precompute +// BSON encoded value. Type is the BSON type of the value and Value is the raw encoded BSON value. // -// This type wraps bsoncore.Value for most of it's functionality. +// A RawValue must be an individual BSON value. Use the Raw type for full BSON documents. type RawValue struct { Type bsontype.Type Value []byte @@ -38,6 +37,12 @@ type RawValue struct { r *bsoncodec.Registry } +// IsZero reports whether the RawValue is zero, i.e. no data is present on +// the RawValue. It returns true if Type is 0 and Value is empty or nil. +func (rv RawValue) IsZero() bool { + return rv.Type == 0x00 && len(rv.Value) == 0 +} + // Unmarshal deserializes BSON into the provided val. If RawValue cannot be unmarshaled into val, an // error is returned. This method will use the registry used to create the RawValue, if the RawValue // was created from partial BSON processing, or it will use the default registry. Users wishing to @@ -83,8 +88,12 @@ func (rv RawValue) UnmarshalWithRegistry(r *bsoncodec.Registry, val interface{}) return dec.DecodeValue(bsoncodec.DecodeContext{Registry: r}, vr, rval) } -// UnmarshalWithContext performs the same unmarshalling as Unmarshal but uses the provided DecodeContext -// instead of the one attached or the default registry. +// UnmarshalWithContext performs the same unmarshalling as Unmarshal but uses +// the provided DecodeContext instead of the one attached or the default +// registry. +// +// Deprecated: Use [RawValue.UnmarshalWithRegistry] with a custom registry to customize +// unmarshal behavior instead. func (rv RawValue) UnmarshalWithContext(dc *bsoncodec.DecodeContext, val interface{}) error { if dc == nil { return ErrNilContext @@ -268,10 +277,16 @@ func (rv RawValue) Int32OK() (int32, bool) { return convertToCoreValue(rv).Int32 // AsInt32 returns a BSON number as an int32. If the BSON type is not a numeric one, this method // will panic. +// +// Deprecated: Use AsInt64 instead. If an int32 is required, convert the returned value to an int32 +// and perform any required overflow/underflow checking. func (rv RawValue) AsInt32() int32 { return convertToCoreValue(rv).AsInt32() } // AsInt32OK is the same as AsInt32, except that it returns a boolean instead of // panicking. +// +// Deprecated: Use AsInt64OK instead. If an int32 is required, convert the returned value to an +// int32 and perform any required overflow/underflow checking. func (rv RawValue) AsInt32OK() (int32, bool) { return convertToCoreValue(rv).AsInt32OK() } // Timestamp returns the BSON timestamp value the Value represents. It panics if the value is a diff --git a/vendor/go.mongodb.org/mongo-driver/bson/registry.go b/vendor/go.mongodb.org/mongo-driver/bson/registry.go index 16d7573..d6afb28 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/registry.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/registry.go @@ -6,15 +6,31 @@ package bson -import "go.mongodb.org/mongo-driver/bson/bsoncodec" +import ( + "go.mongodb.org/mongo-driver/bson/bsoncodec" +) -// DefaultRegistry is the default bsoncodec.Registry. It contains the default codecs and the -// primitive codecs. -var DefaultRegistry = NewRegistryBuilder().Build() +// DefaultRegistry is the default bsoncodec.Registry. It contains the default +// codecs and the primitive codecs. +// +// Deprecated: Use [NewRegistry] to construct a new default registry. To use a +// custom registry when marshaling or unmarshaling, use the "SetRegistry" method +// on an [Encoder] or [Decoder] instead: +// +// dec, err := bson.NewDecoder(bsonrw.NewBSONDocumentReader(data)) +// if err != nil { +// panic(err) +// } +// dec.SetRegistry(reg) +// +// See [Encoder] and [Decoder] for more examples. +var DefaultRegistry = NewRegistry() // NewRegistryBuilder creates a new RegistryBuilder configured with the default encoders and // decoders from the bsoncodec.DefaultValueEncoders and bsoncodec.DefaultValueDecoders types and the // PrimitiveCodecs type in this package. +// +// Deprecated: Use [NewRegistry] instead. func NewRegistryBuilder() *bsoncodec.RegistryBuilder { rb := bsoncodec.NewRegistryBuilder() bsoncodec.DefaultValueEncoders{}.RegisterDefaultEncoders(rb) @@ -22,3 +38,10 @@ func NewRegistryBuilder() *bsoncodec.RegistryBuilder { primitiveCodecs.RegisterPrimitiveCodecs(rb) return rb } + +// NewRegistry creates a new Registry configured with the default encoders and decoders from the +// bsoncodec.DefaultValueEncoders and bsoncodec.DefaultValueDecoders types and the PrimitiveCodecs +// type in this package. +func NewRegistry() *bsoncodec.Registry { + return NewRegistryBuilder().Build() +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/types.go b/vendor/go.mongodb.org/mongo-driver/bson/types.go index 13a1c35..ef39812 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/types.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/types.go @@ -10,7 +10,7 @@ import ( "go.mongodb.org/mongo-driver/bson/bsontype" ) -// These constants uniquely refer to each BSON type. +// BSON element types as described in https://bsonspec.org/spec.html. const ( TypeDouble = bsontype.Double TypeString = bsontype.String @@ -34,3 +34,17 @@ const ( TypeMinKey = bsontype.MinKey TypeMaxKey = bsontype.MaxKey ) + +// BSON binary element subtypes as described in https://bsonspec.org/spec.html. +const ( + TypeBinaryGeneric = bsontype.BinaryGeneric + TypeBinaryFunction = bsontype.BinaryFunction + TypeBinaryBinaryOld = bsontype.BinaryBinaryOld + TypeBinaryUUIDOld = bsontype.BinaryUUIDOld + TypeBinaryUUID = bsontype.BinaryUUID + TypeBinaryMD5 = bsontype.BinaryMD5 + TypeBinaryEncrypted = bsontype.BinaryEncrypted + TypeBinaryColumn = bsontype.BinaryColumn + TypeBinarySensitive = bsontype.BinarySensitive + TypeBinaryUserDefined = bsontype.BinaryUserDefined +) diff --git a/vendor/go.mongodb.org/mongo-driver/bson/unmarshal.go b/vendor/go.mongodb.org/mongo-driver/bson/unmarshal.go index f936ba1..d749ba3 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/unmarshal.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/unmarshal.go @@ -14,18 +14,26 @@ import ( "go.mongodb.org/mongo-driver/bson/bsontype" ) -// Unmarshaler is an interface implemented by types that can unmarshal a BSON -// document representation of themselves. The BSON bytes can be assumed to be -// valid. UnmarshalBSON must copy the BSON bytes if it wishes to retain the data -// after returning. +// Unmarshaler is the interface implemented by types that can unmarshal a BSON +// document representation of themselves. The input can be assumed to be a valid +// encoding of a BSON document. UnmarshalBSON must copy the JSON data if it +// wishes to retain the data after returning. +// +// Unmarshaler is only used to unmarshal full BSON documents. To create custom +// BSON unmarshaling behavior for individual values in a BSON document, +// implement the ValueUnmarshaler interface instead. type Unmarshaler interface { UnmarshalBSON([]byte) error } -// ValueUnmarshaler is an interface implemented by types that can unmarshal a -// BSON value representation of themselves. The BSON bytes and type can be -// assumed to be valid. UnmarshalBSONValue must copy the BSON value bytes if it -// wishes to retain the data after returning. +// ValueUnmarshaler is the interface implemented by types that can unmarshal a +// BSON value representation of themselves. The input can be assumed to be a +// valid encoding of a BSON value. UnmarshalBSONValue must copy the BSON value +// bytes if it wishes to retain the data after returning. +// +// ValueUnmarshaler is only used to unmarshal individual values in a BSON +// document. To create custom BSON unmarshaling behavior for an entire BSON +// document, implement the Unmarshaler interface instead. type ValueUnmarshaler interface { UnmarshalBSONValue(bsontype.Type, []byte) error } @@ -33,6 +41,9 @@ type ValueUnmarshaler interface { // Unmarshal parses the BSON-encoded data and stores the result in the value // pointed to by val. If val is nil or not a pointer, Unmarshal returns // InvalidUnmarshalError. +// +// When unmarshaling BSON, if the BSON value is null and the Go value is a +// pointer, the pointer is set to nil without calling UnmarshalBSONValue. func Unmarshal(data []byte, val interface{}) error { return UnmarshalWithRegistry(DefaultRegistry, data, val) } @@ -40,6 +51,16 @@ func Unmarshal(data []byte, val interface{}) error { // UnmarshalWithRegistry parses the BSON-encoded data using Registry r and // stores the result in the value pointed to by val. If val is nil or not // a pointer, UnmarshalWithRegistry returns InvalidUnmarshalError. +// +// Deprecated: Use [NewDecoder] and specify the Registry by calling [Decoder.SetRegistry] instead: +// +// dec, err := bson.NewDecoder(bsonrw.NewBSONDocumentReader(data)) +// if err != nil { +// panic(err) +// } +// dec.SetRegistry(reg) +// +// See [Decoder] for more examples. func UnmarshalWithRegistry(r *bsoncodec.Registry, data []byte, val interface{}) error { vr := bsonrw.NewBSONDocumentReader(data) return unmarshalFromReader(bsoncodec.DecodeContext{Registry: r}, vr, val) @@ -48,11 +69,40 @@ func UnmarshalWithRegistry(r *bsoncodec.Registry, data []byte, val interface{}) // UnmarshalWithContext parses the BSON-encoded data using DecodeContext dc and // stores the result in the value pointed to by val. If val is nil or not // a pointer, UnmarshalWithRegistry returns InvalidUnmarshalError. +// +// Deprecated: Use [NewDecoder] and use the Decoder configuration methods to set the desired unmarshal +// behavior instead: +// +// dec, err := bson.NewDecoder(bsonrw.NewBSONDocumentReader(data)) +// if err != nil { +// panic(err) +// } +// dec.DefaultDocumentM() +// +// See [Decoder] for more examples. func UnmarshalWithContext(dc bsoncodec.DecodeContext, data []byte, val interface{}) error { vr := bsonrw.NewBSONDocumentReader(data) return unmarshalFromReader(dc, vr, val) } +// UnmarshalValue parses the BSON value of type t with bson.DefaultRegistry and +// stores the result in the value pointed to by val. If val is nil or not a pointer, +// UnmarshalValue returns an error. +func UnmarshalValue(t bsontype.Type, data []byte, val interface{}) error { + return UnmarshalValueWithRegistry(DefaultRegistry, t, data, val) +} + +// UnmarshalValueWithRegistry parses the BSON value of type t with registry r and +// stores the result in the value pointed to by val. If val is nil or not a pointer, +// UnmarshalValue returns an error. +// +// Deprecated: Using a custom registry to unmarshal individual BSON values will not be supported in +// Go Driver 2.0. +func UnmarshalValueWithRegistry(r *bsoncodec.Registry, t bsontype.Type, data []byte, val interface{}) error { + vr := bsonrw.NewBSONValueReader(t, data) + return unmarshalFromReader(bsoncodec.DecodeContext{Registry: r}, vr, val) +} + // UnmarshalExtJSON parses the extended JSON-encoded data and stores the result // in the value pointed to by val. If val is nil or not a pointer, Unmarshal // returns InvalidUnmarshalError. @@ -63,6 +113,20 @@ func UnmarshalExtJSON(data []byte, canonical bool, val interface{}) error { // UnmarshalExtJSONWithRegistry parses the extended JSON-encoded data using // Registry r and stores the result in the value pointed to by val. If val is // nil or not a pointer, UnmarshalWithRegistry returns InvalidUnmarshalError. +// +// Deprecated: Use [NewDecoder] and specify the Registry by calling [Decoder.SetRegistry] instead: +// +// vr, err := bsonrw.NewExtJSONValueReader(bytes.NewReader(data), true) +// if err != nil { +// panic(err) +// } +// dec, err := bson.NewDecoder(vr) +// if err != nil { +// panic(err) +// } +// dec.SetRegistry(reg) +// +// See [Decoder] for more examples. func UnmarshalExtJSONWithRegistry(r *bsoncodec.Registry, data []byte, canonical bool, val interface{}) error { ejvr, err := bsonrw.NewExtJSONValueReader(bytes.NewReader(data), canonical) if err != nil { @@ -75,6 +139,21 @@ func UnmarshalExtJSONWithRegistry(r *bsoncodec.Registry, data []byte, canonical // UnmarshalExtJSONWithContext parses the extended JSON-encoded data using // DecodeContext dc and stores the result in the value pointed to by val. If val is // nil or not a pointer, UnmarshalWithRegistry returns InvalidUnmarshalError. +// +// Deprecated: Use [NewDecoder] and use the Decoder configuration methods to set the desired unmarshal +// behavior instead: +// +// vr, err := bsonrw.NewExtJSONValueReader(bytes.NewReader(data), true) +// if err != nil { +// panic(err) +// } +// dec, err := bson.NewDecoder(vr) +// if err != nil { +// panic(err) +// } +// dec.DefaultDocumentM() +// +// See [Decoder] for more examples. func UnmarshalExtJSONWithContext(dc bsoncodec.DecodeContext, data []byte, canonical bool, val interface{}) error { ejvr, err := bsonrw.NewExtJSONValueReader(bytes.NewReader(data), canonical) if err != nil { diff --git a/vendor/go.mongodb.org/mongo-driver/event/doc.go b/vendor/go.mongodb.org/mongo-driver/event/doc.go index 93b5ede..da1da4d 100644 --- a/vendor/go.mongodb.org/mongo-driver/event/doc.go +++ b/vendor/go.mongodb.org/mongo-driver/event/doc.go @@ -14,43 +14,43 @@ // CommandSucceededEvent or CommandFailedEvent through the RequestID field. For // example, the following code collects the names of started events: // -// var commandStarted []string -// cmdMonitor := &event.CommandMonitor{ -// Started: func(_ context.Context, evt *event.CommandStartedEvent) { -// commandStarted = append(commandStarted, evt.CommandName) -// }, -// } -// clientOpts := options.Client().ApplyURI("mongodb://localhost:27017").SetMonitor(cmdMonitor) -// client, err := mongo.Connect(context.Background(), clientOpts) +// var commandStarted []string +// cmdMonitor := &event.CommandMonitor{ +// Started: func(_ context.Context, evt *event.CommandStartedEvent) { +// commandStarted = append(commandStarted, evt.CommandName) +// }, +// } +// clientOpts := options.Client().ApplyURI("mongodb://localhost:27017").SetMonitor(cmdMonitor) +// client, err := mongo.Connect(context.Background(), clientOpts) // // Monitoring the connection pool requires specifying a PoolMonitor when constructing // a mongo.Client. The following code tracks the number of checked out connections: // -// var int connsCheckedOut -// poolMonitor := &event.PoolMonitor{ -// Event: func(evt *event.PoolEvent) { -// switch evt.Type { -// case event.GetSucceeded: -// connsCheckedOut++ -// case event.ConnectionReturned: -// connsCheckedOut-- -// } -// }, -// } -// clientOpts := options.Client().ApplyURI("mongodb://localhost:27017").SetPoolMonitor(poolMonitor) -// client, err := mongo.Connect(context.Background(), clientOpts) +// var int connsCheckedOut +// poolMonitor := &event.PoolMonitor{ +// Event: func(evt *event.PoolEvent) { +// switch evt.Type { +// case event.GetSucceeded: +// connsCheckedOut++ +// case event.ConnectionReturned: +// connsCheckedOut-- +// } +// }, +// } +// clientOpts := options.Client().ApplyURI("mongodb://localhost:27017").SetPoolMonitor(poolMonitor) +// client, err := mongo.Connect(context.Background(), clientOpts) // // Monitoring server changes specifying a ServerMonitor object when constructing // a mongo.Client. Different functions can be set on the ServerMonitor to // monitor different kinds of events. See ServerMonitor for more details. // The following code appends ServerHeartbeatStartedEvents to a slice: // -// var heartbeatStarted []*event.ServerHeartbeatStartedEvent -// svrMonitor := &event.ServerMonitor{ -// ServerHeartbeatStarted: func(e *event.ServerHeartbeatStartedEvent) { -// heartbeatStarted = append(heartbeatStarted, e) -// } -// } -// clientOpts := options.Client().ApplyURI("mongodb://localhost:27017").SetServerMonitor(svrMonitor) -// client, err := mongo.Connect(context.Background(), clientOpts) +// var heartbeatStarted []*event.ServerHeartbeatStartedEvent +// svrMonitor := &event.ServerMonitor{ +// ServerHeartbeatStarted: func(e *event.ServerHeartbeatStartedEvent) { +// heartbeatStarted = append(heartbeatStarted, e) +// } +// } +// clientOpts := options.Client().ApplyURI("mongodb://localhost:27017").SetServerMonitor(svrMonitor) +// client, err := mongo.Connect(context.Background(), clientOpts) package event diff --git a/vendor/go.mongodb.org/mongo-driver/event/monitoring.go b/vendor/go.mongodb.org/mongo-driver/event/monitoring.go index ac05e40..ddc7aba 100644 --- a/vendor/go.mongodb.org/mongo-driver/event/monitoring.go +++ b/vendor/go.mongodb.org/mongo-driver/event/monitoring.go @@ -8,6 +8,7 @@ package event // import "go.mongodb.org/mongo-driver/event" import ( "context" + "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" @@ -23,8 +24,14 @@ type CommandStartedEvent struct { RequestID int64 ConnectionID string // ServerConnectionID contains the connection ID from the server of the operation. If the server does not return - // this value (e.g. on MDB < 4.2), it is unset. + // this value (e.g. on MDB < 4.2), it is unset. If the server connection ID would cause an int32 overflow, then + // then this field will be nil. + // + // Deprecated: Use ServerConnectionID64. ServerConnectionID *int32 + // ServerConnectionID64 contains the connection ID from the server of the operation. If the server does not + // return this value (e.g. on MDB < 4.2), it is unset. + ServerConnectionID64 *int64 // ServiceID contains the ID of the server to which the command was sent if it is running behind a load balancer. // Otherwise, it is unset. ServiceID *primitive.ObjectID @@ -32,13 +39,22 @@ type CommandStartedEvent struct { // CommandFinishedEvent represents a generic command finishing. type CommandFinishedEvent struct { + // Deprecated: Use Duration instead. DurationNanos int64 + Duration time.Duration CommandName string + DatabaseName string RequestID int64 ConnectionID string // ServerConnectionID contains the connection ID from the server of the operation. If the server does not return - // this value (e.g. on MDB < 4.2), it is unset. + // this value (e.g. on MDB < 4.2), it is unset.If the server connection ID would cause an int32 overflow, then + // this field will be nil. + // + // Deprecated: Use ServerConnectionID64. ServerConnectionID *int32 + // ServerConnectionID64 contains the connection ID from the server of the operation. If the server does not + // return this value (e.g. on MDB < 4.2), it is unset. + ServerConnectionID64 *int64 // ServiceID contains the ID of the server to which the command was sent if it is running behind a load balancer. // Otherwise, it is unset. ServiceID *primitive.ObjectID @@ -101,10 +117,13 @@ type PoolEvent struct { Address string `json:"address"` ConnectionID uint64 `json:"connectionId"` PoolOptions *MonitorPoolOptions `json:"options"` + Duration time.Duration `json:"duration"` Reason string `json:"reason"` // ServiceID is only set if the Type is PoolCleared and the server is deployed behind a load balancer. This field // can be used to distinguish between individual servers in a load balanced deployment. - ServiceID *primitive.ObjectID `json:"serviceId"` + ServiceID *primitive.ObjectID `json:"serviceId"` + Interruption bool `json:"interruptInUseConnections"` + Error error `json:"error"` } // PoolMonitor is a function that allows the user to gain access to events occurring in the pool @@ -157,7 +176,9 @@ type ServerHeartbeatStartedEvent struct { // ServerHeartbeatSucceededEvent is an event generated when the heartbeat succeeds. type ServerHeartbeatSucceededEvent struct { + // Deprecated: Use Duration instead. DurationNanos int64 + Duration time.Duration Reply description.Server ConnectionID string // The address this heartbeat was sent to with a unique identifier Awaited bool // If this heartbeat was awaitable @@ -165,7 +186,9 @@ type ServerHeartbeatSucceededEvent struct { // ServerHeartbeatFailedEvent is an event generated when the heartbeat fails. type ServerHeartbeatFailedEvent struct { + // Deprecated: Use Duration instead. DurationNanos int64 + Duration time.Duration Failure error ConnectionID string // The address this heartbeat was sent to with a unique identifier Awaited bool // If this heartbeat was awaitable diff --git a/vendor/go.mongodb.org/mongo-driver/internal/aws/awserr/error.go b/vendor/go.mongodb.org/mongo-driver/internal/aws/awserr/error.go new file mode 100644 index 0000000..63d06a1 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/aws/awserr/error.go @@ -0,0 +1,60 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +// +// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from: +// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/awserr/error.go +// See THIRD-PARTY-NOTICES for original license terms + +// Package awserr represents API error interface accessors for the SDK. +package awserr + +// An Error wraps lower level errors with code, message and an original error. +// The underlying concrete error type may also satisfy other interfaces which +// can be to used to obtain more specific information about the error. +type Error interface { + // Satisfy the generic error interface. + error + + // Returns the short phrase depicting the classification of the error. + Code() string + + // Returns the error details message. + Message() string + + // Returns the original error if one was set. Nil is returned if not set. + OrigErr() error +} + +// BatchedErrors is a batch of errors which also wraps lower level errors with +// code, message, and original errors. Calling Error() will include all errors +// that occurred in the batch. +// +// Replaces BatchError +type BatchedErrors interface { + // Satisfy the base Error interface. + Error + + // Returns the original error if one was set. Nil is returned if not set. + OrigErrs() []error +} + +// New returns an Error object described by the code, message, and origErr. +// +// If origErr satisfies the Error interface it will not be wrapped within a new +// Error object and will instead be returned. +func New(code, message string, origErr error) Error { + var errs []error + if origErr != nil { + errs = append(errs, origErr) + } + return newBaseError(code, message, errs) +} + +// NewBatchError returns an BatchedErrors with a collection of errors as an +// array of errors. +func NewBatchError(code, message string, errs []error) BatchedErrors { + return newBaseError(code, message, errs) +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/aws/awserr/types.go b/vendor/go.mongodb.org/mongo-driver/internal/aws/awserr/types.go new file mode 100644 index 0000000..18cb4cd --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/aws/awserr/types.go @@ -0,0 +1,144 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +// +// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from: +// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/awserr/types.go +// See THIRD-PARTY-NOTICES for original license terms + +package awserr + +import ( + "fmt" +) + +// SprintError returns a string of the formatted error code. +// +// Both extra and origErr are optional. If they are included their lines +// will be added, but if they are not included their lines will be ignored. +func SprintError(code, message, extra string, origErr error) string { + msg := fmt.Sprintf("%s: %s", code, message) + if extra != "" { + msg = fmt.Sprintf("%s\n\t%s", msg, extra) + } + if origErr != nil { + msg = fmt.Sprintf("%s\ncaused by: %s", msg, origErr.Error()) + } + return msg +} + +// A baseError wraps the code and message which defines an error. It also +// can be used to wrap an original error object. +// +// Should be used as the root for errors satisfying the awserr.Error. Also +// for any error which does not fit into a specific error wrapper type. +type baseError struct { + // Classification of error + code string + + // Detailed information about error + message string + + // Optional original error this error is based off of. Allows building + // chained errors. + errs []error +} + +// newBaseError returns an error object for the code, message, and errors. +// +// code is a short no whitespace phrase depicting the classification of +// the error that is being created. +// +// message is the free flow string containing detailed information about the +// error. +// +// origErrs is the error objects which will be nested under the new errors to +// be returned. +func newBaseError(code, message string, origErrs []error) *baseError { + b := &baseError{ + code: code, + message: message, + errs: origErrs, + } + + return b +} + +// Error returns the string representation of the error. +// +// See ErrorWithExtra for formatting. +// +// Satisfies the error interface. +func (b baseError) Error() string { + size := len(b.errs) + if size > 0 { + return SprintError(b.code, b.message, "", errorList(b.errs)) + } + + return SprintError(b.code, b.message, "", nil) +} + +// String returns the string representation of the error. +// Alias for Error to satisfy the stringer interface. +func (b baseError) String() string { + return b.Error() +} + +// Code returns the short phrase depicting the classification of the error. +func (b baseError) Code() string { + return b.code +} + +// Message returns the error details message. +func (b baseError) Message() string { + return b.message +} + +// OrigErr returns the original error if one was set. Nil is returned if no +// error was set. This only returns the first element in the list. If the full +// list is needed, use BatchedErrors. +func (b baseError) OrigErr() error { + switch len(b.errs) { + case 0: + return nil + case 1: + return b.errs[0] + default: + if err, ok := b.errs[0].(Error); ok { + return NewBatchError(err.Code(), err.Message(), b.errs[1:]) + } + return NewBatchError("BatchedErrors", + "multiple errors occurred", b.errs) + } +} + +// OrigErrs returns the original errors if one was set. An empty slice is +// returned if no error was set. +func (b baseError) OrigErrs() []error { + return b.errs +} + +// An error list that satisfies the golang interface +type errorList []error + +// Error returns the string representation of the error. +// +// Satisfies the error interface. +func (e errorList) Error() string { + msg := "" + // How do we want to handle the array size being zero + if size := len(e); size > 0 { + for i := 0; i < size; i++ { + msg += e[i].Error() + // We check the next index to see if it is within the slice. + // If it is, then we append a newline. We do this, because unit tests + // could be broken with the additional '\n' + if i+1 < size { + msg += "\n" + } + } + } + return msg +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/aws/credentials/chain_provider.go b/vendor/go.mongodb.org/mongo-driver/internal/aws/credentials/chain_provider.go new file mode 100644 index 0000000..6843927 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/aws/credentials/chain_provider.go @@ -0,0 +1,72 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +// +// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from: +// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/chain_provider.go +// See THIRD-PARTY-NOTICES for original license terms + +package credentials + +import ( + "go.mongodb.org/mongo-driver/internal/aws/awserr" +) + +// A ChainProvider will search for a provider which returns credentials +// and cache that provider until Retrieve is called again. +// +// The ChainProvider provides a way of chaining multiple providers together +// which will pick the first available using priority order of the Providers +// in the list. +// +// If none of the Providers retrieve valid credentials Value, ChainProvider's +// Retrieve() will return the error ErrNoValidProvidersFoundInChain. +// +// If a Provider is found which returns valid credentials Value ChainProvider +// will cache that Provider for all calls to IsExpired(), until Retrieve is +// called again. +type ChainProvider struct { + Providers []Provider + curr Provider +} + +// NewChainCredentials returns a pointer to a new Credentials object +// wrapping a chain of providers. +func NewChainCredentials(providers []Provider) *Credentials { + return NewCredentials(&ChainProvider{ + Providers: append([]Provider{}, providers...), + }) +} + +// Retrieve returns the credentials value or error if no provider returned +// without error. +// +// If a provider is found it will be cached and any calls to IsExpired() +// will return the expired state of the cached provider. +func (c *ChainProvider) Retrieve() (Value, error) { + var errs = make([]error, 0, len(c.Providers)) + for _, p := range c.Providers { + creds, err := p.Retrieve() + if err == nil { + c.curr = p + return creds, nil + } + errs = append(errs, err) + } + c.curr = nil + + var err = awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs) + return Value{}, err +} + +// IsExpired will returned the expired state of the currently cached provider +// if there is one. If there is no current provider, true will be returned. +func (c *ChainProvider) IsExpired() bool { + if c.curr != nil { + return c.curr.IsExpired() + } + + return true +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/aws/credentials/credentials.go b/vendor/go.mongodb.org/mongo-driver/internal/aws/credentials/credentials.go new file mode 100644 index 0000000..53181aa --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/aws/credentials/credentials.go @@ -0,0 +1,197 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +// +// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from: +// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/credentials.go +// See THIRD-PARTY-NOTICES for original license terms + +package credentials + +import ( + "context" + "sync" + "time" + + "go.mongodb.org/mongo-driver/internal/aws/awserr" + "golang.org/x/sync/singleflight" +) + +// A Value is the AWS credentials value for individual credential fields. +// +// A Value is also used to represent Azure credentials. +// Azure credentials only consist of an access token, which is stored in the `SessionToken` field. +type Value struct { + // AWS Access key ID + AccessKeyID string + + // AWS Secret Access Key + SecretAccessKey string + + // AWS Session Token + SessionToken string + + // Provider used to get credentials + ProviderName string +} + +// HasKeys returns if the credentials Value has both AccessKeyID and +// SecretAccessKey value set. +func (v Value) HasKeys() bool { + return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0 +} + +// A Provider is the interface for any component which will provide credentials +// Value. A provider is required to manage its own Expired state, and what to +// be expired means. +// +// The Provider should not need to implement its own mutexes, because +// that will be managed by Credentials. +type Provider interface { + // Retrieve returns nil if it successfully retrieved the value. + // Error is returned if the value were not obtainable, or empty. + Retrieve() (Value, error) + + // IsExpired returns if the credentials are no longer valid, and need + // to be retrieved. + IsExpired() bool +} + +// ProviderWithContext is a Provider that can retrieve credentials with a Context +type ProviderWithContext interface { + Provider + + RetrieveWithContext(context.Context) (Value, error) +} + +// A Credentials provides concurrency safe retrieval of AWS credentials Value. +// +// A Credentials is also used to fetch Azure credentials Value. +// +// Credentials will cache the credentials value until they expire. Once the value +// expires the next Get will attempt to retrieve valid credentials. +// +// Credentials is safe to use across multiple goroutines and will manage the +// synchronous state so the Providers do not need to implement their own +// synchronization. +// +// The first Credentials.Get() will always call Provider.Retrieve() to get the +// first instance of the credentials Value. All calls to Get() after that +// will return the cached credentials Value until IsExpired() returns true. +type Credentials struct { + sf singleflight.Group + + m sync.RWMutex + creds Value + provider Provider +} + +// NewCredentials returns a pointer to a new Credentials with the provider set. +func NewCredentials(provider Provider) *Credentials { + c := &Credentials{ + provider: provider, + } + return c +} + +// GetWithContext returns the credentials value, or error if the credentials +// Value failed to be retrieved. Will return early if the passed in context is +// canceled. +// +// Will return the cached credentials Value if it has not expired. If the +// credentials Value has expired the Provider's Retrieve() will be called +// to refresh the credentials. +// +// If Credentials.Expire() was called the credentials Value will be force +// expired, and the next call to Get() will cause them to be refreshed. +func (c *Credentials) GetWithContext(ctx context.Context) (Value, error) { + // Check if credentials are cached, and not expired. + select { + case curCreds, ok := <-c.asyncIsExpired(): + // ok will only be true, of the credentials were not expired. ok will + // be false and have no value if the credentials are expired. + if ok { + return curCreds, nil + } + case <-ctx.Done(): + return Value{}, awserr.New("RequestCanceled", + "request context canceled", ctx.Err()) + } + + // Cannot pass context down to the actual retrieve, because the first + // context would cancel the whole group when there is not direct + // association of items in the group. + resCh := c.sf.DoChan("", func() (interface{}, error) { + return c.singleRetrieve(&suppressedContext{ctx}) + }) + select { + case res := <-resCh: + return res.Val.(Value), res.Err + case <-ctx.Done(): + return Value{}, awserr.New("RequestCanceled", + "request context canceled", ctx.Err()) + } +} + +func (c *Credentials) singleRetrieve(ctx context.Context) (interface{}, error) { + c.m.Lock() + defer c.m.Unlock() + + if curCreds := c.creds; !c.isExpiredLocked(curCreds) { + return curCreds, nil + } + + var creds Value + var err error + if p, ok := c.provider.(ProviderWithContext); ok { + creds, err = p.RetrieveWithContext(ctx) + } else { + creds, err = c.provider.Retrieve() + } + if err == nil { + c.creds = creds + } + + return creds, err +} + +// asyncIsExpired returns a channel of credentials Value. If the channel is +// closed the credentials are expired and credentials value are not empty. +func (c *Credentials) asyncIsExpired() <-chan Value { + ch := make(chan Value, 1) + go func() { + c.m.RLock() + defer c.m.RUnlock() + + if curCreds := c.creds; !c.isExpiredLocked(curCreds) { + ch <- curCreds + } + + close(ch) + }() + + return ch +} + +// isExpiredLocked helper method wrapping the definition of expired credentials. +func (c *Credentials) isExpiredLocked(creds interface{}) bool { + return creds == nil || creds.(Value) == Value{} || c.provider.IsExpired() +} + +type suppressedContext struct { + context.Context +} + +func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) { + return time.Time{}, false +} + +func (s *suppressedContext) Done() <-chan struct{} { + return nil +} + +func (s *suppressedContext) Err() error { + return nil +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/aws/signer/v4/header_rules.go b/vendor/go.mongodb.org/mongo-driver/internal/aws/signer/v4/header_rules.go new file mode 100644 index 0000000..a372646 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/aws/signer/v4/header_rules.go @@ -0,0 +1,51 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +// +// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from: +// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/header_rules.go +// See THIRD-PARTY-NOTICES for original license terms + +package v4 + +// validator houses a set of rule needed for validation of a +// string value +type rules []rule + +// rule interface allows for more flexible rules and just simply +// checks whether or not a value adheres to that rule +type rule interface { + IsValid(value string) bool +} + +// IsValid will iterate through all rules and see if any rules +// apply to the value and supports nested rules +func (r rules) IsValid(value string) bool { + for _, rule := range r { + if rule.IsValid(value) { + return true + } + } + return false +} + +// mapRule generic rule for maps +type mapRule map[string]struct{} + +// IsValid for the map rule satisfies whether it exists in the map +func (m mapRule) IsValid(value string) bool { + _, ok := m[value] + return ok +} + +// excludeList is a generic rule for exclude listing +type excludeList struct { + rule +} + +// IsValid for exclude list checks if the value is within the exclude list +func (b excludeList) IsValid(value string) bool { + return !b.rule.IsValid(value) +} diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/awsv4/request.go b/vendor/go.mongodb.org/mongo-driver/internal/aws/signer/v4/request.go similarity index 96% rename from vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/awsv4/request.go rename to vendor/go.mongodb.org/mongo-driver/internal/aws/signer/v4/request.go index 014ee08..7a43bb3 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/awsv4/request.go +++ b/vendor/go.mongodb.org/mongo-driver/internal/aws/signer/v4/request.go @@ -5,10 +5,10 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 // // Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from: -// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/request/request.go +// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/request/request.go // See THIRD-PARTY-NOTICES for original license terms -package awsv4 +package v4 import ( "net/http" diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/awsv4/rest.go b/vendor/go.mongodb.org/mongo-driver/internal/aws/signer/v4/uri_path.go similarity index 72% rename from vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/awsv4/rest.go rename to vendor/go.mongodb.org/mongo-driver/internal/aws/signer/v4/uri_path.go index b1f86a0..69b6005 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/awsv4/rest.go +++ b/vendor/go.mongodb.org/mongo-driver/internal/aws/signer/v4/uri_path.go @@ -5,14 +5,17 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 // // Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from: -// - github.com/aws/aws-sdk-go/blob/v1.34.28/private/protocol/rest/build.go +// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/uri_path.go +// - github.com/aws/aws-sdk-go/blob/v1.44.225/private/protocol/rest/build.go // See THIRD-PARTY-NOTICES for original license terms -package awsv4 +package v4 import ( "bytes" "fmt" + "net/url" + "strings" ) // Whether the byte value can be sent without escaping in AWS URLs @@ -31,6 +34,22 @@ func init() { } } +func getURIPath(u *url.URL) string { + var uri string + + if len(u.Opaque) > 0 { + uri = "/" + strings.Join(strings.Split(u.Opaque, "/")[3:], "/") + } else { + uri = u.EscapedPath() + } + + if len(uri) == 0 { + uri = "/" + } + + return uri +} + // EscapePath escapes part of a URL path in Amazon style func EscapePath(path string, encodeSep bool) string { var buf bytes.Buffer diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/awsv4/signer.go b/vendor/go.mongodb.org/mongo-driver/internal/aws/signer/v4/v4.go similarity index 80% rename from vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/awsv4/signer.go rename to vendor/go.mongodb.org/mongo-driver/internal/aws/signer/v4/v4.go index 23508c1..6cf4586 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/awsv4/signer.go +++ b/vendor/go.mongodb.org/mongo-driver/internal/aws/signer/v4/v4.go @@ -5,13 +5,10 @@ // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 // // Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from: -// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/request/request.go -// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/signer/v4/v4.go -// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/signer/v4/uri_path.go -// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/types.go +// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/v4.go // See THIRD-PARTY-NOTICES for original license terms -package awsv4 +package v4 import ( "crypto/hmac" @@ -25,6 +22,9 @@ import ( "sort" "strings" "time" + + "go.mongodb.org/mongo-driver/internal/aws" + "go.mongodb.org/mongo-driver/internal/aws/credentials" ) const ( @@ -41,7 +41,7 @@ const ( ) var ignoredHeaders = rules{ - denylist{ + excludeList{ mapRule{ authorizationHeader: struct{}{}, "User-Agent": struct{}{}, @@ -53,13 +53,13 @@ var ignoredHeaders = rules{ // Signer applies AWS v4 signing to given request. Use this to sign requests // that need to be signed with AWS V4 Signatures. type Signer struct { - Credentials *StaticProvider + // The authentication credentials the request will be signed against. + // This value must be set to sign requests. + Credentials *credentials.Credentials } -// NewSigner returns a Signer pointer configured with the credentials and optional -// option values provided. If not options are provided the Signer will use its -// default configuration. -func NewSigner(credentials *StaticProvider) *Signer { +// NewSigner returns a Signer pointer configured with the credentials provided. +func NewSigner(credentials *credentials.Credentials) *Signer { v4 := &Signer{ Credentials: credentials, } @@ -76,7 +76,7 @@ type signingCtx struct { Time time.Time SignedHeaderVals http.Header - credValues Value + credValues credentials.Value bodyDigest string signedHeaders string @@ -85,7 +85,6 @@ type signingCtx struct { credentialString string stringToSign string signature string - authorization string } // Sign signs AWS v4 requests with the provided body, service name, region the @@ -136,7 +135,7 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi } var err error - ctx.credValues, err = v4.Credentials.Retrieve() + ctx.credValues, err = v4.Credentials.GetWithContext(r.Context()) if err != nil { return http.Header{}, err } @@ -200,31 +199,6 @@ func (ctx *signingCtx) build() error { return nil } -// GetSignedRequestSignature attempts to extract the signature of the request. -// Returning an error if the request is unsigned, or unable to extract the -// signature. -func GetSignedRequestSignature(r *http.Request) ([]byte, error) { - - if auth := r.Header.Get(authorizationHeader); len(auth) != 0 { - ps := strings.Split(auth, ", ") - for _, p := range ps { - if idx := strings.Index(p, authHeaderSignatureElem); idx >= 0 { - sig := p[len(authHeaderSignatureElem):] - if len(sig) == 0 { - return nil, fmt.Errorf("invalid request signature authorization header") - } - return hex.DecodeString(sig) - } - } - } - - if sig := r.URL.Query().Get("X-Amz-Signature"); len(sig) != 0 { - return hex.DecodeString(sig) - } - - return nil, fmt.Errorf("request not signed") -} - func (ctx *signingCtx) buildTime() { ctx.Request.Header.Set("X-Amz-Date", formatTime(ctx.Time)) } @@ -234,7 +208,7 @@ func (ctx *signingCtx) buildCredentialString() { } func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) { - headers := make([]string, 0, len(header)) + headers := make([]string, 0, len(header)+1) headers = append(headers, "host") for k, v := range header { if !r.IsValid(k) { @@ -258,37 +232,25 @@ func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) { ctx.signedHeaders = strings.Join(headers, ";") - headerValues := make([]string, len(headers)) + headerItems := make([]string, len(headers)) for i, k := range headers { if k == "host" { if ctx.Request.Host != "" { - headerValues[i] = "host:" + ctx.Request.Host + headerItems[i] = "host:" + ctx.Request.Host } else { - headerValues[i] = "host:" + ctx.Request.URL.Host + headerItems[i] = "host:" + ctx.Request.URL.Host } } else { - headerValues[i] = k + ":" + - strings.Join(ctx.SignedHeaderVals[k], ",") + headerValues := make([]string, len(ctx.SignedHeaderVals[k])) + for i, v := range ctx.SignedHeaderVals[k] { + headerValues[i] = strings.TrimSpace(v) + } + headerItems[i] = k + ":" + + strings.Join(headerValues, ",") } } - stripExcessSpaces(headerValues) - ctx.canonicalHeaders = strings.Join(headerValues, "\n") -} - -func getURIPath(u *url.URL) string { - var uri string - - if len(u.Opaque) > 0 { - uri = "/" + strings.Join(strings.Split(u.Opaque, "/")[3:], "/") - } else { - uri = u.EscapedPath() - } - - if len(uri) == 0 { - uri = "/" - } - - return uri + stripExcessSpaces(headerItems) + ctx.canonicalHeaders = strings.Join(headerItems, "\n") } func (ctx *signingCtx) buildCanonicalString() { @@ -329,6 +291,9 @@ func (ctx *signingCtx) buildBodyDigest() error { if ctx.Body == nil { hash = emptyStringSHA256 } else { + if !aws.IsReaderSeekable(ctx.Body) { + return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body) + } hashBytes, err := makeSha256Reader(ctx.Body) if err != nil { return err @@ -358,27 +323,6 @@ func hashSHA256(data []byte) []byte { return hash.Sum(nil) } -// seekerLen attempts to get the number of bytes remaining at the seeker's -// current position. Returns the number of bytes remaining or error. -func seekerLen(s io.Seeker) (int64, error) { - curOffset, err := s.Seek(0, io.SeekCurrent) - if err != nil { - return 0, err - } - - endOffset, err := s.Seek(0, io.SeekEnd) - if err != nil { - return 0, err - } - - _, err = s.Seek(curOffset, io.SeekStart) - if err != nil { - return 0, err - } - - return endOffset - curOffset, nil -} - func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) { hash := sha256.New() start, err := reader.Seek(0, io.SeekCurrent) @@ -392,7 +336,7 @@ func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) { // Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies // smaller than 32KB. Fall back to io.Copy if we fail to determine the size. - size, err := seekerLen(reader) + size, err := aws.SeekerLen(reader) if err != nil { _, _ = io.Copy(hash, reader) } else { @@ -409,6 +353,8 @@ const doubleSpace = " " func stripExcessSpaces(vals []string) { var j, k, l, m, spaces int for i, str := range vals { + // revive:disable:empty-block + // Trim trailing spaces for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- { } @@ -416,6 +362,9 @@ func stripExcessSpaces(vals []string) { // Trim leading spaces for k = 0; k < j && str[k] == ' '; k++ { } + + // revive:enable:empty-block + str = str[k : j+1] // Strip multiple spaces. diff --git a/vendor/go.mongodb.org/mongo-driver/internal/aws/types.go b/vendor/go.mongodb.org/mongo-driver/internal/aws/types.go new file mode 100644 index 0000000..52aecda --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/aws/types.go @@ -0,0 +1,153 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +// +// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from: +// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/types.go +// See THIRD-PARTY-NOTICES for original license terms + +package aws + +import ( + "io" +) + +// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Allows the +// SDK to accept an io.Reader that is not also an io.Seeker for unsigned +// streaming payload API operations. +// +// A ReadSeekCloser wrapping an nonseekable io.Reader used in an API +// operation's input will prevent that operation being retried in the case of +// network errors, and cause operation requests to fail if the operation +// requires payload signing. +// +// Note: If using With S3 PutObject to stream an object upload The SDK's S3 +// Upload manager (s3manager.Uploader) provides support for streaming with the +// ability to retry network errors. +func ReadSeekCloser(r io.Reader) ReaderSeekerCloser { + return ReaderSeekerCloser{r} +} + +// ReaderSeekerCloser represents a reader that can also delegate io.Seeker and +// io.Closer interfaces to the underlying object if they are available. +type ReaderSeekerCloser struct { + r io.Reader +} + +// IsReaderSeekable returns if the underlying reader type can be seeked. A +// io.Reader might not actually be seekable if it is the ReaderSeekerCloser +// type. +func IsReaderSeekable(r io.Reader) bool { + switch v := r.(type) { + case ReaderSeekerCloser: + return v.IsSeeker() + case *ReaderSeekerCloser: + return v.IsSeeker() + case io.ReadSeeker: + return true + default: + return false + } +} + +// Read reads from the reader up to size of p. The number of bytes read, and +// error if it occurred will be returned. +// +// If the reader is not an io.Reader zero bytes read, and nil error will be +// returned. +// +// Performs the same functionality as io.Reader Read +func (r ReaderSeekerCloser) Read(p []byte) (int, error) { + switch t := r.r.(type) { + case io.Reader: + return t.Read(p) + } + return 0, nil +} + +// Seek sets the offset for the next Read to offset, interpreted according to +// whence: 0 means relative to the origin of the file, 1 means relative to the +// current offset, and 2 means relative to the end. Seek returns the new offset +// and an error, if any. +// +// If the ReaderSeekerCloser is not an io.Seeker nothing will be done. +func (r ReaderSeekerCloser) Seek(offset int64, whence int) (int64, error) { + switch t := r.r.(type) { + case io.Seeker: + return t.Seek(offset, whence) + } + return int64(0), nil +} + +// IsSeeker returns if the underlying reader is also a seeker. +func (r ReaderSeekerCloser) IsSeeker() bool { + _, ok := r.r.(io.Seeker) + return ok +} + +// HasLen returns the length of the underlying reader if the value implements +// the Len() int method. +func (r ReaderSeekerCloser) HasLen() (int, bool) { + type lenner interface { + Len() int + } + + if lr, ok := r.r.(lenner); ok { + return lr.Len(), true + } + + return 0, false +} + +// GetLen returns the length of the bytes remaining in the underlying reader. +// Checks first for Len(), then io.Seeker to determine the size of the +// underlying reader. +// +// Will return -1 if the length cannot be determined. +func (r ReaderSeekerCloser) GetLen() (int64, error) { + if l, ok := r.HasLen(); ok { + return int64(l), nil + } + + if s, ok := r.r.(io.Seeker); ok { + return seekerLen(s) + } + + return -1, nil +} + +// SeekerLen attempts to get the number of bytes remaining at the seeker's +// current position. Returns the number of bytes remaining or error. +func SeekerLen(s io.Seeker) (int64, error) { + // Determine if the seeker is actually seekable. ReaderSeekerCloser + // hides the fact that a io.Readers might not actually be seekable. + switch v := s.(type) { + case ReaderSeekerCloser: + return v.GetLen() + case *ReaderSeekerCloser: + return v.GetLen() + } + + return seekerLen(s) +} + +func seekerLen(s io.Seeker) (int64, error) { + curOffset, err := s.Seek(0, io.SeekCurrent) + if err != nil { + return 0, err + } + + endOffset, err := s.Seek(0, io.SeekEnd) + if err != nil { + return 0, err + } + + _, err = s.Seek(curOffset, io.SeekStart) + if err != nil { + return 0, err + } + + return endOffset - curOffset, nil +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/string_util.go b/vendor/go.mongodb.org/mongo-driver/internal/bsonutil/bsonutil.go similarity index 64% rename from vendor/go.mongodb.org/mongo-driver/internal/string_util.go rename to vendor/go.mongodb.org/mongo-driver/internal/bsonutil/bsonutil.go index 6cafa79..eebb328 100644 --- a/vendor/go.mongodb.org/mongo-driver/internal/string_util.go +++ b/vendor/go.mongodb.org/mongo-driver/internal/bsonutil/bsonutil.go @@ -4,7 +4,7 @@ // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -package internal +package bsonutil import ( "fmt" @@ -12,13 +12,6 @@ import ( "go.mongodb.org/mongo-driver/bson" ) -// StringSliceFromRawElement decodes the provided BSON element into a []string. This internally calls -// StringSliceFromRawValue on the element's value. The error conditions outlined in that function's documentation -// apply for this function as well. -func StringSliceFromRawElement(element bson.RawElement) ([]string, error) { - return StringSliceFromRawValue(element.Key(), element.Value()) -} - // StringSliceFromRawValue decodes the provided BSON value into a []string. This function returns an error if the value // is not an array or any of the elements in the array are not strings. The name parameter is used to add context to // error messages. @@ -43,3 +36,27 @@ func StringSliceFromRawValue(name string, val bson.RawValue) ([]string, error) { } return strs, nil } + +// RawToDocuments converts a bson.Raw that is internally an array of documents to []bson.Raw. +func RawToDocuments(doc bson.Raw) []bson.Raw { + values, err := doc.Values() + if err != nil { + panic(fmt.Sprintf("error converting BSON document to values: %v", err)) + } + + out := make([]bson.Raw, len(values)) + for i := range values { + out[i] = values[i].Document() + } + + return out +} + +// RawToInterfaces takes one or many bson.Raw documents and returns them as a []interface{}. +func RawToInterfaces(docs ...bson.Raw) []interface{} { + out := make([]interface{}, len(docs)) + for i := range docs { + out[i] = docs[i] + } + return out +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/cancellation_listener.go b/vendor/go.mongodb.org/mongo-driver/internal/cancellation_listener.go deleted file mode 100644 index a7fa163..0000000 --- a/vendor/go.mongodb.org/mongo-driver/internal/cancellation_listener.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package internal - -import "context" - -// CancellationListener listens for context cancellation in a loop until the context expires or the listener is aborted. -type CancellationListener struct { - aborted bool - done chan struct{} -} - -// NewCancellationListener constructs a CancellationListener. -func NewCancellationListener() *CancellationListener { - return &CancellationListener{ - done: make(chan struct{}), - } -} - -// Listen blocks until the provided context is cancelled or listening is aborted via the StopListening function. If this -// detects that the context has been cancelled (i.e. ctx.Err() == context.Canceled), the provided callback is called to -// abort in-progress work. Even if the context expires, this function will block until StopListening is called. -func (c *CancellationListener) Listen(ctx context.Context, abortFn func()) { - c.aborted = false - - select { - case <-ctx.Done(): - if ctx.Err() == context.Canceled { - c.aborted = true - abortFn() - } - - <-c.done - case <-c.done: - } -} - -// StopListening stops the in-progress Listen call. This blocks if there is no in-progress Listen call. This function -// will return true if the provided abort callback was called when listening for cancellation on the previous context. -func (c *CancellationListener) StopListening() bool { - c.done <- struct{}{} - return c.aborted -} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/codecutil/encoding.go b/vendor/go.mongodb.org/mongo-driver/internal/codecutil/encoding.go new file mode 100644 index 0000000..2aaf8f2 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/codecutil/encoding.go @@ -0,0 +1,65 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package codecutil + +import ( + "bytes" + "errors" + "fmt" + "io" + "reflect" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" +) + +var ErrNilValue = errors.New("value is nil") + +// MarshalError is returned when attempting to transform a value into a document +// results in an error. +type MarshalError struct { + Value interface{} + Err error +} + +// Error implements the error interface. +func (e MarshalError) Error() string { + return fmt.Sprintf("cannot transform type %s to a BSON Document: %v", + reflect.TypeOf(e.Value), e.Err) +} + +// EncoderFn is used to functionally construct an encoder for marshaling values. +type EncoderFn func(io.Writer) (*bson.Encoder, error) + +// MarshalValue will attempt to encode the value with the encoder returned by +// the encoder function. +func MarshalValue(val interface{}, encFn EncoderFn) (bsoncore.Value, error) { + // If the val is already a bsoncore.Value, then do nothing. + if bval, ok := val.(bsoncore.Value); ok { + return bval, nil + } + + if val == nil { + return bsoncore.Value{}, ErrNilValue + } + + buf := new(bytes.Buffer) + + enc, err := encFn(buf) + if err != nil { + return bsoncore.Value{}, err + } + + // Encode the value in a single-element document with an empty key. Use + // bsoncore to extract the first element and return the BSON value. + err = enc.Encode(bson.D{{Key: "", Value: val}}) + if err != nil { + return bsoncore.Value{}, MarshalError{Value: val, Err: err} + } + + return bsoncore.Document(buf.Bytes()).Index(0).Value(), nil +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/credproviders/assume_role_provider.go b/vendor/go.mongodb.org/mongo-driver/internal/credproviders/assume_role_provider.go new file mode 100644 index 0000000..3a95cf4 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/credproviders/assume_role_provider.go @@ -0,0 +1,148 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package credproviders + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "time" + + "go.mongodb.org/mongo-driver/internal/aws/credentials" + "go.mongodb.org/mongo-driver/internal/uuid" +) + +const ( + // assumeRoleProviderName provides a name of assume role provider + assumeRoleProviderName = "AssumeRoleProvider" + + stsURI = `https://sts.amazonaws.com/?Action=AssumeRoleWithWebIdentity&RoleSessionName=%s&RoleArn=%s&WebIdentityToken=%s&Version=2011-06-15` +) + +// An AssumeRoleProvider retrieves credentials for assume role with web identity. +type AssumeRoleProvider struct { + AwsRoleArnEnv EnvVar + AwsWebIdentityTokenFileEnv EnvVar + AwsRoleSessionNameEnv EnvVar + + httpClient *http.Client + expiration time.Time + + // expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring. + // This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions. + // + // So a ExpiryWindow of 10s would cause calls to IsExpired() to return true + // 10 seconds before the credentials are actually expired. + expiryWindow time.Duration +} + +// NewAssumeRoleProvider returns a pointer to an assume role provider. +func NewAssumeRoleProvider(httpClient *http.Client, expiryWindow time.Duration) *AssumeRoleProvider { + return &AssumeRoleProvider{ + // AwsRoleArnEnv is the environment variable for AWS_ROLE_ARN + AwsRoleArnEnv: EnvVar("AWS_ROLE_ARN"), + // AwsWebIdentityTokenFileEnv is the environment variable for AWS_WEB_IDENTITY_TOKEN_FILE + AwsWebIdentityTokenFileEnv: EnvVar("AWS_WEB_IDENTITY_TOKEN_FILE"), + // AwsRoleSessionNameEnv is the environment variable for AWS_ROLE_SESSION_NAME + AwsRoleSessionNameEnv: EnvVar("AWS_ROLE_SESSION_NAME"), + httpClient: httpClient, + expiryWindow: expiryWindow, + } +} + +// RetrieveWithContext retrieves the keys from the AWS service. +func (a *AssumeRoleProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { + const defaultHTTPTimeout = 10 * time.Second + + v := credentials.Value{ProviderName: assumeRoleProviderName} + + roleArn := a.AwsRoleArnEnv.Get() + tokenFile := a.AwsWebIdentityTokenFileEnv.Get() + if tokenFile == "" && roleArn == "" { + return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_ARN are missing") + } + if tokenFile != "" && roleArn == "" { + return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE is set, but AWS_ROLE_ARN is missing") + } + if tokenFile == "" && roleArn != "" { + return v, errors.New("AWS_ROLE_ARN is set, but AWS_WEB_IDENTITY_TOKEN_FILE is missing") + } + token, err := ioutil.ReadFile(tokenFile) + if err != nil { + return v, err + } + + sessionName := a.AwsRoleSessionNameEnv.Get() + if sessionName == "" { + // Use a UUID if the RoleSessionName is not given. + id, err := uuid.New() + if err != nil { + return v, err + } + sessionName = id.String() + } + + fullURI := fmt.Sprintf(stsURI, sessionName, roleArn, string(token)) + + req, err := http.NewRequest(http.MethodPost, fullURI, nil) + if err != nil { + return v, err + } + req.Header.Set("Accept", "application/json") + + ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout) + defer cancel() + resp, err := a.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return v, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return v, fmt.Errorf("response failure: %s", resp.Status) + } + + var stsResp struct { + Response struct { + Result struct { + Credentials struct { + AccessKeyID string `json:"AccessKeyId"` + SecretAccessKey string `json:"SecretAccessKey"` + Token string `json:"SessionToken"` + Expiration float64 `json:"Expiration"` + } `json:"Credentials"` + } `json:"AssumeRoleWithWebIdentityResult"` + } `json:"AssumeRoleWithWebIdentityResponse"` + } + + err = json.NewDecoder(resp.Body).Decode(&stsResp) + if err != nil { + return v, err + } + v.AccessKeyID = stsResp.Response.Result.Credentials.AccessKeyID + v.SecretAccessKey = stsResp.Response.Result.Credentials.SecretAccessKey + v.SessionToken = stsResp.Response.Result.Credentials.Token + if !v.HasKeys() { + return v, errors.New("failed to retrieve web identity keys") + } + sec := int64(stsResp.Response.Result.Credentials.Expiration) + a.expiration = time.Unix(sec, 0).Add(-a.expiryWindow) + + return v, nil +} + +// Retrieve retrieves the keys from the AWS service. +func (a *AssumeRoleProvider) Retrieve() (credentials.Value, error) { + return a.RetrieveWithContext(context.Background()) +} + +// IsExpired returns true if the credentials are expired. +func (a *AssumeRoleProvider) IsExpired() bool { + return a.expiration.Before(time.Now()) +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/credproviders/ec2_provider.go b/vendor/go.mongodb.org/mongo-driver/internal/credproviders/ec2_provider.go new file mode 100644 index 0000000..771bfca --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/credproviders/ec2_provider.go @@ -0,0 +1,183 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package credproviders + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "time" + + "go.mongodb.org/mongo-driver/internal/aws/credentials" +) + +const ( + // ec2ProviderName provides a name of EC2 provider + ec2ProviderName = "EC2Provider" + + awsEC2URI = "http://169.254.169.254/" + awsEC2RolePath = "latest/meta-data/iam/security-credentials/" + awsEC2TokenPath = "latest/api/token" + + defaultHTTPTimeout = 10 * time.Second +) + +// An EC2Provider retrieves credentials from EC2 metadata. +type EC2Provider struct { + httpClient *http.Client + expiration time.Time + + // expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring. + // This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions. + // + // So a ExpiryWindow of 10s would cause calls to IsExpired() to return true + // 10 seconds before the credentials are actually expired. + expiryWindow time.Duration +} + +// NewEC2Provider returns a pointer to an EC2 credential provider. +func NewEC2Provider(httpClient *http.Client, expiryWindow time.Duration) *EC2Provider { + return &EC2Provider{ + httpClient: httpClient, + expiryWindow: expiryWindow, + } +} + +func (e *EC2Provider) getToken(ctx context.Context) (string, error) { + req, err := http.NewRequest(http.MethodPut, awsEC2URI+awsEC2TokenPath, nil) + if err != nil { + return "", err + } + const defaultEC2TTLSeconds = "30" + req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", defaultEC2TTLSeconds) + + ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout) + defer cancel() + resp, err := e.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status) + } + + token, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", err + } + if len(token) == 0 { + return "", errors.New("unable to retrieve token from EC2 metadata") + } + return string(token), nil +} + +func (e *EC2Provider) getRoleName(ctx context.Context, token string) (string, error) { + req, err := http.NewRequest(http.MethodGet, awsEC2URI+awsEC2RolePath, nil) + if err != nil { + return "", err + } + req.Header.Set("X-aws-ec2-metadata-token", token) + + ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout) + defer cancel() + resp, err := e.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status) + } + + role, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", err + } + if len(role) == 0 { + return "", errors.New("unable to retrieve role_name from EC2 metadata") + } + return string(role), nil +} + +func (e *EC2Provider) getCredentials(ctx context.Context, token string, role string) (credentials.Value, time.Time, error) { + v := credentials.Value{ProviderName: ec2ProviderName} + + pathWithRole := awsEC2URI + awsEC2RolePath + role + req, err := http.NewRequest(http.MethodGet, pathWithRole, nil) + if err != nil { + return v, time.Time{}, err + } + req.Header.Set("X-aws-ec2-metadata-token", token) + ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout) + defer cancel() + resp, err := e.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return v, time.Time{}, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return v, time.Time{}, fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status) + } + + var ec2Resp struct { + AccessKeyID string `json:"AccessKeyId"` + SecretAccessKey string `json:"SecretAccessKey"` + Token string `json:"Token"` + Expiration time.Time `json:"Expiration"` + } + + err = json.NewDecoder(resp.Body).Decode(&ec2Resp) + if err != nil { + return v, time.Time{}, err + } + + v.AccessKeyID = ec2Resp.AccessKeyID + v.SecretAccessKey = ec2Resp.SecretAccessKey + v.SessionToken = ec2Resp.Token + + return v, ec2Resp.Expiration, nil +} + +// RetrieveWithContext retrieves the keys from the AWS service. +func (e *EC2Provider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { + v := credentials.Value{ProviderName: ec2ProviderName} + + token, err := e.getToken(ctx) + if err != nil { + return v, err + } + + role, err := e.getRoleName(ctx, token) + if err != nil { + return v, err + } + + v, exp, err := e.getCredentials(ctx, token, role) + if err != nil { + return v, err + } + if !v.HasKeys() { + return v, errors.New("failed to retrieve EC2 keys") + } + e.expiration = exp.Add(-e.expiryWindow) + + return v, nil +} + +// Retrieve retrieves the keys from the AWS service. +func (e *EC2Provider) Retrieve() (credentials.Value, error) { + return e.RetrieveWithContext(context.Background()) +} + +// IsExpired returns true if the credentials are expired. +func (e *EC2Provider) IsExpired() bool { + return e.expiration.Before(time.Now()) +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/credproviders/ecs_provider.go b/vendor/go.mongodb.org/mongo-driver/internal/credproviders/ecs_provider.go new file mode 100644 index 0000000..0c3a27e --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/credproviders/ecs_provider.go @@ -0,0 +1,112 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package credproviders + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "time" + + "go.mongodb.org/mongo-driver/internal/aws/credentials" +) + +const ( + // ecsProviderName provides a name of ECS provider + ecsProviderName = "ECSProvider" + + awsRelativeURI = "http://169.254.170.2/" +) + +// An ECSProvider retrieves credentials from ECS metadata. +type ECSProvider struct { + AwsContainerCredentialsRelativeURIEnv EnvVar + + httpClient *http.Client + expiration time.Time + + // expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring. + // This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions. + // + // So a ExpiryWindow of 10s would cause calls to IsExpired() to return true + // 10 seconds before the credentials are actually expired. + expiryWindow time.Duration +} + +// NewECSProvider returns a pointer to an ECS credential provider. +func NewECSProvider(httpClient *http.Client, expiryWindow time.Duration) *ECSProvider { + return &ECSProvider{ + // AwsContainerCredentialsRelativeURIEnv is the environment variable for AWS_CONTAINER_CREDENTIALS_RELATIVE_URI + AwsContainerCredentialsRelativeURIEnv: EnvVar("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"), + httpClient: httpClient, + expiryWindow: expiryWindow, + } +} + +// RetrieveWithContext retrieves the keys from the AWS service. +func (e *ECSProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { + const defaultHTTPTimeout = 10 * time.Second + + v := credentials.Value{ProviderName: ecsProviderName} + + relativeEcsURI := e.AwsContainerCredentialsRelativeURIEnv.Get() + if len(relativeEcsURI) == 0 { + return v, errors.New("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI is missing") + } + fullURI := awsRelativeURI + relativeEcsURI + + req, err := http.NewRequest(http.MethodGet, fullURI, nil) + if err != nil { + return v, err + } + req.Header.Set("Accept", "application/json") + + ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout) + defer cancel() + resp, err := e.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return v, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return v, fmt.Errorf("response failure: %s", resp.Status) + } + + var ecsResp struct { + AccessKeyID string `json:"AccessKeyId"` + SecretAccessKey string `json:"SecretAccessKey"` + Token string `json:"Token"` + Expiration time.Time `json:"Expiration"` + } + + err = json.NewDecoder(resp.Body).Decode(&ecsResp) + if err != nil { + return v, err + } + + v.AccessKeyID = ecsResp.AccessKeyID + v.SecretAccessKey = ecsResp.SecretAccessKey + v.SessionToken = ecsResp.Token + if !v.HasKeys() { + return v, errors.New("failed to retrieve ECS keys") + } + e.expiration = ecsResp.Expiration.Add(-e.expiryWindow) + + return v, nil +} + +// Retrieve retrieves the keys from the AWS service. +func (e *ECSProvider) Retrieve() (credentials.Value, error) { + return e.RetrieveWithContext(context.Background()) +} + +// IsExpired returns true if the credentials are expired. +func (e *ECSProvider) IsExpired() bool { + return e.expiration.Before(time.Now()) +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/credproviders/env_provider.go b/vendor/go.mongodb.org/mongo-driver/internal/credproviders/env_provider.go new file mode 100644 index 0000000..59ca633 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/credproviders/env_provider.go @@ -0,0 +1,69 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package credproviders + +import ( + "os" + + "go.mongodb.org/mongo-driver/internal/aws/credentials" +) + +// envProviderName provides a name of Env provider +const envProviderName = "EnvProvider" + +// EnvVar is an environment variable +type EnvVar string + +// Get retrieves the environment variable +func (ev EnvVar) Get() string { + return os.Getenv(string(ev)) +} + +// A EnvProvider retrieves credentials from the environment variables of the +// running process. Environment credentials never expire. +type EnvProvider struct { + AwsAccessKeyIDEnv EnvVar + AwsSecretAccessKeyEnv EnvVar + AwsSessionTokenEnv EnvVar + + retrieved bool +} + +// NewEnvProvider returns a pointer to an ECS credential provider. +func NewEnvProvider() *EnvProvider { + return &EnvProvider{ + // AwsAccessKeyIDEnv is the environment variable for AWS_ACCESS_KEY_ID + AwsAccessKeyIDEnv: EnvVar("AWS_ACCESS_KEY_ID"), + // AwsSecretAccessKeyEnv is the environment variable for AWS_SECRET_ACCESS_KEY + AwsSecretAccessKeyEnv: EnvVar("AWS_SECRET_ACCESS_KEY"), + // AwsSessionTokenEnv is the environment variable for AWS_SESSION_TOKEN + AwsSessionTokenEnv: EnvVar("AWS_SESSION_TOKEN"), + } +} + +// Retrieve retrieves the keys from the environment. +func (e *EnvProvider) Retrieve() (credentials.Value, error) { + e.retrieved = false + + v := credentials.Value{ + AccessKeyID: e.AwsAccessKeyIDEnv.Get(), + SecretAccessKey: e.AwsSecretAccessKeyEnv.Get(), + SessionToken: e.AwsSessionTokenEnv.Get(), + ProviderName: envProviderName, + } + err := verify(v) + if err == nil { + e.retrieved = true + } + + return v, err +} + +// IsExpired returns true if the credentials have not been retrieved. +func (e *EnvProvider) IsExpired() bool { + return !e.retrieved +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/credproviders/imds_provider.go b/vendor/go.mongodb.org/mongo-driver/internal/credproviders/imds_provider.go new file mode 100644 index 0000000..96dad1a --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/credproviders/imds_provider.go @@ -0,0 +1,103 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package credproviders + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "time" + + "go.mongodb.org/mongo-driver/internal/aws/credentials" +) + +const ( + // AzureProviderName provides a name of Azure provider + AzureProviderName = "AzureProvider" + + azureURI = "http://169.254.169.254/metadata/identity/oauth2/token" +) + +// An AzureProvider retrieves credentials from Azure IMDS. +type AzureProvider struct { + httpClient *http.Client + expiration time.Time + expiryWindow time.Duration +} + +// NewAzureProvider returns a pointer to an Azure credential provider. +func NewAzureProvider(httpClient *http.Client, expiryWindow time.Duration) *AzureProvider { + return &AzureProvider{ + httpClient: httpClient, + expiration: time.Time{}, + expiryWindow: expiryWindow, + } +} + +// RetrieveWithContext retrieves the keys from the Azure service. +func (a *AzureProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { + v := credentials.Value{ProviderName: AzureProviderName} + req, err := http.NewRequest(http.MethodGet, azureURI, nil) + if err != nil { + return v, fmt.Errorf("unable to retrieve Azure credentials: %w", err) + } + q := make(url.Values) + q.Set("api-version", "2018-02-01") + q.Set("resource", "https://vault.azure.net") + req.URL.RawQuery = q.Encode() + req.Header.Set("Metadata", "true") + req.Header.Set("Accept", "application/json") + + resp, err := a.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return v, fmt.Errorf("unable to retrieve Azure credentials: %w", err) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return v, fmt.Errorf("unable to retrieve Azure credentials: error reading response body: %w", err) + } + if resp.StatusCode != http.StatusOK { + return v, fmt.Errorf("unable to retrieve Azure credentials: expected StatusCode 200, got StatusCode: %v. Response body: %s", resp.StatusCode, body) + } + var tokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn string `json:"expires_in"` + } + // Attempt to read body as JSON + err = json.Unmarshal(body, &tokenResponse) + if err != nil { + return v, fmt.Errorf("unable to retrieve Azure credentials: error reading body JSON: %w (response body: %s)", err, body) + } + if tokenResponse.AccessToken == "" { + return v, fmt.Errorf("unable to retrieve Azure credentials: got unexpected empty accessToken from Azure Metadata Server. Response body: %s", body) + } + v.SessionToken = tokenResponse.AccessToken + + expiresIn, err := time.ParseDuration(tokenResponse.ExpiresIn + "s") + if err != nil { + return v, err + } + if expiration := expiresIn - a.expiryWindow; expiration > 0 { + a.expiration = time.Now().Add(expiration) + } + + return v, err +} + +// Retrieve retrieves the keys from the Azure service. +func (a *AzureProvider) Retrieve() (credentials.Value, error) { + return a.RetrieveWithContext(context.Background()) +} + +// IsExpired returns if the credentials have been retrieved. +func (a *AzureProvider) IsExpired() bool { + return a.expiration.Before(time.Now()) +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/credproviders/static_provider.go b/vendor/go.mongodb.org/mongo-driver/internal/credproviders/static_provider.go new file mode 100644 index 0000000..6b49613 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/credproviders/static_provider.go @@ -0,0 +1,59 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package credproviders + +import ( + "errors" + + "go.mongodb.org/mongo-driver/internal/aws/credentials" +) + +// staticProviderName provides a name of Static provider +const staticProviderName = "StaticProvider" + +// A StaticProvider is a set of credentials which are set programmatically, +// and will never expire. +type StaticProvider struct { + credentials.Value + + verified bool + err error +} + +func verify(v credentials.Value) error { + if !v.HasKeys() { + return errors.New("failed to retrieve ACCESS_KEY_ID and SECRET_ACCESS_KEY") + } + if v.AccessKeyID != "" && v.SecretAccessKey == "" { + return errors.New("ACCESS_KEY_ID is set, but SECRET_ACCESS_KEY is missing") + } + if v.AccessKeyID == "" && v.SecretAccessKey != "" { + return errors.New("SECRET_ACCESS_KEY is set, but ACCESS_KEY_ID is missing") + } + if v.AccessKeyID == "" && v.SecretAccessKey == "" && v.SessionToken != "" { + return errors.New("AWS_SESSION_TOKEN is set, but ACCESS_KEY_ID and SECRET_ACCESS_KEY are missing") + } + return nil + +} + +// Retrieve returns the credentials or error if the credentials are invalid. +func (s *StaticProvider) Retrieve() (credentials.Value, error) { + if !s.verified { + s.err = verify(s.Value) + s.Value.ProviderName = staticProviderName + s.verified = true + } + return s.Value, s.err +} + +// IsExpired returns if the credentials are expired. +// +// For StaticProvider, the credentials never expired. +func (s *StaticProvider) IsExpired() bool { + return false +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/csfle/csfle.go b/vendor/go.mongodb.org/mongo-driver/internal/csfle/csfle.go new file mode 100644 index 0000000..20a6d43 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/csfle/csfle.go @@ -0,0 +1,40 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package csfle + +import ( + "errors" + "fmt" + + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" +) + +const ( + EncryptedCacheCollection = "ecc" + EncryptedStateCollection = "esc" + EncryptedCompactionCollection = "ecoc" +) + +// GetEncryptedStateCollectionName returns the encrypted state collection name associated with dataCollectionName. +func GetEncryptedStateCollectionName(efBSON bsoncore.Document, dataCollectionName string, stateCollection string) (string, error) { + fieldName := stateCollection + "Collection" + val, err := efBSON.LookupErr(fieldName) + if err != nil { + if !errors.Is(err, bsoncore.ErrElementNotFound) { + return "", err + } + // Return default name. + defaultName := "enxcol_." + dataCollectionName + "." + stateCollection + return defaultName, nil + } + + stateCollectionName, ok := val.StringValueOK() + if !ok { + return "", fmt.Errorf("expected string for '%v', got: %v", fieldName, val.Type) + } + return stateCollectionName, nil +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/csot/csot.go b/vendor/go.mongodb.org/mongo-driver/internal/csot/csot.go new file mode 100644 index 0000000..43801a5 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/csot/csot.go @@ -0,0 +1,60 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package csot + +import ( + "context" + "time" +) + +type timeoutKey struct{} + +// MakeTimeoutContext returns a new context with Client-Side Operation Timeout (CSOT) feature-gated behavior +// and a Timeout set to the passed in Duration. Setting a Timeout on a single operation is not supported in +// public API. +// +// TODO(GODRIVER-2348) We may be able to remove this function once CSOT feature-gated behavior becomes the +// TODO default behavior. +func MakeTimeoutContext(ctx context.Context, to time.Duration) (context.Context, context.CancelFunc) { + // Only use the passed in Duration as a timeout on the Context if it + // is non-zero and if the Context doesn't already have a timeout. + cancelFunc := func() {} + if _, deadlineSet := ctx.Deadline(); to != 0 && !deadlineSet { + ctx, cancelFunc = context.WithTimeout(ctx, to) + } + + // Add timeoutKey either way to indicate CSOT is enabled. + return context.WithValue(ctx, timeoutKey{}, true), cancelFunc +} + +func IsTimeoutContext(ctx context.Context) bool { + return ctx.Value(timeoutKey{}) != nil +} + +// ZeroRTTMonitor implements the RTTMonitor interface and is used internally for testing. It returns 0 for all +// RTT calculations and an empty string for RTT statistics. +type ZeroRTTMonitor struct{} + +// EWMA implements the RTT monitor interface. +func (zrm *ZeroRTTMonitor) EWMA() time.Duration { + return 0 +} + +// Min implements the RTT monitor interface. +func (zrm *ZeroRTTMonitor) Min() time.Duration { + return 0 +} + +// P90 implements the RTT monitor interface. +func (zrm *ZeroRTTMonitor) P90() time.Duration { + return 0 +} + +// Stats implements the RTT monitor interface. +func (zrm *ZeroRTTMonitor) Stats() string { + return "" +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/driverutil/hello.go b/vendor/go.mongodb.org/mongo-driver/internal/driverutil/hello.go new file mode 100644 index 0000000..18a70f0 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/driverutil/hello.go @@ -0,0 +1,128 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package driverutil + +import ( + "os" + "strings" +) + +const AwsLambdaPrefix = "AWS_Lambda_" + +const ( + // FaaS environment variable names + + // EnvVarAWSExecutionEnv is the AWS Execution environment variable. + EnvVarAWSExecutionEnv = "AWS_EXECUTION_ENV" + // EnvVarAWSLambdaRuntimeAPI is the AWS Lambda runtime API variable. + EnvVarAWSLambdaRuntimeAPI = "AWS_LAMBDA_RUNTIME_API" + // EnvVarFunctionsWorkerRuntime is the functions worker runtime variable. + EnvVarFunctionsWorkerRuntime = "FUNCTIONS_WORKER_RUNTIME" + // EnvVarKService is the K Service variable. + EnvVarKService = "K_SERVICE" + // EnvVarFunctionName is the function name variable. + EnvVarFunctionName = "FUNCTION_NAME" + // EnvVarVercel is the Vercel variable. + EnvVarVercel = "VERCEL" + // EnvVarK8s is the K8s variable. + EnvVarK8s = "KUBERNETES_SERVICE_HOST" +) + +const ( + // FaaS environment variable names + + // EnvVarAWSRegion is the AWS region variable. + EnvVarAWSRegion = "AWS_REGION" + // EnvVarAWSLambdaFunctionMemorySize is the AWS Lambda function memory size variable. + EnvVarAWSLambdaFunctionMemorySize = "AWS_LAMBDA_FUNCTION_MEMORY_SIZE" + // EnvVarFunctionMemoryMB is the function memory in megabytes variable. + EnvVarFunctionMemoryMB = "FUNCTION_MEMORY_MB" + // EnvVarFunctionTimeoutSec is the function timeout in seconds variable. + EnvVarFunctionTimeoutSec = "FUNCTION_TIMEOUT_SEC" + // EnvVarFunctionRegion is the function region variable. + EnvVarFunctionRegion = "FUNCTION_REGION" + // EnvVarVercelRegion is the Vercel region variable. + EnvVarVercelRegion = "VERCEL_REGION" +) + +const ( + // FaaS environment names used by the client + + // EnvNameAWSLambda is the AWS Lambda environment name. + EnvNameAWSLambda = "aws.lambda" + // EnvNameAzureFunc is the Azure Function environment name. + EnvNameAzureFunc = "azure.func" + // EnvNameGCPFunc is the Google Cloud Function environment name. + EnvNameGCPFunc = "gcp.func" + // EnvNameVercel is the Vercel environment name. + EnvNameVercel = "vercel" +) + +// GetFaasEnvName parses the FaaS environment variable name and returns the +// corresponding name used by the client. If none of the variables or variables +// for multiple names are populated the client.env value MUST be entirely +// omitted. When variables for multiple "client.env.name" values are present, +// "vercel" takes precedence over "aws.lambda"; any other combination MUST cause +// "client.env" to be entirely omitted. +func GetFaasEnvName() string { + envVars := []string{ + EnvVarAWSExecutionEnv, + EnvVarAWSLambdaRuntimeAPI, + EnvVarFunctionsWorkerRuntime, + EnvVarKService, + EnvVarFunctionName, + EnvVarVercel, + } + + // If none of the variables are populated the client.env value MUST be + // entirely omitted. + names := make(map[string]struct{}) + + for _, envVar := range envVars { + val := os.Getenv(envVar) + if val == "" { + continue + } + + var name string + + switch envVar { + case EnvVarAWSExecutionEnv: + if !strings.HasPrefix(val, AwsLambdaPrefix) { + continue + } + + name = EnvNameAWSLambda + case EnvVarAWSLambdaRuntimeAPI: + name = EnvNameAWSLambda + case EnvVarFunctionsWorkerRuntime: + name = EnvNameAzureFunc + case EnvVarKService, EnvVarFunctionName: + name = EnvNameGCPFunc + case EnvVarVercel: + // "vercel" takes precedence over "aws.lambda". + delete(names, EnvNameAWSLambda) + + name = EnvNameVercel + } + + names[name] = struct{}{} + if len(names) > 1 { + // If multiple names are populated the client.env value + // MUST be entirely omitted. + names = nil + + break + } + } + + for name := range names { + return name + } + + return "" +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/driverutil/operation.go b/vendor/go.mongodb.org/mongo-driver/internal/driverutil/operation.go new file mode 100644 index 0000000..3270431 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/driverutil/operation.go @@ -0,0 +1,31 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package driverutil + +// Operation Names should be sourced from the command reference documentation: +// https://www.mongodb.com/docs/manual/reference/command/ +const ( + AbortTransactionOp = "abortTransaction" // AbortTransactionOp is the name for aborting a transaction + AggregateOp = "aggregate" // AggregateOp is the name for aggregating + CommitTransactionOp = "commitTransaction" // CommitTransactionOp is the name for committing a transaction + CountOp = "count" // CountOp is the name for counting + CreateOp = "create" // CreateOp is the name for creating + CreateIndexesOp = "createIndexes" // CreateIndexesOp is the name for creating indexes + DeleteOp = "delete" // DeleteOp is the name for deleting + DistinctOp = "distinct" // DistinctOp is the name for distinct + DropOp = "drop" // DropOp is the name for dropping + DropDatabaseOp = "dropDatabase" // DropDatabaseOp is the name for dropping a database + DropIndexesOp = "dropIndexes" // DropIndexesOp is the name for dropping indexes + EndSessionsOp = "endSessions" // EndSessionsOp is the name for ending sessions + FindAndModifyOp = "findAndModify" // FindAndModifyOp is the name for finding and modifying + FindOp = "find" // FindOp is the name for finding + InsertOp = "insert" // InsertOp is the name for inserting + ListCollectionsOp = "listCollections" // ListCollectionsOp is the name for listing collections + ListIndexesOp = "listIndexes" // ListIndexesOp is the name for listing indexes + ListDatabasesOp = "listDatabases" // ListDatabasesOp is the name for listing databases + UpdateOp = "update" // UpdateOp is the name for updating +) diff --git a/vendor/go.mongodb.org/mongo-driver/internal/error.go b/vendor/go.mongodb.org/mongo-driver/internal/error.go deleted file mode 100644 index 6a105af..0000000 --- a/vendor/go.mongodb.org/mongo-driver/internal/error.go +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package internal - -import ( - "fmt" -) - -// WrappedError represents an error that contains another error. -type WrappedError interface { - // Message gets the basic message of the error. - Message() string - // Inner gets the inner error if one exists. - Inner() error -} - -// RolledUpErrorMessage gets a flattened error message. -func RolledUpErrorMessage(err error) string { - if wrappedErr, ok := err.(WrappedError); ok { - inner := wrappedErr.Inner() - if inner != nil { - return fmt.Sprintf("%s: %s", wrappedErr.Message(), RolledUpErrorMessage(inner)) - } - - return wrappedErr.Message() - } - - return err.Error() -} - -//UnwrapError attempts to unwrap the error down to its root cause. -func UnwrapError(err error) error { - - switch tErr := err.(type) { - case WrappedError: - return UnwrapError(tErr.Inner()) - case *multiError: - return UnwrapError(tErr.errors[0]) - } - - return err -} - -// WrapError wraps an error with a message. -func WrapError(inner error, message string) error { - return &wrappedError{message, inner} -} - -// WrapErrorf wraps an error with a message. -func WrapErrorf(inner error, format string, args ...interface{}) error { - return &wrappedError{fmt.Sprintf(format, args...), inner} -} - -// MultiError combines multiple errors into a single error. If there are no errors, -// nil is returned. If there is 1 error, it is returned. Otherwise, they are combined. -func MultiError(errors ...error) error { - - // remove nils from the error list - var nonNils []error - for _, e := range errors { - if e != nil { - nonNils = append(nonNils, e) - } - } - - switch len(nonNils) { - case 0: - return nil - case 1: - return nonNils[0] - default: - return &multiError{ - message: "multiple errors encountered", - errors: nonNils, - } - } -} - -type multiError struct { - message string - errors []error -} - -func (e *multiError) Message() string { - return e.message -} - -func (e *multiError) Error() string { - result := e.message - for _, e := range e.errors { - result += fmt.Sprintf("\n %s", e) - } - return result -} - -func (e *multiError) Errors() []error { - return e.errors -} - -type wrappedError struct { - message string - inner error -} - -func (e *wrappedError) Message() string { - return e.message -} - -func (e *wrappedError) Error() string { - return RolledUpErrorMessage(e) -} - -func (e *wrappedError) Inner() error { - return e.inner -} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/const.go b/vendor/go.mongodb.org/mongo-driver/internal/handshake/handshake.go similarity index 64% rename from vendor/go.mongodb.org/mongo-driver/internal/const.go rename to vendor/go.mongodb.org/mongo-driver/internal/handshake/handshake.go index a7ef69d..c9537d3 100644 --- a/vendor/go.mongodb.org/mongo-driver/internal/const.go +++ b/vendor/go.mongodb.org/mongo-driver/internal/handshake/handshake.go @@ -4,16 +4,10 @@ // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -package internal // import "go.mongodb.org/mongo-driver/internal" - -// Version is the current version of the driver. -var Version = "local build" +package handshake // LegacyHello is the legacy version of the hello command. var LegacyHello = "isMaster" // LegacyHelloLowercase is the lowercase, legacy version of the hello command. var LegacyHelloLowercase = "ismaster" - -// LegacyNotPrimary is the legacy version of the "not primary" server error message. -var LegacyNotPrimary = "not master" diff --git a/vendor/go.mongodb.org/mongo-driver/internal/httputil/httputil.go b/vendor/go.mongodb.org/mongo-driver/internal/httputil/httputil.go new file mode 100644 index 0000000..db0dd5f --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/httputil/httputil.go @@ -0,0 +1,30 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package httputil + +import ( + "net/http" +) + +// DefaultHTTPClient is the default HTTP client used across the driver. +var DefaultHTTPClient = &http.Client{ + Transport: http.DefaultTransport.(*http.Transport).Clone(), +} + +// CloseIdleHTTPConnections closes any connections which were previously +// connected from previous requests but are now sitting idle in a "keep-alive" +// state. It does not interrupt any connections currently in use. +// +// Borrowed from the Go standard library. +func CloseIdleHTTPConnections(client *http.Client) { + type closeIdler interface { + CloseIdleConnections() + } + if tr, ok := client.Transport.(closeIdler); ok { + tr.CloseIdleConnections() + } +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/logger/component.go b/vendor/go.mongodb.org/mongo-driver/internal/logger/component.go new file mode 100644 index 0000000..0a3d553 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/logger/component.go @@ -0,0 +1,314 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package logger + +import ( + "os" + "strconv" + + "go.mongodb.org/mongo-driver/bson/primitive" +) + +const ( + CommandFailed = "Command failed" + CommandStarted = "Command started" + CommandSucceeded = "Command succeeded" + ConnectionPoolCreated = "Connection pool created" + ConnectionPoolReady = "Connection pool ready" + ConnectionPoolCleared = "Connection pool cleared" + ConnectionPoolClosed = "Connection pool closed" + ConnectionCreated = "Connection created" + ConnectionReady = "Connection ready" + ConnectionClosed = "Connection closed" + ConnectionCheckoutStarted = "Connection checkout started" + ConnectionCheckoutFailed = "Connection checkout failed" + ConnectionCheckedOut = "Connection checked out" + ConnectionCheckedIn = "Connection checked in" + ServerSelectionFailed = "Server selection failed" + ServerSelectionStarted = "Server selection started" + ServerSelectionSucceeded = "Server selection succeeded" + ServerSelectionWaiting = "Waiting for suitable server to become available" + TopologyClosed = "Stopped topology monitoring" + TopologyDescriptionChanged = "Topology description changed" + TopologyOpening = "Starting topology monitoring" + TopologyServerClosed = "Stopped server monitoring" + TopologyServerHeartbeatFailed = "Server heartbeat failed" + TopologyServerHeartbeatStarted = "Server heartbeat started" + TopologyServerHeartbeatSucceeded = "Server heartbeat succeeded" + TopologyServerOpening = "Starting server monitoring" +) + +const ( + KeyAwaited = "awaited" + KeyCommand = "command" + KeyCommandName = "commandName" + KeyDatabaseName = "databaseName" + KeyDriverConnectionID = "driverConnectionId" + KeyDurationMS = "durationMS" + KeyError = "error" + KeyFailure = "failure" + KeyMaxConnecting = "maxConnecting" + KeyMaxIdleTimeMS = "maxIdleTimeMS" + KeyMaxPoolSize = "maxPoolSize" + KeyMessage = "message" + KeyMinPoolSize = "minPoolSize" + KeyNewDescription = "newDescription" + KeyOperation = "operation" + KeyOperationID = "operationId" + KeyPreviousDescription = "previousDescription" + KeyRemainingTimeMS = "remainingTimeMS" + KeyReason = "reason" + KeyReply = "reply" + KeyRequestID = "requestId" + KeySelector = "selector" + KeyServerConnectionID = "serverConnectionId" + KeyServerHost = "serverHost" + KeyServerPort = "serverPort" + KeyServiceID = "serviceId" + KeyTimestamp = "timestamp" + KeyTopologyDescription = "topologyDescription" + KeyTopologyID = "topologyId" +) + +// KeyValues is a list of key-value pairs. +type KeyValues []interface{} + +// Add adds a key-value pair to an instance of a KeyValues list. +func (kvs *KeyValues) Add(key string, value interface{}) { + *kvs = append(*kvs, key, value) +} + +const ( + ReasonConnClosedStale = "Connection became stale because the pool was cleared" + ReasonConnClosedIdle = "Connection has been available but unused for longer than the configured max idle time" + ReasonConnClosedError = "An error occurred while using the connection" + ReasonConnClosedPoolClosed = "Connection pool was closed" + ReasonConnCheckoutFailedTimout = "Wait queue timeout elapsed without a connection becoming available" + ReasonConnCheckoutFailedError = "An error occurred while trying to establish a new connection" + ReasonConnCheckoutFailedPoolClosed = "Connection pool was closed" +) + +// Component is an enumeration representing the "components" which can be +// logged against. A LogLevel can be configured on a per-component basis. +type Component int + +const ( + // ComponentAll enables logging for all components. + ComponentAll Component = iota + + // ComponentCommand enables command monitor logging. + ComponentCommand + + // ComponentTopology enables topology logging. + ComponentTopology + + // ComponentServerSelection enables server selection logging. + ComponentServerSelection + + // ComponentConnection enables connection services logging. + ComponentConnection +) + +const ( + mongoDBLogAllEnvVar = "MONGODB_LOG_ALL" + mongoDBLogCommandEnvVar = "MONGODB_LOG_COMMAND" + mongoDBLogTopologyEnvVar = "MONGODB_LOG_TOPOLOGY" + mongoDBLogServerSelectionEnvVar = "MONGODB_LOG_SERVER_SELECTION" + mongoDBLogConnectionEnvVar = "MONGODB_LOG_CONNECTION" +) + +var componentEnvVarMap = map[string]Component{ + mongoDBLogAllEnvVar: ComponentAll, + mongoDBLogCommandEnvVar: ComponentCommand, + mongoDBLogTopologyEnvVar: ComponentTopology, + mongoDBLogServerSelectionEnvVar: ComponentServerSelection, + mongoDBLogConnectionEnvVar: ComponentConnection, +} + +// EnvHasComponentVariables returns true if the environment contains any of the +// component environment variables. +func EnvHasComponentVariables() bool { + for envVar := range componentEnvVarMap { + if os.Getenv(envVar) != "" { + return true + } + } + + return false +} + +// Command is a struct defining common fields that must be included in all +// commands. +type Command struct { + // TODO(GODRIVER-2824): change the DriverConnectionID type to int64. + DriverConnectionID uint64 // Driver's ID for the connection + Name string // Command name + DatabaseName string // Database name + Message string // Message associated with the command + OperationID int32 // Driver-generated operation ID + RequestID int64 // Driver-generated request ID + ServerConnectionID *int64 // Server's ID for the connection used for the command + ServerHost string // Hostname or IP address for the server + ServerPort string // Port for the server + ServiceID *primitive.ObjectID // ID for the command in load balancer mode +} + +// SerializeCommand takes a command and a variable number of key-value pairs and +// returns a slice of interface{} that can be passed to the logger for +// structured logging. +func SerializeCommand(cmd Command, extraKeysAndValues ...interface{}) KeyValues { + // Initialize the boilerplate keys and values. + keysAndValues := KeyValues{ + KeyCommandName, cmd.Name, + KeyDatabaseName, cmd.DatabaseName, + KeyDriverConnectionID, cmd.DriverConnectionID, + KeyMessage, cmd.Message, + KeyOperationID, cmd.OperationID, + KeyRequestID, cmd.RequestID, + KeyServerHost, cmd.ServerHost, + } + + // Add the extra keys and values. + for i := 0; i < len(extraKeysAndValues); i += 2 { + keysAndValues.Add(extraKeysAndValues[i].(string), extraKeysAndValues[i+1]) + } + + port, err := strconv.ParseInt(cmd.ServerPort, 10, 32) + if err == nil { + keysAndValues.Add(KeyServerPort, port) + } + + // Add the "serverConnectionId" if it is not nil. + if cmd.ServerConnectionID != nil { + keysAndValues.Add(KeyServerConnectionID, *cmd.ServerConnectionID) + } + + // Add the "serviceId" if it is not nil. + if cmd.ServiceID != nil { + keysAndValues.Add(KeyServiceID, cmd.ServiceID.Hex()) + } + + return keysAndValues +} + +// Connection contains data that all connection log messages MUST contain. +type Connection struct { + Message string // Message associated with the connection + ServerHost string // Hostname or IP address for the server + ServerPort string // Port for the server +} + +// SerializeConnection serializes a Connection message into a slice of keys and +// values that can be passed to a logger. +func SerializeConnection(conn Connection, extraKeysAndValues ...interface{}) KeyValues { + // Initialize the boilerplate keys and values. + keysAndValues := KeyValues{ + KeyMessage, conn.Message, + KeyServerHost, conn.ServerHost, + } + + // Add the optional keys and values. + for i := 0; i < len(extraKeysAndValues); i += 2 { + keysAndValues.Add(extraKeysAndValues[i].(string), extraKeysAndValues[i+1]) + } + + port, err := strconv.ParseInt(conn.ServerPort, 10, 32) + if err == nil { + keysAndValues.Add(KeyServerPort, port) + } + + return keysAndValues +} + +// Server contains data that all server messages MAY contain. +type Server struct { + DriverConnectionID uint64 // Driver's ID for the connection + TopologyID primitive.ObjectID // Driver's unique ID for this topology + Message string // Message associated with the topology + ServerConnectionID *int64 // Server's ID for the connection + ServerHost string // Hostname or IP address for the server + ServerPort string // Port for the server +} + +// SerializeServer serializes a Server message into a slice of keys and +// values that can be passed to a logger. +func SerializeServer(srv Server, extraKV ...interface{}) KeyValues { + // Initialize the boilerplate keys and values. + keysAndValues := KeyValues{ + KeyDriverConnectionID, srv.DriverConnectionID, + KeyMessage, srv.Message, + KeyServerHost, srv.ServerHost, + KeyTopologyID, srv.TopologyID.Hex(), + } + + if connID := srv.ServerConnectionID; connID != nil { + keysAndValues.Add(KeyServerConnectionID, *connID) + } + + port, err := strconv.ParseInt(srv.ServerPort, 10, 32) + if err == nil { + keysAndValues.Add(KeyServerPort, port) + } + + // Add the optional keys and values. + for i := 0; i < len(extraKV); i += 2 { + keysAndValues.Add(extraKV[i].(string), extraKV[i+1]) + } + + return keysAndValues +} + +// ServerSelection contains data that all server selection messages MUST +// contain. +type ServerSelection struct { + Selector string + OperationID *int32 + Operation string + TopologyDescription string +} + +// SerializeServerSelection serializes a Topology message into a slice of keys +// and values that can be passed to a logger. +func SerializeServerSelection(srvSelection ServerSelection, extraKV ...interface{}) KeyValues { + keysAndValues := KeyValues{ + KeySelector, srvSelection.Selector, + KeyOperation, srvSelection.Operation, + KeyTopologyDescription, srvSelection.TopologyDescription, + } + + if srvSelection.OperationID != nil { + keysAndValues.Add(KeyOperationID, *srvSelection.OperationID) + } + + // Add the optional keys and values. + for i := 0; i < len(extraKV); i += 2 { + keysAndValues.Add(extraKV[i].(string), extraKV[i+1]) + } + + return keysAndValues +} + +// Topology contains data that all topology messages MAY contain. +type Topology struct { + ID primitive.ObjectID // Driver's unique ID for this topology + Message string // Message associated with the topology +} + +// SerializeTopology serializes a Topology message into a slice of keys and +// values that can be passed to a logger. +func SerializeTopology(topo Topology, extraKV ...interface{}) KeyValues { + keysAndValues := KeyValues{ + KeyTopologyID, topo.ID.Hex(), + } + + // Add the optional keys and values. + for i := 0; i < len(extraKV); i += 2 { + keysAndValues.Add(extraKV[i].(string), extraKV[i+1]) + } + + return keysAndValues +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/logger/context.go b/vendor/go.mongodb.org/mongo-driver/internal/logger/context.go new file mode 100644 index 0000000..785f141 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/logger/context.go @@ -0,0 +1,48 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package logger + +import "context" + +// contextKey is a custom type used to prevent key collisions when using the +// context package. +type contextKey string + +const ( + contextKeyOperation contextKey = "operation" + contextKeyOperationID contextKey = "operationID" +) + +// WithOperationName adds the operation name to the context. +func WithOperationName(ctx context.Context, operation string) context.Context { + return context.WithValue(ctx, contextKeyOperation, operation) +} + +// WithOperationID adds the operation ID to the context. +func WithOperationID(ctx context.Context, operationID int32) context.Context { + return context.WithValue(ctx, contextKeyOperationID, operationID) +} + +// OperationName returns the operation name from the context. +func OperationName(ctx context.Context) (string, bool) { + operationName := ctx.Value(contextKeyOperation) + if operationName == nil { + return "", false + } + + return operationName.(string), true +} + +// OperationID returns the operation ID from the context. +func OperationID(ctx context.Context) (int32, bool) { + operationID := ctx.Value(contextKeyOperationID) + if operationID == nil { + return 0, false + } + + return operationID.(int32), true +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/logger/io_sink.go b/vendor/go.mongodb.org/mongo-driver/internal/logger/io_sink.go new file mode 100644 index 0000000..0a6c1bd --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/logger/io_sink.go @@ -0,0 +1,63 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package logger + +import ( + "encoding/json" + "io" + "math" + "sync" + "time" +) + +// IOSink writes a JSON-encoded message to the io.Writer. +type IOSink struct { + enc *json.Encoder + + // encMu protects the encoder from concurrent writes. While the logger + // itself does not concurrently write to the sink, the sink may be used + // concurrently within the driver. + encMu sync.Mutex +} + +// Compile-time check to ensure IOSink implements the LogSink interface. +var _ LogSink = &IOSink{} + +// NewIOSink will create an IOSink object that writes JSON messages to the +// provided io.Writer. +func NewIOSink(out io.Writer) *IOSink { + return &IOSink{ + enc: json.NewEncoder(out), + } +} + +// Info will write a JSON-encoded message to the io.Writer. +func (sink *IOSink) Info(_ int, msg string, keysAndValues ...interface{}) { + mapSize := len(keysAndValues) / 2 + if math.MaxInt-mapSize >= 2 { + mapSize += 2 + } + kvMap := make(map[string]interface{}, mapSize) + + kvMap[KeyTimestamp] = time.Now().UnixNano() + kvMap[KeyMessage] = msg + + for i := 0; i < len(keysAndValues); i += 2 { + kvMap[keysAndValues[i].(string)] = keysAndValues[i+1] + } + + sink.encMu.Lock() + defer sink.encMu.Unlock() + + _ = sink.enc.Encode(kvMap) +} + +// Error will write a JSON-encoded error message to the io.Writer. +func (sink *IOSink) Error(err error, msg string, kv ...interface{}) { + kv = append(kv, KeyError, err.Error()) + sink.Info(0, msg, kv...) +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/logger/level.go b/vendor/go.mongodb.org/mongo-driver/internal/logger/level.go new file mode 100644 index 0000000..07f85b3 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/logger/level.go @@ -0,0 +1,74 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package logger + +import "strings" + +// DiffToInfo is the number of levels in the Go Driver that come before the +// "Info" level. This should ensure that "Info" is the 0th level passed to the +// sink. +const DiffToInfo = 1 + +// Level is an enumeration representing the log severity levels supported by +// the driver. The order of the logging levels is important. The driver expects +// that a user will likely use the "logr" package to create a LogSink, which +// defaults InfoLevel as 0. Any additions to the Level enumeration before the +// InfoLevel will need to also update the "diffToInfo" constant. +type Level int + +const ( + // LevelOff suppresses logging. + LevelOff Level = iota + + // LevelInfo enables logging of informational messages. These logs are + // high-level information about normal driver behavior. + LevelInfo + + // LevelDebug enables logging of debug messages. These logs can be + // voluminous and are intended for detailed information that may be + // helpful when debugging an application. + LevelDebug +) + +const ( + levelLiteralOff = "off" + levelLiteralEmergency = "emergency" + levelLiteralAlert = "alert" + levelLiteralCritical = "critical" + levelLiteralError = "error" + levelLiteralWarning = "warning" + levelLiteralNotice = "notice" + levelLiteralInfo = "info" + levelLiteralDebug = "debug" + levelLiteralTrace = "trace" +) + +var LevelLiteralMap = map[string]Level{ + levelLiteralOff: LevelOff, + levelLiteralEmergency: LevelInfo, + levelLiteralAlert: LevelInfo, + levelLiteralCritical: LevelInfo, + levelLiteralError: LevelInfo, + levelLiteralWarning: LevelInfo, + levelLiteralNotice: LevelInfo, + levelLiteralInfo: LevelInfo, + levelLiteralDebug: LevelDebug, + levelLiteralTrace: LevelDebug, +} + +// ParseLevel will check if the given string is a valid environment variable +// for a logging severity level. If it is, then it will return the associated +// driver's Level. The default Level is “LevelOff”. +func ParseLevel(str string) Level { + for literal, level := range LevelLiteralMap { + if strings.EqualFold(literal, str) { + return level + } + } + + return LevelOff +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/logger/logger.go b/vendor/go.mongodb.org/mongo-driver/internal/logger/logger.go new file mode 100644 index 0000000..2250286 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/logger/logger.go @@ -0,0 +1,275 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +// Package logger provides the internal logging solution for the MongoDB Go +// Driver. +package logger + +import ( + "fmt" + "os" + "strconv" + "strings" +) + +// DefaultMaxDocumentLength is the default maximum number of bytes that can be +// logged for a stringified BSON document. +const DefaultMaxDocumentLength = 1000 + +// TruncationSuffix are trailing ellipsis "..." appended to a message to +// indicate to the user that truncation occurred. This constant does not count +// toward the max document length. +const TruncationSuffix = "..." + +const logSinkPathEnvVar = "MONGODB_LOG_PATH" +const maxDocumentLengthEnvVar = "MONGODB_LOG_MAX_DOCUMENT_LENGTH" + +// LogSink represents a logging implementation, this interface should be 1-1 +// with the exported "LogSink" interface in the mongo/options package. +type LogSink interface { + // Info logs a non-error message with the given key/value pairs. The + // level argument is provided for optional logging. + Info(level int, msg string, keysAndValues ...interface{}) + + // Error logs an error, with the given message and key/value pairs. + Error(err error, msg string, keysAndValues ...interface{}) +} + +// Logger represents the configuration for the internal logger. +type Logger struct { + ComponentLevels map[Component]Level // Log levels for each component. + Sink LogSink // LogSink for log printing. + MaxDocumentLength uint // Command truncation width. + logFile *os.File // File to write logs to. +} + +// New will construct a new logger. If any of the given options are the +// zero-value of the argument type, then the constructor will attempt to +// source the data from the environment. If the environment has not been set, +// then the constructor will the respective default values. +func New(sink LogSink, maxDocLen uint, compLevels map[Component]Level) (*Logger, error) { + logger := &Logger{ + ComponentLevels: selectComponentLevels(compLevels), + MaxDocumentLength: selectMaxDocumentLength(maxDocLen), + } + + sink, logFile, err := selectLogSink(sink) + if err != nil { + return nil, err + } + + logger.Sink = sink + logger.logFile = logFile + + return logger, nil +} + +// Close will close the logger's log file, if it exists. +func (logger *Logger) Close() error { + if logger.logFile != nil { + return logger.logFile.Close() + } + + return nil +} + +// LevelComponentEnabled will return true if the given LogLevel is enabled for +// the given LogComponent. If the ComponentLevels on the logger are enabled for +// "ComponentAll", then this function will return true for any level bound by +// the level assigned to "ComponentAll". +// +// If the level is not enabled (i.e. LevelOff), then false is returned. This is +// to avoid false positives, such as returning "true" for a component that is +// not enabled. For example, without this condition, an empty LevelComponent +// would be considered "enabled" for "LevelOff". +func (logger *Logger) LevelComponentEnabled(level Level, component Component) bool { + if level == LevelOff { + return false + } + + if logger.ComponentLevels == nil { + return false + } + + return logger.ComponentLevels[component] >= level || + logger.ComponentLevels[ComponentAll] >= level +} + +// Print will synchronously print the given message to the configured LogSink. +// If the LogSink is nil, then this method will do nothing. Future work could be done to make +// this method asynchronous, see buffer management in libraries such as log4j. +// +// It's worth noting that many structured logs defined by DBX-wide +// specifications include a "message" field, which is often shared with the +// message arguments passed to this print function. The "Info" method used by +// this function is implemented based on the go-logr/logr LogSink interface, +// which is why "Print" has a message parameter. Any duplication in code is +// intentional to adhere to the logr pattern. +func (logger *Logger) Print(level Level, component Component, msg string, keysAndValues ...interface{}) { + // If the level is not enabled for the component, then + // skip the message. + if !logger.LevelComponentEnabled(level, component) { + return + } + + // If the sink is nil, then skip the message. + if logger.Sink == nil { + return + } + + logger.Sink.Info(int(level)-DiffToInfo, msg, keysAndValues...) +} + +// Error logs an error, with the given message and key/value pairs. +// It functions similarly to Print, but may have unique behavior, and should be +// preferred for logging errors. +func (logger *Logger) Error(err error, msg string, keysAndValues ...interface{}) { + if logger.Sink == nil { + return + } + + logger.Sink.Error(err, msg, keysAndValues...) +} + +// selectMaxDocumentLength will return the integer value of the first non-zero +// function, with the user-defined function taking priority over the environment +// variables. For the environment, the function will attempt to get the value of +// "MONGODB_LOG_MAX_DOCUMENT_LENGTH" and parse it as an unsigned integer. If the +// environment variable is not set or is not an unsigned integer, then this +// function will return the default max document length. +func selectMaxDocumentLength(maxDocLen uint) uint { + if maxDocLen != 0 { + return maxDocLen + } + + maxDocLenEnv := os.Getenv(maxDocumentLengthEnvVar) + if maxDocLenEnv != "" { + maxDocLenEnvInt, err := strconv.ParseUint(maxDocLenEnv, 10, 32) + if err == nil { + return uint(maxDocLenEnvInt) + } + } + + return DefaultMaxDocumentLength +} + +const ( + logSinkPathStdout = "stdout" + logSinkPathStderr = "stderr" +) + +// selectLogSink will return the first non-nil LogSink, with the user-defined +// LogSink taking precedence over the environment-defined LogSink. If no LogSink +// is defined, then this function will return a LogSink that writes to stderr. +func selectLogSink(sink LogSink) (LogSink, *os.File, error) { + if sink != nil { + return sink, nil, nil + } + + path := os.Getenv(logSinkPathEnvVar) + lowerPath := strings.ToLower(path) + + if lowerPath == string(logSinkPathStderr) { + return NewIOSink(os.Stderr), nil, nil + } + + if lowerPath == string(logSinkPathStdout) { + return NewIOSink(os.Stdout), nil, nil + } + + if path != "" { + logFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0666) + if err != nil { + return nil, nil, fmt.Errorf("unable to open log file: %w", err) + } + + return NewIOSink(logFile), logFile, nil + } + + return NewIOSink(os.Stderr), nil, nil +} + +// selectComponentLevels returns a new map of LogComponents to LogLevels that is +// the result of merging the user-defined data with the environment, with the +// user-defined data taking priority. +func selectComponentLevels(componentLevels map[Component]Level) map[Component]Level { + selected := make(map[Component]Level) + + // Determine if the "MONGODB_LOG_ALL" environment variable is set. + var globalEnvLevel *Level + if all := os.Getenv(mongoDBLogAllEnvVar); all != "" { + level := ParseLevel(all) + globalEnvLevel = &level + } + + for envVar, component := range componentEnvVarMap { + // If the component already has a level, then skip it. + if _, ok := componentLevels[component]; ok { + selected[component] = componentLevels[component] + + continue + } + + // If the "MONGODB_LOG_ALL" environment variable is set, then + // set the level for the component to the value of the + // environment variable. + if globalEnvLevel != nil { + selected[component] = *globalEnvLevel + + continue + } + + // Otherwise, set the level for the component to the value of + // the environment variable. + selected[component] = ParseLevel(os.Getenv(envVar)) + } + + return selected +} + +// truncate will truncate a string to the given width, appending "..." to the +// end of the string if it is truncated. This routine is safe for multi-byte +// characters. +func truncate(str string, width uint) string { + if width == 0 { + return "" + } + + if len(str) <= int(width) { + return str + } + + // Truncate the byte slice of the string to the given width. + newStr := str[:width] + + // Check if the last byte is at the beginning of a multi-byte character. + // If it is, then remove the last byte. + if newStr[len(newStr)-1]&0xC0 == 0xC0 { + return newStr[:len(newStr)-1] + TruncationSuffix + } + + // Check if the last byte is in the middle of a multi-byte character. If + // it is, then step back until we find the beginning of the character. + if newStr[len(newStr)-1]&0xC0 == 0x80 { + for i := len(newStr) - 1; i >= 0; i-- { + if newStr[i]&0xC0 == 0xC0 { + return newStr[:i] + TruncationSuffix + } + } + } + + return newStr + TruncationSuffix +} + +// FormatMessage formats a BSON document for logging. The document is truncated +// to the given width. +func FormatMessage(msg string, width uint) string { + if len(msg) == 0 { + return "{}" + } + + return truncate(msg, width) +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/ptrutil/int64.go b/vendor/go.mongodb.org/mongo-driver/internal/ptrutil/int64.go new file mode 100644 index 0000000..1c3ab57 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/ptrutil/int64.go @@ -0,0 +1,39 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package ptrutil + +// CompareInt64 is a piecewise function with the following return conditions: +// +// (1) 2, ptr1 != nil AND ptr2 == nil +// (2) 1, *ptr1 > *ptr2 +// (3) 0, ptr1 == ptr2 or *ptr1 == *ptr2 +// (4) -1, *ptr1 < *ptr2 +// (5) -2, ptr1 == nil AND ptr2 != nil +func CompareInt64(ptr1, ptr2 *int64) int { + if ptr1 == ptr2 { + // This will catch the double nil or same-pointer cases. + return 0 + } + + if ptr1 == nil && ptr2 != nil { + return -2 + } + + if ptr1 != nil && ptr2 == nil { + return 2 + } + + if *ptr1 > *ptr2 { + return 1 + } + + if *ptr1 < *ptr2 { + return -1 + } + + return 0 +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/rand/bits.go b/vendor/go.mongodb.org/mongo-driver/internal/rand/bits.go new file mode 100644 index 0000000..4479009 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/rand/bits.go @@ -0,0 +1,38 @@ +// Copied from https://cs.opensource.google/go/go/+/946b4baaf6521d521928500b2b57429c149854e7:src/math/bits.go + +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rand + +// Add64 returns the sum with carry of x, y and carry: sum = x + y + carry. +// The carry input must be 0 or 1; otherwise the behavior is undefined. +// The carryOut output is guaranteed to be 0 or 1. +func Add64(x, y, carry uint64) (sum, carryOut uint64) { + yc := y + carry + sum = x + yc + if sum < x || yc < y { + carryOut = 1 + } + return +} + +// Mul64 returns the 128-bit product of x and y: (hi, lo) = x * y +// with the product bits' upper half returned in hi and the lower +// half returned in lo. +func Mul64(x, y uint64) (hi, lo uint64) { + const mask32 = 1<<32 - 1 + x0 := x & mask32 + x1 := x >> 32 + y0 := y & mask32 + y1 := y >> 32 + w0 := x0 * y0 + t := x1*y0 + w0>>32 + w1 := t & mask32 + w2 := t >> 32 + w1 += x0 * y1 + hi = x1*y1 + w2 + w1>>32 + lo = x * y + return +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/rand/exp.go b/vendor/go.mongodb.org/mongo-driver/internal/rand/exp.go new file mode 100644 index 0000000..859e4e0 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/rand/exp.go @@ -0,0 +1,223 @@ +// Copied from https://cs.opensource.google/go/x/exp/+/24438e51023af3bfc1db8aed43c1342817e8cfcd:rand/exp.go + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rand + +import ( + "math" +) + +/* + * Exponential distribution + * + * See "The Ziggurat Method for Generating Random Variables" + * (Marsaglia & Tsang, 2000) + * http://www.jstatsoft.org/v05/i08/paper [pdf] + */ + +const ( + re = 7.69711747013104972 +) + +// ExpFloat64 returns an exponentially distributed float64 in the range +// (0, +math.MaxFloat64] with an exponential distribution whose rate parameter +// (lambda) is 1 and whose mean is 1/lambda (1). +// To produce a distribution with a different rate parameter, +// callers can adjust the output using: +// +// sample = ExpFloat64() / desiredRateParameter +func (r *Rand) ExpFloat64() float64 { + for { + j := r.Uint32() + i := j & 0xFF + x := float64(j) * float64(we[i]) + if j < ke[i] { + return x + } + if i == 0 { + return re - math.Log(r.Float64()) + } + if fe[i]+float32(r.Float64())*(fe[i-1]-fe[i]) < float32(math.Exp(-x)) { + return x + } + } +} + +var ke = [256]uint32{ + 0xe290a139, 0x0, 0x9beadebc, 0xc377ac71, 0xd4ddb990, + 0xde893fb8, 0xe4a8e87c, 0xe8dff16a, 0xebf2deab, 0xee49a6e8, + 0xf0204efd, 0xf19bdb8e, 0xf2d458bb, 0xf3da104b, 0xf4b86d78, + 0xf577ad8a, 0xf61de83d, 0xf6afb784, 0xf730a573, 0xf7a37651, + 0xf80a5bb6, 0xf867189d, 0xf8bb1b4f, 0xf9079062, 0xf94d70ca, + 0xf98d8c7d, 0xf9c8928a, 0xf9ff175b, 0xfa319996, 0xfa6085f8, + 0xfa8c3a62, 0xfab5084e, 0xfadb36c8, 0xfaff0410, 0xfb20a6ea, + 0xfb404fb4, 0xfb5e2951, 0xfb7a59e9, 0xfb95038c, 0xfbae44ba, + 0xfbc638d8, 0xfbdcf892, 0xfbf29a30, 0xfc0731df, 0xfc1ad1ed, + 0xfc2d8b02, 0xfc3f6c4d, 0xfc5083ac, 0xfc60ddd1, 0xfc708662, + 0xfc7f8810, 0xfc8decb4, 0xfc9bbd62, 0xfca9027c, 0xfcb5c3c3, + 0xfcc20864, 0xfccdd70a, 0xfcd935e3, 0xfce42ab0, 0xfceebace, + 0xfcf8eb3b, 0xfd02c0a0, 0xfd0c3f59, 0xfd156b7b, 0xfd1e48d6, + 0xfd26daff, 0xfd2f2552, 0xfd372af7, 0xfd3eeee5, 0xfd4673e7, + 0xfd4dbc9e, 0xfd54cb85, 0xfd5ba2f2, 0xfd62451b, 0xfd68b415, + 0xfd6ef1da, 0xfd750047, 0xfd7ae120, 0xfd809612, 0xfd8620b4, + 0xfd8b8285, 0xfd90bcf5, 0xfd95d15e, 0xfd9ac10b, 0xfd9f8d36, + 0xfda43708, 0xfda8bf9e, 0xfdad2806, 0xfdb17141, 0xfdb59c46, + 0xfdb9a9fd, 0xfdbd9b46, 0xfdc170f6, 0xfdc52bd8, 0xfdc8ccac, + 0xfdcc542d, 0xfdcfc30b, 0xfdd319ef, 0xfdd6597a, 0xfdd98245, + 0xfddc94e5, 0xfddf91e6, 0xfde279ce, 0xfde54d1f, 0xfde80c52, + 0xfdeab7de, 0xfded5034, 0xfdefd5be, 0xfdf248e3, 0xfdf4aa06, + 0xfdf6f984, 0xfdf937b6, 0xfdfb64f4, 0xfdfd818d, 0xfdff8dd0, + 0xfe018a08, 0xfe03767a, 0xfe05536c, 0xfe07211c, 0xfe08dfc9, + 0xfe0a8fab, 0xfe0c30fb, 0xfe0dc3ec, 0xfe0f48b1, 0xfe10bf76, + 0xfe122869, 0xfe1383b4, 0xfe14d17c, 0xfe1611e7, 0xfe174516, + 0xfe186b2a, 0xfe19843e, 0xfe1a9070, 0xfe1b8fd6, 0xfe1c8289, + 0xfe1d689b, 0xfe1e4220, 0xfe1f0f26, 0xfe1fcfbc, 0xfe2083ed, + 0xfe212bc3, 0xfe21c745, 0xfe225678, 0xfe22d95f, 0xfe234ffb, + 0xfe23ba4a, 0xfe241849, 0xfe2469f2, 0xfe24af3c, 0xfe24e81e, + 0xfe25148b, 0xfe253474, 0xfe2547c7, 0xfe254e70, 0xfe25485a, + 0xfe25356a, 0xfe251586, 0xfe24e88f, 0xfe24ae64, 0xfe2466e1, + 0xfe2411df, 0xfe23af34, 0xfe233eb4, 0xfe22c02c, 0xfe22336b, + 0xfe219838, 0xfe20ee58, 0xfe20358c, 0xfe1f6d92, 0xfe1e9621, + 0xfe1daef0, 0xfe1cb7ac, 0xfe1bb002, 0xfe1a9798, 0xfe196e0d, + 0xfe1832fd, 0xfe16e5fe, 0xfe15869d, 0xfe141464, 0xfe128ed3, + 0xfe10f565, 0xfe0f478c, 0xfe0d84b1, 0xfe0bac36, 0xfe09bd73, + 0xfe07b7b5, 0xfe059a40, 0xfe03644c, 0xfe011504, 0xfdfeab88, + 0xfdfc26e9, 0xfdf98629, 0xfdf6c83b, 0xfdf3ec01, 0xfdf0f04a, + 0xfdedd3d1, 0xfdea953d, 0xfde7331e, 0xfde3abe9, 0xfddffdfb, + 0xfddc2791, 0xfdd826cd, 0xfdd3f9a8, 0xfdcf9dfc, 0xfdcb1176, + 0xfdc65198, 0xfdc15bb3, 0xfdbc2ce2, 0xfdb6c206, 0xfdb117be, + 0xfdab2a63, 0xfda4f5fd, 0xfd9e7640, 0xfd97a67a, 0xfd908192, + 0xfd8901f2, 0xfd812182, 0xfd78d98e, 0xfd7022bb, 0xfd66f4ed, + 0xfd5d4732, 0xfd530f9c, 0xfd48432b, 0xfd3cd59a, 0xfd30b936, + 0xfd23dea4, 0xfd16349e, 0xfd07a7a3, 0xfcf8219b, 0xfce7895b, + 0xfcd5c220, 0xfcc2aadb, 0xfcae1d5e, 0xfc97ed4e, 0xfc7fe6d4, + 0xfc65ccf3, 0xfc495762, 0xfc2a2fc8, 0xfc07ee19, 0xfbe213c1, + 0xfbb8051a, 0xfb890078, 0xfb5411a5, 0xfb180005, 0xfad33482, + 0xfa839276, 0xfa263b32, 0xf9b72d1c, 0xf930a1a2, 0xf889f023, + 0xf7b577d2, 0xf69c650c, 0xf51530f0, 0xf2cb0e3c, 0xeeefb15d, + 0xe6da6ecf, +} +var we = [256]float32{ + 2.0249555e-09, 1.486674e-11, 2.4409617e-11, 3.1968806e-11, + 3.844677e-11, 4.4228204e-11, 4.9516443e-11, 5.443359e-11, + 5.905944e-11, 6.344942e-11, 6.7643814e-11, 7.1672945e-11, + 7.556032e-11, 7.932458e-11, 8.298079e-11, 8.654132e-11, + 9.0016515e-11, 9.3415074e-11, 9.674443e-11, 1.0001099e-10, + 1.03220314e-10, 1.06377254e-10, 1.09486115e-10, 1.1255068e-10, + 1.1557435e-10, 1.1856015e-10, 1.2151083e-10, 1.2442886e-10, + 1.2731648e-10, 1.3017575e-10, 1.3300853e-10, 1.3581657e-10, + 1.3860142e-10, 1.4136457e-10, 1.4410738e-10, 1.4683108e-10, + 1.4953687e-10, 1.5222583e-10, 1.54899e-10, 1.5755733e-10, + 1.6020171e-10, 1.6283301e-10, 1.6545203e-10, 1.6805951e-10, + 1.7065617e-10, 1.732427e-10, 1.7581973e-10, 1.7838787e-10, + 1.8094774e-10, 1.8349985e-10, 1.8604476e-10, 1.8858298e-10, + 1.9111498e-10, 1.9364126e-10, 1.9616223e-10, 1.9867835e-10, + 2.0119004e-10, 2.0369768e-10, 2.0620168e-10, 2.087024e-10, + 2.1120022e-10, 2.136955e-10, 2.1618855e-10, 2.1867974e-10, + 2.2116936e-10, 2.2365775e-10, 2.261452e-10, 2.2863202e-10, + 2.311185e-10, 2.3360494e-10, 2.360916e-10, 2.3857874e-10, + 2.4106667e-10, 2.4355562e-10, 2.4604588e-10, 2.485377e-10, + 2.5103128e-10, 2.5352695e-10, 2.560249e-10, 2.585254e-10, + 2.6102867e-10, 2.6353494e-10, 2.6604446e-10, 2.6855745e-10, + 2.7107416e-10, 2.7359479e-10, 2.761196e-10, 2.7864877e-10, + 2.8118255e-10, 2.8372119e-10, 2.8626485e-10, 2.888138e-10, + 2.9136826e-10, 2.939284e-10, 2.9649452e-10, 2.9906677e-10, + 3.016454e-10, 3.0423064e-10, 3.0682268e-10, 3.0942177e-10, + 3.1202813e-10, 3.1464195e-10, 3.1726352e-10, 3.19893e-10, + 3.2253064e-10, 3.251767e-10, 3.2783135e-10, 3.3049485e-10, + 3.3316744e-10, 3.3584938e-10, 3.3854083e-10, 3.4124212e-10, + 3.4395342e-10, 3.46675e-10, 3.4940711e-10, 3.5215003e-10, + 3.5490397e-10, 3.5766917e-10, 3.6044595e-10, 3.6323455e-10, + 3.660352e-10, 3.6884823e-10, 3.7167386e-10, 3.745124e-10, + 3.773641e-10, 3.802293e-10, 3.8310827e-10, 3.860013e-10, + 3.8890866e-10, 3.918307e-10, 3.9476775e-10, 3.9772008e-10, + 4.0068804e-10, 4.0367196e-10, 4.0667217e-10, 4.09689e-10, + 4.1272286e-10, 4.1577405e-10, 4.1884296e-10, 4.2192994e-10, + 4.250354e-10, 4.281597e-10, 4.313033e-10, 4.3446652e-10, + 4.3764986e-10, 4.408537e-10, 4.4407847e-10, 4.4732465e-10, + 4.5059267e-10, 4.5388301e-10, 4.571962e-10, 4.6053267e-10, + 4.6389292e-10, 4.6727755e-10, 4.70687e-10, 4.741219e-10, + 4.7758275e-10, 4.810702e-10, 4.845848e-10, 4.8812715e-10, + 4.9169796e-10, 4.9529775e-10, 4.989273e-10, 5.0258725e-10, + 5.0627835e-10, 5.100013e-10, 5.1375687e-10, 5.1754584e-10, + 5.21369e-10, 5.2522725e-10, 5.2912136e-10, 5.330522e-10, + 5.370208e-10, 5.4102806e-10, 5.45075e-10, 5.491625e-10, + 5.532918e-10, 5.5746385e-10, 5.616799e-10, 5.6594107e-10, + 5.7024857e-10, 5.746037e-10, 5.7900773e-10, 5.834621e-10, + 5.8796823e-10, 5.925276e-10, 5.971417e-10, 6.018122e-10, + 6.065408e-10, 6.113292e-10, 6.1617933e-10, 6.2109295e-10, + 6.260722e-10, 6.3111916e-10, 6.3623595e-10, 6.4142497e-10, + 6.4668854e-10, 6.5202926e-10, 6.5744976e-10, 6.6295286e-10, + 6.6854156e-10, 6.742188e-10, 6.79988e-10, 6.858526e-10, + 6.9181616e-10, 6.978826e-10, 7.04056e-10, 7.103407e-10, + 7.167412e-10, 7.2326256e-10, 7.2990985e-10, 7.366886e-10, + 7.4360473e-10, 7.5066453e-10, 7.5787476e-10, 7.6524265e-10, + 7.7277595e-10, 7.80483e-10, 7.883728e-10, 7.9645507e-10, + 8.047402e-10, 8.1323964e-10, 8.219657e-10, 8.309319e-10, + 8.401528e-10, 8.496445e-10, 8.594247e-10, 8.6951274e-10, + 8.799301e-10, 8.9070046e-10, 9.018503e-10, 9.134092e-10, + 9.254101e-10, 9.378904e-10, 9.508923e-10, 9.644638e-10, + 9.786603e-10, 9.935448e-10, 1.0091913e-09, 1.025686e-09, + 1.0431306e-09, 1.0616465e-09, 1.08138e-09, 1.1025096e-09, + 1.1252564e-09, 1.1498986e-09, 1.1767932e-09, 1.206409e-09, + 1.2393786e-09, 1.276585e-09, 1.3193139e-09, 1.3695435e-09, + 1.4305498e-09, 1.508365e-09, 1.6160854e-09, 1.7921248e-09, +} +var fe = [256]float32{ + 1, 0.9381437, 0.90046996, 0.87170434, 0.8477855, 0.8269933, + 0.8084217, 0.7915276, 0.77595687, 0.7614634, 0.7478686, + 0.7350381, 0.72286767, 0.71127474, 0.70019263, 0.6895665, + 0.67935055, 0.6695063, 0.66000086, 0.65080583, 0.6418967, + 0.63325197, 0.6248527, 0.6166822, 0.60872537, 0.60096896, + 0.5934009, 0.58601034, 0.5787874, 0.57172304, 0.5648092, + 0.5580383, 0.5514034, 0.5448982, 0.5385169, 0.53225386, + 0.5261042, 0.52006316, 0.5141264, 0.50828975, 0.5025495, + 0.496902, 0.49134386, 0.485872, 0.48048335, 0.4751752, + 0.46994483, 0.46478975, 0.45970762, 0.45469615, 0.44975325, + 0.44487688, 0.44006512, 0.43531612, 0.43062815, 0.42599955, + 0.42142874, 0.4169142, 0.41245446, 0.40804818, 0.403694, + 0.3993907, 0.39513698, 0.39093173, 0.38677382, 0.38266218, + 0.37859577, 0.37457356, 0.37059465, 0.3666581, 0.362763, + 0.35890847, 0.35509375, 0.351318, 0.3475805, 0.34388044, + 0.34021714, 0.3365899, 0.33299807, 0.32944095, 0.32591796, + 0.3224285, 0.3189719, 0.31554767, 0.31215525, 0.30879408, + 0.3054636, 0.3021634, 0.29889292, 0.2956517, 0.29243928, + 0.28925523, 0.28609908, 0.28297043, 0.27986884, 0.27679393, + 0.2737453, 0.2707226, 0.2677254, 0.26475343, 0.26180625, + 0.25888354, 0.25598502, 0.2531103, 0.25025907, 0.24743107, + 0.24462597, 0.24184346, 0.23908329, 0.23634516, 0.23362878, + 0.23093392, 0.2282603, 0.22560766, 0.22297576, 0.22036438, + 0.21777324, 0.21520215, 0.21265087, 0.21011916, 0.20760682, + 0.20511365, 0.20263945, 0.20018397, 0.19774707, 0.19532852, + 0.19292815, 0.19054577, 0.1881812, 0.18583426, 0.18350479, + 0.1811926, 0.17889754, 0.17661946, 0.17435817, 0.17211354, + 0.1698854, 0.16767362, 0.16547804, 0.16329853, 0.16113494, + 0.15898713, 0.15685499, 0.15473837, 0.15263714, 0.15055119, + 0.14848037, 0.14642459, 0.14438373, 0.14235765, 0.14034624, + 0.13834943, 0.13636707, 0.13439907, 0.13244532, 0.13050574, + 0.1285802, 0.12666863, 0.12477092, 0.12288698, 0.12101672, + 0.119160056, 0.1173169, 0.115487166, 0.11367077, 0.11186763, + 0.11007768, 0.10830083, 0.10653701, 0.10478614, 0.10304816, + 0.101323, 0.09961058, 0.09791085, 0.09622374, 0.09454919, + 0.09288713, 0.091237515, 0.08960028, 0.087975375, 0.08636274, + 0.08476233, 0.083174095, 0.081597984, 0.08003395, 0.07848195, + 0.076941945, 0.07541389, 0.07389775, 0.072393484, 0.07090106, + 0.069420435, 0.06795159, 0.066494495, 0.06504912, 0.063615434, + 0.062193416, 0.060783047, 0.059384305, 0.057997175, + 0.05662164, 0.05525769, 0.053905312, 0.052564494, 0.051235236, + 0.049917534, 0.048611384, 0.047316793, 0.046033762, 0.0447623, + 0.043502413, 0.042254124, 0.041017443, 0.039792392, + 0.038578995, 0.037377283, 0.036187284, 0.035009038, + 0.033842582, 0.032687962, 0.031545233, 0.030414443, 0.02929566, + 0.02818895, 0.027094385, 0.026012046, 0.024942026, 0.023884421, + 0.022839336, 0.021806888, 0.020787204, 0.019780423, 0.0187867, + 0.0178062, 0.016839107, 0.015885621, 0.014945968, 0.014020392, + 0.013109165, 0.012212592, 0.011331013, 0.01046481, 0.009614414, + 0.008780315, 0.007963077, 0.0071633533, 0.006381906, + 0.0056196423, 0.0048776558, 0.004157295, 0.0034602648, + 0.0027887989, 0.0021459677, 0.0015362998, 0.0009672693, + 0.00045413437, +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/rand/normal.go b/vendor/go.mongodb.org/mongo-driver/internal/rand/normal.go new file mode 100644 index 0000000..8c74a35 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/rand/normal.go @@ -0,0 +1,158 @@ +// Copied from https://cs.opensource.google/go/x/exp/+/24438e51023af3bfc1db8aed43c1342817e8cfcd:rand/normal.go + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rand + +import ( + "math" +) + +/* + * Normal distribution + * + * See "The Ziggurat Method for Generating Random Variables" + * (Marsaglia & Tsang, 2000) + * http://www.jstatsoft.org/v05/i08/paper [pdf] + */ + +const ( + rn = 3.442619855899 +) + +func absInt32(i int32) uint32 { + if i < 0 { + return uint32(-i) + } + return uint32(i) +} + +// NormFloat64 returns a normally distributed float64 in the range +// [-math.MaxFloat64, +math.MaxFloat64] with +// standard normal distribution (mean = 0, stddev = 1). +// To produce a different normal distribution, callers can +// adjust the output using: +// +// sample = NormFloat64() * desiredStdDev + desiredMean +func (r *Rand) NormFloat64() float64 { + for { + j := int32(r.Uint32()) // Possibly negative + i := j & 0x7F + x := float64(j) * float64(wn[i]) + if absInt32(j) < kn[i] { + // This case should be hit better than 99% of the time. + return x + } + + if i == 0 { + // This extra work is only required for the base strip. + for { + x = -math.Log(r.Float64()) * (1.0 / rn) + y := -math.Log(r.Float64()) + if y+y >= x*x { + break + } + } + if j > 0 { + return rn + x + } + return -rn - x + } + if fn[i]+float32(r.Float64())*(fn[i-1]-fn[i]) < float32(math.Exp(-.5*x*x)) { + return x + } + } +} + +var kn = [128]uint32{ + 0x76ad2212, 0x0, 0x600f1b53, 0x6ce447a6, 0x725b46a2, + 0x7560051d, 0x774921eb, 0x789a25bd, 0x799045c3, 0x7a4bce5d, + 0x7adf629f, 0x7b5682a6, 0x7bb8a8c6, 0x7c0ae722, 0x7c50cce7, + 0x7c8cec5b, 0x7cc12cd6, 0x7ceefed2, 0x7d177e0b, 0x7d3b8883, + 0x7d5bce6c, 0x7d78dd64, 0x7d932886, 0x7dab0e57, 0x7dc0dd30, + 0x7dd4d688, 0x7de73185, 0x7df81cea, 0x7e07c0a3, 0x7e163efa, + 0x7e23b587, 0x7e303dfd, 0x7e3beec2, 0x7e46db77, 0x7e51155d, + 0x7e5aabb3, 0x7e63abf7, 0x7e6c222c, 0x7e741906, 0x7e7b9a18, + 0x7e82adfa, 0x7e895c63, 0x7e8fac4b, 0x7e95a3fb, 0x7e9b4924, + 0x7ea0a0ef, 0x7ea5b00d, 0x7eaa7ac3, 0x7eaf04f3, 0x7eb3522a, + 0x7eb765a5, 0x7ebb4259, 0x7ebeeafd, 0x7ec2620a, 0x7ec5a9c4, + 0x7ec8c441, 0x7ecbb365, 0x7ece78ed, 0x7ed11671, 0x7ed38d62, + 0x7ed5df12, 0x7ed80cb4, 0x7eda175c, 0x7edc0005, 0x7eddc78e, + 0x7edf6ebf, 0x7ee0f647, 0x7ee25ebe, 0x7ee3a8a9, 0x7ee4d473, + 0x7ee5e276, 0x7ee6d2f5, 0x7ee7a620, 0x7ee85c10, 0x7ee8f4cd, + 0x7ee97047, 0x7ee9ce59, 0x7eea0eca, 0x7eea3147, 0x7eea3568, + 0x7eea1aab, 0x7ee9e071, 0x7ee98602, 0x7ee90a88, 0x7ee86d08, + 0x7ee7ac6a, 0x7ee6c769, 0x7ee5bc9c, 0x7ee48a67, 0x7ee32efc, + 0x7ee1a857, 0x7edff42f, 0x7ede0ffa, 0x7edbf8d9, 0x7ed9ab94, + 0x7ed7248d, 0x7ed45fae, 0x7ed1585c, 0x7ece095f, 0x7eca6ccb, + 0x7ec67be2, 0x7ec22eee, 0x7ebd7d1a, 0x7eb85c35, 0x7eb2c075, + 0x7eac9c20, 0x7ea5df27, 0x7e9e769f, 0x7e964c16, 0x7e8d44ba, + 0x7e834033, 0x7e781728, 0x7e6b9933, 0x7e5d8a1a, 0x7e4d9ded, + 0x7e3b737a, 0x7e268c2f, 0x7e0e3ff5, 0x7df1aa5d, 0x7dcf8c72, + 0x7da61a1e, 0x7d72a0fb, 0x7d30e097, 0x7cd9b4ab, 0x7c600f1a, + 0x7ba90bdc, 0x7a722176, 0x77d664e5, +} +var wn = [128]float32{ + 1.7290405e-09, 1.2680929e-10, 1.6897518e-10, 1.9862688e-10, + 2.2232431e-10, 2.4244937e-10, 2.601613e-10, 2.7611988e-10, + 2.9073963e-10, 3.042997e-10, 3.1699796e-10, 3.289802e-10, + 3.4035738e-10, 3.5121603e-10, 3.616251e-10, 3.7164058e-10, + 3.8130857e-10, 3.9066758e-10, 3.9975012e-10, 4.08584e-10, + 4.1719309e-10, 4.2559822e-10, 4.338176e-10, 4.418672e-10, + 4.497613e-10, 4.5751258e-10, 4.651324e-10, 4.7263105e-10, + 4.8001775e-10, 4.87301e-10, 4.944885e-10, 5.015873e-10, + 5.0860405e-10, 5.155446e-10, 5.2241467e-10, 5.2921934e-10, + 5.359635e-10, 5.426517e-10, 5.4928817e-10, 5.5587696e-10, + 5.624219e-10, 5.6892646e-10, 5.753941e-10, 5.818282e-10, + 5.882317e-10, 5.946077e-10, 6.00959e-10, 6.072884e-10, + 6.135985e-10, 6.19892e-10, 6.2617134e-10, 6.3243905e-10, + 6.386974e-10, 6.449488e-10, 6.511956e-10, 6.5744005e-10, + 6.6368433e-10, 6.699307e-10, 6.7618144e-10, 6.824387e-10, + 6.8870465e-10, 6.949815e-10, 7.012715e-10, 7.075768e-10, + 7.1389966e-10, 7.202424e-10, 7.266073e-10, 7.329966e-10, + 7.394128e-10, 7.4585826e-10, 7.5233547e-10, 7.58847e-10, + 7.653954e-10, 7.719835e-10, 7.7861395e-10, 7.852897e-10, + 7.920138e-10, 7.987892e-10, 8.0561924e-10, 8.125073e-10, + 8.194569e-10, 8.2647167e-10, 8.3355556e-10, 8.407127e-10, + 8.479473e-10, 8.55264e-10, 8.6266755e-10, 8.7016316e-10, + 8.777562e-10, 8.8545243e-10, 8.932582e-10, 9.0117996e-10, + 9.09225e-10, 9.174008e-10, 9.2571584e-10, 9.341788e-10, + 9.427997e-10, 9.515889e-10, 9.605579e-10, 9.697193e-10, + 9.790869e-10, 9.88676e-10, 9.985036e-10, 1.0085882e-09, + 1.0189509e-09, 1.0296151e-09, 1.0406069e-09, 1.0519566e-09, + 1.063698e-09, 1.0758702e-09, 1.0885183e-09, 1.1016947e-09, + 1.1154611e-09, 1.1298902e-09, 1.1450696e-09, 1.1611052e-09, + 1.1781276e-09, 1.1962995e-09, 1.2158287e-09, 1.2369856e-09, + 1.2601323e-09, 1.2857697e-09, 1.3146202e-09, 1.347784e-09, + 1.3870636e-09, 1.4357403e-09, 1.5008659e-09, 1.6030948e-09, +} +var fn = [128]float32{ + 1, 0.9635997, 0.9362827, 0.9130436, 0.89228165, 0.87324303, + 0.8555006, 0.8387836, 0.8229072, 0.8077383, 0.793177, + 0.7791461, 0.7655842, 0.7524416, 0.73967725, 0.7272569, + 0.7151515, 0.7033361, 0.69178915, 0.68049186, 0.6694277, + 0.658582, 0.6479418, 0.63749546, 0.6272325, 0.6171434, + 0.6072195, 0.5974532, 0.58783704, 0.5783647, 0.56903, + 0.5598274, 0.5507518, 0.54179835, 0.5329627, 0.52424055, + 0.5156282, 0.50712204, 0.49871865, 0.49041483, 0.48220766, + 0.4740943, 0.46607214, 0.4581387, 0.45029163, 0.44252872, + 0.43484783, 0.427247, 0.41972435, 0.41227803, 0.40490642, + 0.39760786, 0.3903808, 0.3832238, 0.37613547, 0.36911446, + 0.3621595, 0.35526937, 0.34844297, 0.34167916, 0.33497685, + 0.3283351, 0.3217529, 0.3152294, 0.30876362, 0.30235484, + 0.29600215, 0.28970486, 0.2834622, 0.2772735, 0.27113807, + 0.2650553, 0.25902456, 0.2530453, 0.24711695, 0.241239, + 0.23541094, 0.22963232, 0.2239027, 0.21822165, 0.21258877, + 0.20700371, 0.20146611, 0.19597565, 0.19053204, 0.18513499, + 0.17978427, 0.17447963, 0.1692209, 0.16400786, 0.15884037, + 0.15371831, 0.14864157, 0.14361008, 0.13862377, 0.13368265, + 0.12878671, 0.12393598, 0.119130544, 0.11437051, 0.10965602, + 0.104987256, 0.10036444, 0.095787846, 0.0912578, 0.08677467, + 0.0823389, 0.077950984, 0.073611505, 0.06932112, 0.06508058, + 0.06089077, 0.056752663, 0.0526674, 0.048636295, 0.044660863, + 0.040742867, 0.03688439, 0.033087887, 0.029356318, + 0.025693292, 0.022103304, 0.018592102, 0.015167298, + 0.011839478, 0.008624485, 0.005548995, 0.0026696292, +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/rand/rand.go b/vendor/go.mongodb.org/mongo-driver/internal/rand/rand.go new file mode 100644 index 0000000..4c3d3e6 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/rand/rand.go @@ -0,0 +1,374 @@ +// Copied from https://cs.opensource.google/go/x/exp/+/24438e51023af3bfc1db8aed43c1342817e8cfcd:rand/rand.go + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package rand implements pseudo-random number generators. +// +// Random numbers are generated by a Source. Top-level functions, such as +// Float64 and Int, use a default shared Source that produces a deterministic +// sequence of values each time a program is run. Use the Seed function to +// initialize the default Source if different behavior is required for each run. +// The default Source, a LockedSource, is safe for concurrent use by multiple +// goroutines, but Sources created by NewSource are not. However, Sources are small +// and it is reasonable to have a separate Source for each goroutine, seeded +// differently, to avoid locking. +// +// For random numbers suitable for security-sensitive work, see the crypto/rand +// package. +package rand + +import "sync" + +// A Source represents a source of uniformly-distributed +// pseudo-random int64 values in the range [0, 1<<64). +type Source interface { + Uint64() uint64 + Seed(seed uint64) +} + +// NewSource returns a new pseudo-random Source seeded with the given value. +func NewSource(seed uint64) Source { + var rng PCGSource + rng.Seed(seed) + return &rng +} + +// A Rand is a source of random numbers. +type Rand struct { + src Source + + // readVal contains remainder of 64-bit integer used for bytes + // generation during most recent Read call. + // It is saved so next Read call can start where the previous + // one finished. + readVal uint64 + // readPos indicates the number of low-order bytes of readVal + // that are still valid. + readPos int8 +} + +// New returns a new Rand that uses random values from src +// to generate other random values. +func New(src Source) *Rand { + return &Rand{src: src} +} + +// Seed uses the provided seed value to initialize the generator to a deterministic state. +// Seed should not be called concurrently with any other Rand method. +func (r *Rand) Seed(seed uint64) { + if lk, ok := r.src.(*LockedSource); ok { + lk.seedPos(seed, &r.readPos) + return + } + + r.src.Seed(seed) + r.readPos = 0 +} + +// Uint64 returns a pseudo-random 64-bit integer as a uint64. +func (r *Rand) Uint64() uint64 { return r.src.Uint64() } + +// Int63 returns a non-negative pseudo-random 63-bit integer as an int64. +func (r *Rand) Int63() int64 { return int64(r.src.Uint64() &^ (1 << 63)) } + +// Uint32 returns a pseudo-random 32-bit value as a uint32. +func (r *Rand) Uint32() uint32 { return uint32(r.Uint64() >> 32) } + +// Int31 returns a non-negative pseudo-random 31-bit integer as an int32. +func (r *Rand) Int31() int32 { return int32(r.Uint64() >> 33) } + +// Int returns a non-negative pseudo-random int. +func (r *Rand) Int() int { + u := uint(r.Uint64()) + return int(u << 1 >> 1) // clear sign bit. +} + +const maxUint64 = (1 << 64) - 1 + +// Uint64n returns, as a uint64, a pseudo-random number in [0,n). +// It is guaranteed more uniform than taking a Source value mod n +// for any n that is not a power of 2. +func (r *Rand) Uint64n(n uint64) uint64 { + if n&(n-1) == 0 { // n is power of two, can mask + if n == 0 { + panic("invalid argument to Uint64n") + } + return r.Uint64() & (n - 1) + } + // If n does not divide v, to avoid bias we must not use + // a v that is within maxUint64%n of the top of the range. + v := r.Uint64() + if v > maxUint64-n { // Fast check. + ceiling := maxUint64 - maxUint64%n + for v >= ceiling { + v = r.Uint64() + } + } + + return v % n +} + +// Int63n returns, as an int64, a non-negative pseudo-random number in [0,n). +// It panics if n <= 0. +func (r *Rand) Int63n(n int64) int64 { + if n <= 0 { + panic("invalid argument to Int63n") + } + return int64(r.Uint64n(uint64(n))) +} + +// Int31n returns, as an int32, a non-negative pseudo-random number in [0,n). +// It panics if n <= 0. +func (r *Rand) Int31n(n int32) int32 { + if n <= 0 { + panic("invalid argument to Int31n") + } + // TODO: Avoid some 64-bit ops to make it more efficient on 32-bit machines. + return int32(r.Uint64n(uint64(n))) +} + +// Intn returns, as an int, a non-negative pseudo-random number in [0,n). +// It panics if n <= 0. +func (r *Rand) Intn(n int) int { + if n <= 0 { + panic("invalid argument to Intn") + } + // TODO: Avoid some 64-bit ops to make it more efficient on 32-bit machines. + return int(r.Uint64n(uint64(n))) +} + +// Float64 returns, as a float64, a pseudo-random number in [0.0,1.0). +func (r *Rand) Float64() float64 { + // There is one bug in the value stream: r.Int63() may be so close + // to 1<<63 that the division rounds up to 1.0, and we've guaranteed + // that the result is always less than 1.0. + // + // We tried to fix this by mapping 1.0 back to 0.0, but since float64 + // values near 0 are much denser than near 1, mapping 1 to 0 caused + // a theoretically significant overshoot in the probability of returning 0. + // Instead of that, if we round up to 1, just try again. + // Getting 1 only happens 1/2⁵³ of the time, so most clients + // will not observe it anyway. +again: + f := float64(r.Uint64n(1<<53)) / (1 << 53) + if f == 1.0 { + goto again // resample; this branch is taken O(never) + } + return f +} + +// Float32 returns, as a float32, a pseudo-random number in [0.0,1.0). +func (r *Rand) Float32() float32 { + // We do not want to return 1.0. + // This only happens 1/2²⁴ of the time (plus the 1/2⁵³ of the time in Float64). +again: + f := float32(r.Float64()) + if f == 1 { + goto again // resample; this branch is taken O(very rarely) + } + return f +} + +// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers [0,n). +func (r *Rand) Perm(n int) []int { + m := make([]int, n) + // In the following loop, the iteration when i=0 always swaps m[0] with m[0]. + // A change to remove this useless iteration is to assign 1 to i in the init + // statement. But Perm also effects r. Making this change will affect + // the final state of r. So this change can't be made for compatibility + // reasons for Go 1. + for i := 0; i < n; i++ { + j := r.Intn(i + 1) + m[i] = m[j] + m[j] = i + } + return m +} + +// Shuffle pseudo-randomizes the order of elements. +// n is the number of elements. Shuffle panics if n < 0. +// swap swaps the elements with indexes i and j. +func (r *Rand) Shuffle(n int, swap func(i, j int)) { + if n < 0 { + panic("invalid argument to Shuffle") + } + + // Fisher-Yates shuffle: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle + // Shuffle really ought not be called with n that doesn't fit in 32 bits. + // Not only will it take a very long time, but with 2³¹! possible permutations, + // there's no way that any PRNG can have a big enough internal state to + // generate even a minuscule percentage of the possible permutations. + // Nevertheless, the right API signature accepts an int n, so handle it as best we can. + i := n - 1 + for ; i > 1<<31-1-1; i-- { + j := int(r.Int63n(int64(i + 1))) + swap(i, j) + } + for ; i > 0; i-- { + j := int(r.Int31n(int32(i + 1))) + swap(i, j) + } +} + +// Read generates len(p) random bytes and writes them into p. It +// always returns len(p) and a nil error. +// Read should not be called concurrently with any other Rand method unless +// the underlying source is a LockedSource. +func (r *Rand) Read(p []byte) (n int, err error) { + if lk, ok := r.src.(*LockedSource); ok { + return lk.Read(p, &r.readVal, &r.readPos) + } + return read(p, r.src, &r.readVal, &r.readPos) +} + +func read(p []byte, src Source, readVal *uint64, readPos *int8) (n int, err error) { + pos := *readPos + val := *readVal + rng, _ := src.(*PCGSource) + for n = 0; n < len(p); n++ { + if pos == 0 { + if rng != nil { + val = rng.Uint64() + } else { + val = src.Uint64() + } + pos = 8 + } + p[n] = byte(val) + val >>= 8 + pos-- + } + *readPos = pos + *readVal = val + return +} + +/* + * Top-level convenience functions + */ + +var globalRand = New(&LockedSource{src: *NewSource(1).(*PCGSource)}) + +// Type assert that globalRand's source is a LockedSource whose src is a PCGSource. +var _ PCGSource = globalRand.src.(*LockedSource).src + +// Seed uses the provided seed value to initialize the default Source to a +// deterministic state. If Seed is not called, the generator behaves as +// if seeded by Seed(1). +// Seed, unlike the Rand.Seed method, is safe for concurrent use. +func Seed(seed uint64) { globalRand.Seed(seed) } + +// Int63 returns a non-negative pseudo-random 63-bit integer as an int64 +// from the default Source. +func Int63() int64 { return globalRand.Int63() } + +// Uint32 returns a pseudo-random 32-bit value as a uint32 +// from the default Source. +func Uint32() uint32 { return globalRand.Uint32() } + +// Uint64 returns a pseudo-random 64-bit value as a uint64 +// from the default Source. +func Uint64() uint64 { return globalRand.Uint64() } + +// Int31 returns a non-negative pseudo-random 31-bit integer as an int32 +// from the default Source. +func Int31() int32 { return globalRand.Int31() } + +// Int returns a non-negative pseudo-random int from the default Source. +func Int() int { return globalRand.Int() } + +// Int63n returns, as an int64, a non-negative pseudo-random number in [0,n) +// from the default Source. +// It panics if n <= 0. +func Int63n(n int64) int64 { return globalRand.Int63n(n) } + +// Int31n returns, as an int32, a non-negative pseudo-random number in [0,n) +// from the default Source. +// It panics if n <= 0. +func Int31n(n int32) int32 { return globalRand.Int31n(n) } + +// Intn returns, as an int, a non-negative pseudo-random number in [0,n) +// from the default Source. +// It panics if n <= 0. +func Intn(n int) int { return globalRand.Intn(n) } + +// Float64 returns, as a float64, a pseudo-random number in [0.0,1.0) +// from the default Source. +func Float64() float64 { return globalRand.Float64() } + +// Float32 returns, as a float32, a pseudo-random number in [0.0,1.0) +// from the default Source. +func Float32() float32 { return globalRand.Float32() } + +// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers [0,n) +// from the default Source. +func Perm(n int) []int { return globalRand.Perm(n) } + +// Shuffle pseudo-randomizes the order of elements using the default Source. +// n is the number of elements. Shuffle panics if n < 0. +// swap swaps the elements with indexes i and j. +func Shuffle(n int, swap func(i, j int)) { globalRand.Shuffle(n, swap) } + +// Read generates len(p) random bytes from the default Source and +// writes them into p. It always returns len(p) and a nil error. +// Read, unlike the Rand.Read method, is safe for concurrent use. +func Read(p []byte) (n int, err error) { return globalRand.Read(p) } + +// NormFloat64 returns a normally distributed float64 in the range +// [-math.MaxFloat64, +math.MaxFloat64] with +// standard normal distribution (mean = 0, stddev = 1) +// from the default Source. +// To produce a different normal distribution, callers can +// adjust the output using: +// +// sample = NormFloat64() * desiredStdDev + desiredMean +func NormFloat64() float64 { return globalRand.NormFloat64() } + +// ExpFloat64 returns an exponentially distributed float64 in the range +// (0, +math.MaxFloat64] with an exponential distribution whose rate parameter +// (lambda) is 1 and whose mean is 1/lambda (1) from the default Source. +// To produce a distribution with a different rate parameter, +// callers can adjust the output using: +// +// sample = ExpFloat64() / desiredRateParameter +func ExpFloat64() float64 { return globalRand.ExpFloat64() } + +// LockedSource is an implementation of Source that is concurrency-safe. +// A Rand using a LockedSource is safe for concurrent use. +// +// The zero value of LockedSource is valid, but should be seeded before use. +type LockedSource struct { + lk sync.Mutex + src PCGSource +} + +func (s *LockedSource) Uint64() (n uint64) { + s.lk.Lock() + n = s.src.Uint64() + s.lk.Unlock() + return +} + +func (s *LockedSource) Seed(seed uint64) { + s.lk.Lock() + s.src.Seed(seed) + s.lk.Unlock() +} + +// seedPos implements Seed for a LockedSource without a race condition. +func (s *LockedSource) seedPos(seed uint64, readPos *int8) { + s.lk.Lock() + s.src.Seed(seed) + *readPos = 0 + s.lk.Unlock() +} + +// Read implements Read for a LockedSource. +func (s *LockedSource) Read(p []byte, readVal *uint64, readPos *int8) (n int, err error) { + s.lk.Lock() + n, err = read(p, &s.src, readVal, readPos) + s.lk.Unlock() + return +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/rand/rng.go b/vendor/go.mongodb.org/mongo-driver/internal/rand/rng.go new file mode 100644 index 0000000..f04f987 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/rand/rng.go @@ -0,0 +1,93 @@ +// Copied from https://cs.opensource.google/go/x/exp/+/24438e51023af3bfc1db8aed43c1342817e8cfcd:rand/rng.go + +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rand + +import ( + "encoding/binary" + "io" + "math/bits" +) + +// PCGSource is an implementation of a 64-bit permuted congruential +// generator as defined in +// +// PCG: A Family of Simple Fast Space-Efficient Statistically Good +// Algorithms for Random Number Generation +// Melissa E. O’Neill, Harvey Mudd College +// http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf +// +// The generator here is the congruential generator PCG XSL RR 128/64 (LCG) +// as found in the software available at http://www.pcg-random.org/. +// It has period 2^128 with 128 bits of state, producing 64-bit values. +// Is state is represented by two uint64 words. +type PCGSource struct { + low uint64 + high uint64 +} + +const ( + maxUint32 = (1 << 32) - 1 + + multiplier = 47026247687942121848144207491837523525 + mulHigh = multiplier >> 64 + mulLow = multiplier & maxUint64 + + increment = 117397592171526113268558934119004209487 + incHigh = increment >> 64 + incLow = increment & maxUint64 + + // TODO: Use these? + initializer = 245720598905631564143578724636268694099 + initHigh = initializer >> 64 + initLow = initializer & maxUint64 +) + +// Seed uses the provided seed value to initialize the generator to a deterministic state. +func (pcg *PCGSource) Seed(seed uint64) { + pcg.low = seed + pcg.high = seed // TODO: What is right? +} + +// Uint64 returns a pseudo-random 64-bit unsigned integer as a uint64. +func (pcg *PCGSource) Uint64() uint64 { + pcg.multiply() + pcg.add() + // XOR high and low 64 bits together and rotate right by high 6 bits of state. + return bits.RotateLeft64(pcg.high^pcg.low, -int(pcg.high>>58)) +} + +func (pcg *PCGSource) add() { + var carry uint64 + pcg.low, carry = Add64(pcg.low, incLow, 0) + pcg.high, _ = Add64(pcg.high, incHigh, carry) +} + +func (pcg *PCGSource) multiply() { + hi, lo := Mul64(pcg.low, mulLow) + hi += pcg.high * mulLow + hi += pcg.low * mulHigh + pcg.low = lo + pcg.high = hi +} + +// MarshalBinary returns the binary representation of the current state of the generator. +func (pcg *PCGSource) MarshalBinary() ([]byte, error) { + var buf [16]byte + binary.BigEndian.PutUint64(buf[:8], pcg.high) + binary.BigEndian.PutUint64(buf[8:], pcg.low) + return buf[:], nil +} + +// UnmarshalBinary sets the state of the generator to the state represented in data. +func (pcg *PCGSource) UnmarshalBinary(data []byte) error { + if len(data) < 16 { + return io.ErrUnexpectedEOF + } + pcg.low = binary.BigEndian.Uint64(data[8:]) + pcg.high = binary.BigEndian.Uint64(data[:8]) + return nil +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/randutil/randutil.go b/vendor/go.mongodb.org/mongo-driver/internal/randutil/randutil.go index d7b753b..dd8c6d6 100644 --- a/vendor/go.mongodb.org/mongo-driver/internal/randutil/randutil.go +++ b/vendor/go.mongodb.org/mongo-driver/internal/randutil/randutil.go @@ -11,67 +11,29 @@ import ( crand "crypto/rand" "fmt" "io" - "math/rand" - "sync" -) - -// A LockedRand wraps a "math/rand".Rand and is safe to use from multiple goroutines. -type LockedRand struct { - mu sync.Mutex - r *rand.Rand -} - -// NewLockedRand returns a new LockedRand that uses random values from src to generate other random -// values. It is safe to use from multiple goroutines. -func NewLockedRand(src rand.Source) *LockedRand { - return &LockedRand{ - // Ignore gosec warning "Use of weak random number generator (math/rand instead of - // crypto/rand)". We intentionally use a pseudo-random number generator. - /* #nosec G404 */ - r: rand.New(src), - } -} -// Read generates len(p) random bytes and writes them into p. It always returns len(p) and a nil -// error. -func (lr *LockedRand) Read(p []byte) (int, error) { - lr.mu.Lock() - n, err := lr.r.Read(p) - lr.mu.Unlock() - return n, err -} - -// Intn returns, as an int, a non-negative pseudo-random number in the half-open interval [0,n). It -// panics if n <= 0. -func (lr *LockedRand) Intn(n int) int { - lr.mu.Lock() - x := lr.r.Intn(n) - lr.mu.Unlock() - return x -} + xrand "go.mongodb.org/mongo-driver/internal/rand" +) -// Shuffle pseudo-randomizes the order of elements. n is the number of elements. Shuffle panics if -// n < 0. swap swaps the elements with indexes i and j. -// -// Note that Shuffle locks the LockedRand, so shuffling large collections may adversely affect other -// concurrent calls. If many concurrent Shuffle and random value calls are required, consider using -// the global "math/rand".Shuffle instead because it uses much more granular locking. -func (lr *LockedRand) Shuffle(n int, swap func(i, j int)) { - lr.mu.Lock() - lr.r.Shuffle(n, swap) - lr.mu.Unlock() +// NewLockedRand returns a new "x/exp/rand" pseudo-random number generator seeded with a +// cryptographically-secure random number. +// It is safe to use from multiple goroutines. +func NewLockedRand() *xrand.Rand { + var randSrc = new(xrand.LockedSource) + randSrc.Seed(cryptoSeed()) + return xrand.New(randSrc) } -// CryptoSeed returns a random int64 read from the "crypto/rand" random number generator. It is +// cryptoSeed returns a random uint64 read from the "crypto/rand" random number generator. It is // intended to be used to seed pseudorandom number generators at package initialization. It panics // if it encounters any errors. -func CryptoSeed() int64 { +func cryptoSeed() uint64 { var b [8]byte _, err := io.ReadFull(crand.Reader, b[:]) if err != nil { panic(fmt.Errorf("failed to read 8 bytes from a \"crypto/rand\".Reader: %v", err)) } - return (int64(b[0]) << 0) | (int64(b[1]) << 8) | (int64(b[2]) << 16) | (int64(b[3]) << 24) | - (int64(b[4]) << 32) | (int64(b[5]) << 40) | (int64(b[6]) << 48) | (int64(b[7]) << 56) + return (uint64(b[0]) << 0) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | + (uint64(b[4]) << 32) | (uint64(b[5]) << 40) | (uint64(b[6]) << 48) | (uint64(b[7]) << 56) } diff --git a/vendor/go.mongodb.org/mongo-driver/internal/uri_validation_errors.go b/vendor/go.mongodb.org/mongo-driver/internal/uri_validation_errors.go deleted file mode 100644 index 21e7300..0000000 --- a/vendor/go.mongodb.org/mongo-driver/internal/uri_validation_errors.go +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package internal - -import "errors" - -var ( - // ErrLoadBalancedWithMultipleHosts is returned when loadBalanced=true is specified in a URI with multiple hosts. - ErrLoadBalancedWithMultipleHosts = errors.New("loadBalanced cannot be set to true if multiple hosts are specified") - // ErrLoadBalancedWithReplicaSet is returned when loadBalanced=true is specified in a URI with the replicaSet option. - ErrLoadBalancedWithReplicaSet = errors.New("loadBalanced cannot be set to true if a replica set name is specified") - // ErrLoadBalancedWithDirectConnection is returned when loadBalanced=true is specified in a URI with the directConnection option. - ErrLoadBalancedWithDirectConnection = errors.New("loadBalanced cannot be set to true if the direct connection option is specified") - // ErrSRVMaxHostsWithReplicaSet is returned when srvMaxHosts > 0 is specified in a URI with the replicaSet option. - ErrSRVMaxHostsWithReplicaSet = errors.New("srvMaxHosts cannot be a positive value if a replica set name is specified") - // ErrSRVMaxHostsWithLoadBalanced is returned when srvMaxHosts > 0 is specified in a URI with loadBalanced=true. - ErrSRVMaxHostsWithLoadBalanced = errors.New("srvMaxHosts cannot be a positive value if loadBalanced is set to true") -) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/uuid/uuid.go b/vendor/go.mongodb.org/mongo-driver/internal/uuid/uuid.go similarity index 54% rename from vendor/go.mongodb.org/mongo-driver/x/mongo/driver/uuid/uuid.go rename to vendor/go.mongodb.org/mongo-driver/internal/uuid/uuid.go index 0978387..86c2a33 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/uuid/uuid.go +++ b/vendor/go.mongodb.org/mongo-driver/internal/uuid/uuid.go @@ -4,11 +4,11 @@ // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -package uuid // import "go.mongodb.org/mongo-driver/x/mongo/driver/uuid" +package uuid import ( + "encoding/hex" "io" - "math/rand" "go.mongodb.org/mongo-driver/internal/randutil" ) @@ -16,47 +16,53 @@ import ( // UUID represents a UUID. type UUID [16]byte -// A source is a UUID generator that reads random values from a randutil.LockedRand. -// It is safe to use from multiple goroutines. +// A source is a UUID generator that reads random values from a io.Reader. +// It should be safe to use from multiple goroutines. type source struct { - random *randutil.LockedRand + random io.Reader } // new returns a random UUIDv4 with bytes read from the source's random number generator. func (s *source) new() (UUID, error) { - var uuid [16]byte - + var uuid UUID _, err := io.ReadFull(s.random, uuid[:]) if err != nil { - return [16]byte{}, err + return UUID{}, err } uuid[6] = (uuid[6] & 0x0f) | 0x40 // Version 4 uuid[8] = (uuid[8] & 0x3f) | 0x80 // Variant is 10 - return uuid, nil } -// newGlobalSource returns a source that uses a "math/rand" pseudo-random number generator seeded -// with a cryptographically-secure random number. It is intended to be used to initialize the -// package-global UUID generator. -func newGlobalSource() *source { +// newSource returns a source that uses a pseudo-random number generator in reandutil package. +// It is intended to be used to initialize the package-global UUID generator. +func newSource() *source { return &source{ - random: randutil.NewLockedRand(rand.NewSource(randutil.CryptoSeed())), + random: randutil.NewLockedRand(), } } // globalSource is a package-global pseudo-random UUID generator. -var globalSource = newGlobalSource() +var globalSource = newSource() -// New returns a random UUIDv4. It uses a "math/rand" pseudo-random number generator seeded with a -// cryptographically-secure random number at package initialization. +// New returns a random UUIDv4. It uses a global pseudo-random number generator in randutil +// at package initialization. // // New should not be used to generate cryptographically-secure random UUIDs. func New() (UUID, error) { return globalSource.new() } -// Equal returns true if two UUIDs are equal. -func Equal(a, b UUID) bool { - return a == b +func (uuid UUID) String() string { + var str [36]byte + hex.Encode(str[:], uuid[:4]) + str[8] = '-' + hex.Encode(str[9:13], uuid[4:6]) + str[13] = '-' + hex.Encode(str[14:18], uuid[6:8]) + str[18] = '-' + hex.Encode(str[19:23], uuid[8:10]) + str[23] = '-' + hex.Encode(str[24:], uuid[10:]) + return string(str[:]) } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/address/addr.go b/vendor/go.mongodb.org/mongo-driver/mongo/address/addr.go index 5655b34..fb6abbc 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/address/addr.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/address/addr.go @@ -4,6 +4,7 @@ // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +// Package address provides structured representations of network addresses. package address // import "go.mongodb.org/mongo-driver/mongo/address" import ( diff --git a/vendor/go.mongodb.org/mongo-driver/internal/background_context.go b/vendor/go.mongodb.org/mongo-driver/mongo/background_context.go similarity index 87% rename from vendor/go.mongodb.org/mongo-driver/internal/background_context.go rename to vendor/go.mongodb.org/mongo-driver/mongo/background_context.go index 6f190ed..e4146e8 100644 --- a/vendor/go.mongodb.org/mongo-driver/internal/background_context.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/background_context.go @@ -4,7 +4,7 @@ // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -package internal +package mongo import "context" @@ -16,9 +16,9 @@ type backgroundContext struct { childValuesCtx context.Context } -// NewBackgroundContext creates a new Context whose behavior matches that of context.Background(), but Value calls are +// newBackgroundContext creates a new Context whose behavior matches that of context.Background(), but Value calls are // forwarded to the provided ctx parameter. If ctx is nil, context.Background() is returned. -func NewBackgroundContext(ctx context.Context) context.Context { +func newBackgroundContext(ctx context.Context) context.Context { if ctx == nil { return context.Background() } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/batch_cursor.go b/vendor/go.mongodb.org/mongo-driver/mongo/batch_cursor.go index 0b7432f..51d59d0 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/batch_cursor.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/batch_cursor.go @@ -1,7 +1,14 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + package mongo import ( "context" + "time" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" @@ -28,6 +35,22 @@ type batchCursor interface { // Close closes the cursor. Close(context.Context) error + + // SetBatchSize is a modifier function used to adjust the batch size of + // the cursor that implements it. + SetBatchSize(int32) + + // SetMaxTime will set the maximum amount of time the server will allow + // the operations to execute. The server will error if this field is set + // but the cursor is not configured with awaitData=true. + // + // The time.Duration value passed by this setter will be converted and + // rounded down to the nearest millisecond. + SetMaxTime(time.Duration) + + // SetComment will set a user-configurable comment that can be used to + // identify the operation in server logs. + SetComment(interface{}) } // changeStreamCursor is the interface implemented by batch cursors that also provide the functionality for retrieving diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/bulk_write.go b/vendor/go.mongodb.org/mongo-driver/mongo/bulk_write.go index 0446e7f..81dfbb1 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/bulk_write.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/bulk_write.go @@ -8,8 +8,10 @@ package mongo import ( "context" + "errors" "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/writeconcern" @@ -25,8 +27,9 @@ type bulkWriteBatch struct { indexes []int } -// bulkWrite perfoms a bulkwrite operation +// bulkWrite performs a bulkwrite operation type bulkWrite struct { + comment interface{} ordered *bool bypassDocumentValidation *bool models []WriteModel @@ -36,6 +39,7 @@ type bulkWrite struct { writeConcern *writeconcern.WriteConcern result BulkWriteResult let interface{} + bypassEmptyTsReplacement *bool } func (bw *bulkWrite) execute(ctx context.Context) error { @@ -69,7 +73,7 @@ func (bw *bulkWrite) execute(ctx context.Context) error { bwErr.WriteErrors = append(bwErr.WriteErrors, batchErr.WriteErrors...) - commandErrorOccurred := err != nil && err != driver.ErrUnacknowledgedWrite + commandErrorOccurred := err != nil && !errors.Is(err, driver.ErrUnacknowledgedWrite) writeErrorOccurred := len(batchErr.WriteErrors) > 0 || batchErr.WriteConcernError != nil if !continueOnError && (commandErrorOccurred || writeErrorOccurred) { if err != nil { @@ -106,40 +110,40 @@ func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWr case *InsertOneModel: res, err := bw.runInsert(ctx, batch) if err != nil { - writeErr, ok := err.(driver.WriteCommandError) - if !ok { + var writeErr driver.WriteCommandError + if !errors.As(err, &writeErr) { return BulkWriteResult{}, batchErr, err } writeErrors = writeErr.WriteErrors batchErr.Labels = writeErr.Labels batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError) } - batchRes.InsertedCount = int64(res.N) + batchRes.InsertedCount = res.N case *DeleteOneModel, *DeleteManyModel: res, err := bw.runDelete(ctx, batch) if err != nil { - writeErr, ok := err.(driver.WriteCommandError) - if !ok { + var writeErr driver.WriteCommandError + if !errors.As(err, &writeErr) { return BulkWriteResult{}, batchErr, err } writeErrors = writeErr.WriteErrors batchErr.Labels = writeErr.Labels batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError) } - batchRes.DeletedCount = int64(res.N) + batchRes.DeletedCount = res.N case *ReplaceOneModel, *UpdateOneModel, *UpdateManyModel: res, err := bw.runUpdate(ctx, batch) if err != nil { - writeErr, ok := err.(driver.WriteCommandError) - if !ok { + var writeErr driver.WriteCommandError + if !errors.As(err, &writeErr) { return BulkWriteResult{}, batchErr, err } writeErrors = writeErr.WriteErrors batchErr.Labels = writeErr.Labels batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError) } - batchRes.MatchedCount = int64(res.N) - batchRes.ModifiedCount = int64(res.NModified) + batchRes.MatchedCount = res.N + batchRes.ModifiedCount = res.NModified batchRes.UpsertedCount = int64(len(res.Upserted)) for _, upsert := range res.Upserted { batchRes.UpsertedIDs[int64(batch.indexes[upsert.Index])] = upsert.ID @@ -164,7 +168,11 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera var i int for _, model := range batch.models { converted := model.(*InsertOneModel) - doc, _, err := transformAndEnsureID(bw.collection.registry, converted.Document) + doc, err := marshal(converted.Document, bw.collection.bsonOpts, bw.collection.registry) + if err != nil { + return operation.InsertResult{}, err + } + doc, _, err = ensureID(doc, primitive.NilObjectID, bw.collection.bsonOpts, bw.collection.registry) if err != nil { return operation.InsertResult{}, err } @@ -178,7 +186,15 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock). Database(bw.collection.db.name).Collection(bw.collection.name). Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE). - ServerAPI(bw.collection.client.serverAPI) + ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout). + Logger(bw.collection.client.logger).Authenticator(bw.collection.client.authenticator) + if bw.comment != nil { + comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) + if err != nil { + return op.Result(), err + } + op.Comment(comment) + } if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation { op = op.BypassDocumentValidation(*bw.bypassDocumentValidation) } @@ -192,6 +208,10 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera } op = op.Retry(retry) + if bw.bypassEmptyTsReplacement != nil { + op.BypassEmptyTsReplacement(*bw.bypassEmptyTsReplacement) + } + err := op.Execute(ctx) return op.Result(), err @@ -208,10 +228,22 @@ func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (opera switch converted := model.(type) { case *DeleteOneModel: - doc, err = createDeleteDoc(converted.Filter, converted.Collation, converted.Hint, true, bw.collection.registry) + doc, err = createDeleteDoc( + converted.Filter, + converted.Collation, + converted.Hint, + true, + bw.collection.bsonOpts, + bw.collection.registry) hasHint = hasHint || (converted.Hint != nil) case *DeleteManyModel: - doc, err = createDeleteDoc(converted.Filter, converted.Collation, converted.Hint, false, bw.collection.registry) + doc, err = createDeleteDoc( + converted.Filter, + converted.Collation, + converted.Hint, + false, + bw.collection.bsonOpts, + bw.collection.registry) hasHint = hasHint || (converted.Hint != nil) } @@ -228,9 +260,17 @@ func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (opera ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock). Database(bw.collection.db.name).Collection(bw.collection.name). Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint). - ServerAPI(bw.collection.client.serverAPI) + ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout). + Logger(bw.collection.client.logger).Authenticator(bw.collection.client.authenticator) + if bw.comment != nil { + comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) + if err != nil { + return op.Result(), err + } + op.Comment(comment) + } if bw.let != nil { - let, err := transformBsoncoreDocument(bw.collection.registry, bw.let, true, "let") + let, err := marshal(bw.let, bw.collection.bsonOpts, bw.collection.registry) if err != nil { return operation.DeleteResult{}, err } @@ -250,10 +290,15 @@ func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (opera return op.Result(), err } -func createDeleteDoc(filter interface{}, collation *options.Collation, hint interface{}, deleteOne bool, - registry *bsoncodec.Registry) (bsoncore.Document, error) { - - f, err := transformBsoncoreDocument(registry, filter, true, "filter") +func createDeleteDoc( + filter interface{}, + collation *options.Collation, + hint interface{}, + deleteOne bool, + bsonOpts *options.BSONOptions, + registry *bsoncodec.Registry, +) (bsoncore.Document, error) { + f, err := marshal(filter, bsonOpts, registry) if err != nil { return nil, err } @@ -269,7 +314,10 @@ func createDeleteDoc(filter interface{}, collation *options.Collation, hint inte doc = bsoncore.AppendDocumentElement(doc, "collation", collation.ToDocument()) } if hint != nil { - hintVal, err := transformValue(registry, hint, false, "hint") + if isUnorderedMap(hint) { + return nil, ErrMapForOrderedArgument{"hint"} + } + hintVal, err := marshalValue(hint, bsonOpts, registry) if err != nil { return nil, err } @@ -290,17 +338,44 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera switch converted := model.(type) { case *ReplaceOneModel: - doc, err = createUpdateDoc(converted.Filter, converted.Replacement, converted.Hint, nil, converted.Collation, converted.Upsert, false, - false, bw.collection.registry) + doc, err = createUpdateDoc( + converted.Filter, + converted.Replacement, + converted.Hint, + nil, + converted.Collation, + converted.Upsert, + false, + false, + bw.collection.bsonOpts, + bw.collection.registry) hasHint = hasHint || (converted.Hint != nil) case *UpdateOneModel: - doc, err = createUpdateDoc(converted.Filter, converted.Update, converted.Hint, converted.ArrayFilters, converted.Collation, converted.Upsert, false, - true, bw.collection.registry) + doc, err = createUpdateDoc( + converted.Filter, + converted.Update, + converted.Hint, + converted.ArrayFilters, + converted.Collation, + converted.Upsert, + false, + true, + bw.collection.bsonOpts, + bw.collection.registry) hasHint = hasHint || (converted.Hint != nil) hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil) case *UpdateManyModel: - doc, err = createUpdateDoc(converted.Filter, converted.Update, converted.Hint, converted.ArrayFilters, converted.Collation, converted.Upsert, true, - true, bw.collection.registry) + doc, err = createUpdateDoc( + converted.Filter, + converted.Update, + converted.Hint, + converted.ArrayFilters, + converted.Collation, + converted.Upsert, + true, + true, + bw.collection.bsonOpts, + bw.collection.registry) hasHint = hasHint || (converted.Hint != nil) hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil) } @@ -316,9 +391,18 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock). Database(bw.collection.db.name).Collection(bw.collection.name). Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint). - ArrayFilters(hasArrayFilters).ServerAPI(bw.collection.client.serverAPI) + ArrayFilters(hasArrayFilters).ServerAPI(bw.collection.client.serverAPI). + Timeout(bw.collection.client.timeout).Logger(bw.collection.client.logger). + Authenticator(bw.collection.client.authenticator) + if bw.comment != nil { + comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) + if err != nil { + return op.Result(), err + } + op.Comment(comment) + } if bw.let != nil { - let, err := transformBsoncoreDocument(bw.collection.registry, bw.let, true, "let") + let, err := marshal(bw.let, bw.collection.bsonOpts, bw.collection.registry) if err != nil { return operation.UpdateResult{}, err } @@ -336,10 +420,15 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera } op = op.Retry(retry) + if bw.bypassEmptyTsReplacement != nil { + op.BypassEmptyTsReplacement(*bw.bypassEmptyTsReplacement) + } + err := op.Execute(ctx) return op.Result(), err } + func createUpdateDoc( filter interface{}, update interface{}, @@ -349,9 +438,10 @@ func createUpdateDoc( upsert *bool, multi bool, checkDollarKey bool, + bsonOpts *options.BSONOptions, registry *bsoncodec.Registry, ) (bsoncore.Document, error) { - f, err := transformBsoncoreDocument(registry, filter, true, "filter") + f, err := marshal(filter, bsonOpts, registry) if err != nil { return nil, err } @@ -359,7 +449,7 @@ func createUpdateDoc( uidx, updateDoc := bsoncore.AppendDocumentStart(nil) updateDoc = bsoncore.AppendDocumentElement(updateDoc, "q", f) - u, err := transformUpdateValue(registry, update, checkDollarKey) + u, err := marshalUpdateValue(update, bsonOpts, registry, checkDollarKey) if err != nil { return nil, err } @@ -371,11 +461,15 @@ func createUpdateDoc( } if arrayFilters != nil { - arr, err := arrayFilters.ToArrayDocument() + reg := registry + if arrayFilters.Registry != nil { + reg = arrayFilters.Registry + } + arr, err := marshalValue(arrayFilters.Filters, bsonOpts, reg) if err != nil { return nil, err } - updateDoc = bsoncore.AppendArrayElement(updateDoc, "arrayFilters", arr) + updateDoc = bsoncore.AppendArrayElement(updateDoc, "arrayFilters", arr.Data) } if collation != nil { @@ -387,7 +481,10 @@ func createUpdateDoc( } if hint != nil { - hintVal, err := transformValue(registry, hint, false, "hint") + if isUnorderedMap(hint) { + return nil, ErrMapForOrderedArgument{"hint"} + } + hintVal, err := marshalValue(hint, bsonOpts, registry) if err != nil { return nil, err } @@ -395,7 +492,6 @@ func createUpdateDoc( } updateDoc, _ = bsoncore.AppendDocumentEnd(updateDoc, uidx) - return updateDoc, nil } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/bulk_write_models.go b/vendor/go.mongodb.org/mongo-driver/mongo/bulk_write_models.go index b4b8e3e..64f4589 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/bulk_write_models.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/bulk_write_models.go @@ -152,7 +152,7 @@ func (rom *ReplaceOneModel) SetFilter(filter interface{}) *ReplaceOneModel { } // SetReplacement specifies a document that will be used to replace the selected document. It cannot be nil and cannot -// contain any update operators (https://docs.mongodb.com/manual/reference/operator/update/). +// contain any update operators (https://www.mongodb.com/docs/manual/reference/operator/update/). func (rom *ReplaceOneModel) SetReplacement(rep interface{}) *ReplaceOneModel { rom.Replacement = rep return rom @@ -210,7 +210,7 @@ func (uom *UpdateOneModel) SetFilter(filter interface{}) *UpdateOneModel { } // SetUpdate specifies the modifications to be made to the selected document. The value must be a document containing -// update operators (https://docs.mongodb.com/manual/reference/operator/update/). It cannot be nil or empty. +// update operators (https://www.mongodb.com/docs/manual/reference/operator/update/). It cannot be nil or empty. func (uom *UpdateOneModel) SetUpdate(update interface{}) *UpdateOneModel { uom.Update = update return uom @@ -274,7 +274,7 @@ func (umm *UpdateManyModel) SetFilter(filter interface{}) *UpdateManyModel { } // SetUpdate specifies the modifications to be made to the selected documents. The value must be a document containing -// update operators (https://docs.mongodb.com/manual/reference/operator/update/). It cannot be nil or empty. +// update operators (https://www.mongodb.com/docs/manual/reference/operator/update/). It cannot be nil or empty. func (umm *UpdateManyModel) SetUpdate(update interface{}) *UpdateManyModel { umm.Update = update return umm diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/change_stream.go b/vendor/go.mongodb.org/mongo-driver/mongo/change_stream.go index a76eb7c..3ea8baf 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/change_stream.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/change_stream.go @@ -17,6 +17,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/internal/csot" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" @@ -63,7 +64,7 @@ var ( // ChangeStream is used to iterate over a stream of events. Each event can be decoded into a Go type via the Decode // method or accessed as raw BSON via the Current field. This type is not goroutine safe and must not be used // concurrently by multiple goroutines. For more information about change streams, see -// https://docs.mongodb.com/manual/changeStreams/. +// https://www.mongodb.com/docs/manual/changeStreams/. type ChangeStream struct { // Current is the BSON bytes of the current event. This property is only valid until the next call to Next or // TryNext. If continued access is required, a copy must be made. @@ -79,6 +80,7 @@ type ChangeStream struct { err error sess *session.Client client *Client + bsonOpts *options.BSONOptions registry *bsoncodec.Registry streamType StreamType options *options.ChangeStreamOptions @@ -91,6 +93,7 @@ type changeStreamConfig struct { readConcern *readconcern.ReadConcern readPreference *readpref.ReadPref client *Client + bsonOpts *options.BSONOptions registry *bsoncodec.Registry streamType StreamType collectionName string @@ -104,8 +107,13 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in ctx = context.Background() } + cursorOpts := config.client.createBaseCursorOptions() + + cursorOpts.MarshalValueEncoderFn = newEncoderFn(config.bsonOpts, config.registry) + cs := &ChangeStream{ client: config.client, + bsonOpts: config.bsonOpts, registry: config.registry, streamType: config.streamType, options: options.MergeChangeStreamOptions(opts...), @@ -113,15 +121,12 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in description.ReadPrefSelector(config.readPreference), description.LatencySelector(config.client.localThreshold), }), - cursorOptions: config.client.createBaseCursorOptions(), + cursorOptions: cursorOpts, } cs.sess = sessionFromContext(ctx) if cs.sess == nil && cs.client.sessionPool != nil { - cs.sess, cs.err = session.NewClientSession(cs.client.sessionPool, cs.client.id, session.Implicit) - if cs.err != nil { - return nil, cs.Err() - } + cs.sess = session.NewImplicitClientSession(cs.client.sessionPool, cs.client.id) } if cs.err = cs.client.validSession(cs.sess); cs.err != nil { closeImplicitSession(cs.sess) @@ -132,11 +137,21 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in ReadPreference(config.readPreference).ReadConcern(config.readConcern). Deployment(cs.client.deployment).ClusterClock(cs.client.clock). CommandMonitor(cs.client.monitor).Session(cs.sess).ServerSelector(cs.selector).Retry(driver.RetryNone). - ServerAPI(cs.client.serverAPI).Crypt(config.crypt) + ServerAPI(cs.client.serverAPI).Crypt(config.crypt).Timeout(cs.client.timeout). + Authenticator(cs.client.authenticator) if cs.options.Collation != nil { cs.aggregate.Collation(bsoncore.Document(cs.options.Collation.ToDocument())) } + if comment := cs.options.Comment; comment != nil { + cs.aggregate.Comment(*comment) + + commentVal, err := marshalValue(comment, cs.bsonOpts, cs.registry) + if err != nil { + return nil, err + } + cs.cursorOptions.Comment = commentVal + } if cs.options.BatchSize != nil { cs.aggregate.BatchSize(*cs.options.BatchSize) cs.cursorOptions.BatchSize = *cs.options.BatchSize @@ -230,7 +245,6 @@ func (cs *ChangeStream) createOperationDeployment(server driver.Server, connecti func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) error { var server driver.Server var conn driver.Connection - var err error if server, cs.err = cs.client.deployment.SelectServer(ctx, cs.selector); cs.err != nil { return cs.Err() @@ -246,7 +260,10 @@ func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) err if resuming { cs.replaceOptions(cs.wireVersion) - csOptDoc := cs.createPipelineOptionsDoc() + csOptDoc, err := cs.createPipelineOptionsDoc() + if err != nil { + return err + } pipIdx, pipDoc := bsoncore.AppendDocumentStart(nil) pipDoc = bsoncore.AppendDocumentElement(pipDoc, "$changeStream", csOptDoc) if pipDoc, cs.err = bsoncore.AppendDocumentEnd(pipDoc, pipIdx); cs.err != nil { @@ -261,48 +278,72 @@ func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) err cs.aggregate.Pipeline(plArr) } - if original := cs.aggregate.Execute(ctx); original != nil { - retryableRead := cs.client.retryReads && cs.wireVersion != nil && cs.wireVersion.Max >= 6 - if !retryableRead { - cs.err = replaceErrors(original) - return cs.err + // If cs.client.timeout is set and context is not already a Timeout context, + // honor cs.client.timeout in new Timeout context for change stream + // operation execution and potential retry. + if cs.client.timeout != nil && !csot.IsTimeoutContext(ctx) { + newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *cs.client.timeout) + // Redefine ctx to be the new timeout-derived context. + ctx = newCtx + // Cancel the timeout-derived context at the end of executeOperation to avoid a context leak. + defer cancelFunc() + } + + // Execute the aggregate, retrying on retryable errors once (1) if retryable reads are enabled and + // infinitely (-1) if context is a Timeout context. + var retries int + if cs.client.retryReads { + retries = 1 + } + if csot.IsTimeoutContext(ctx) { + retries = -1 + } + + var err error +AggregateExecuteLoop: + for { + err = cs.aggregate.Execute(ctx) + // If no error or no retries remain, do not retry. + if err == nil || retries == 0 { + break AggregateExecuteLoop } - cs.err = original - switch tt := original.(type) { + switch tt := err.(type) { case driver.Error: + // If error is not retryable, do not retry. if !tt.RetryableRead() { - break + break AggregateExecuteLoop } + // If error is retryable: subtract 1 from retries, redo server selection, checkout + // a connection, and restart loop. + retries-- server, err = cs.client.deployment.SelectServer(ctx, cs.selector) if err != nil { - break + break AggregateExecuteLoop } conn.Close() conn, err = server.Connection(ctx) if err != nil { - break + break AggregateExecuteLoop } defer conn.Close() - cs.wireVersion = conn.Description().WireVersion - if cs.wireVersion == nil || cs.wireVersion.Max < 6 { - break - } + // Update the wire version with data from the new connection. + cs.wireVersion = conn.Description().WireVersion + // Reset deployment. cs.aggregate.Deployment(cs.createOperationDeployment(server, conn)) - cs.err = cs.aggregate.Execute(ctx) + default: + // Do not retry if error is not a driver error. + break AggregateExecuteLoop } - - if cs.err != nil { - cs.err = replaceErrors(cs.err) - return cs.Err() - } - } - cs.err = nil + if err != nil { + cs.err = replaceErrors(err) + return cs.err + } cr := cs.aggregate.ResultCursorResponse() cr.Server = server @@ -356,16 +397,17 @@ func (cs *ChangeStream) storeResumeToken() error { func (cs *ChangeStream) buildPipelineSlice(pipeline interface{}) error { val := reflect.ValueOf(pipeline) if !val.IsValid() || !(val.Kind() == reflect.Slice) { - cs.err = errors.New("can only transform slices and arrays into aggregation pipelines, but got invalid") + cs.err = errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid") return cs.err } cs.pipelineSlice = make([]bsoncore.Document, 0, val.Len()+1) csIdx, csDoc := bsoncore.AppendDocumentStart(nil) - csDocTemp := cs.createPipelineOptionsDoc() - if cs.err != nil { - return cs.err + + csDocTemp, err := cs.createPipelineOptionsDoc() + if err != nil { + return err } csDoc = bsoncore.AppendDocumentElement(csDoc, "$changeStream", csDocTemp) csDoc, cs.err = bsoncore.AppendDocumentEnd(csDoc, csIdx) @@ -376,7 +418,7 @@ func (cs *ChangeStream) buildPipelineSlice(pipeline interface{}) error { for i := 0; i < val.Len(); i++ { var elem []byte - elem, cs.err = transformBsoncoreDocument(cs.registry, val.Index(i).Interface(), true, fmt.Sprintf("pipeline stage :%v", i)) + elem, cs.err = marshal(val.Index(i).Interface(), cs.bsonOpts, cs.registry) if cs.err != nil { return cs.err } @@ -387,32 +429,40 @@ func (cs *ChangeStream) buildPipelineSlice(pipeline interface{}) error { return cs.err } -func (cs *ChangeStream) createPipelineOptionsDoc() bsoncore.Document { +func (cs *ChangeStream) createPipelineOptionsDoc() (bsoncore.Document, error) { plDocIdx, plDoc := bsoncore.AppendDocumentStart(nil) if cs.streamType == ClientStream { plDoc = bsoncore.AppendBooleanElement(plDoc, "allChangesForCluster", true) } - if cs.options.FullDocument != nil { + if cs.options.FullDocument != nil && *cs.options.FullDocument != options.Default { plDoc = bsoncore.AppendStringElement(plDoc, "fullDocument", string(*cs.options.FullDocument)) } + if cs.options.FullDocumentBeforeChange != nil { + plDoc = bsoncore.AppendStringElement(plDoc, "fullDocumentBeforeChange", string(*cs.options.FullDocumentBeforeChange)) + } + if cs.options.ResumeAfter != nil { var raDoc bsoncore.Document - raDoc, cs.err = transformBsoncoreDocument(cs.registry, cs.options.ResumeAfter, true, "resumeAfter") + raDoc, cs.err = marshal(cs.options.ResumeAfter, cs.bsonOpts, cs.registry) if cs.err != nil { - return nil + return nil, cs.err } plDoc = bsoncore.AppendDocumentElement(plDoc, "resumeAfter", raDoc) } + if cs.options.ShowExpandedEvents != nil { + plDoc = bsoncore.AppendBooleanElement(plDoc, "showExpandedEvents", *cs.options.ShowExpandedEvents) + } + if cs.options.StartAfter != nil { var saDoc bsoncore.Document - saDoc, cs.err = transformBsoncoreDocument(cs.registry, cs.options.StartAfter, true, "startAfter") + saDoc, cs.err = marshal(cs.options.StartAfter, cs.bsonOpts, cs.registry) if cs.err != nil { - return nil + return nil, cs.err } plDoc = bsoncore.AppendDocumentElement(plDoc, "startAfter", saDoc) @@ -428,10 +478,10 @@ func (cs *ChangeStream) createPipelineOptionsDoc() bsoncore.Document { } if plDoc, cs.err = bsoncore.AppendDocumentEnd(plDoc, plDocIdx); cs.err != nil { - return nil + return nil, cs.err } - return plDoc + return plDoc, nil } func (cs *ChangeStream) pipelineToBSON() (bsoncore.Document, error) { @@ -482,6 +532,22 @@ func (cs *ChangeStream) ID() int64 { return cs.cursor.ID() } +// RemainingBatchLength returns the number of documents left in the current batch. If this returns zero, the subsequent +// call to Next or TryNext will do a network request to fetch the next batch. +func (cs *ChangeStream) RemainingBatchLength() int { + return len(cs.batch) +} + +// SetBatchSize sets the number of documents to fetch from the database with +// each iteration of the ChangeStream's "Next" or "TryNext" method. This setting +// only affects subsequent document batches fetched from the database. +func (cs *ChangeStream) SetBatchSize(size int32) { + // Set batch size on the cursor options also so any "resumed" change stream + // cursors will pick up the latest batch size setting. + cs.cursorOptions.BatchSize = size + cs.cursor.SetBatchSize(size) +} + // Decode will unmarshal the current event document into val and return any errors from the unmarshalling process // without any modification. If val is nil or is a typed nil, an error will be returned. func (cs *ChangeStream) Decode(val interface{}) error { @@ -489,7 +555,11 @@ func (cs *ChangeStream) Decode(val interface{}) error { return ErrNilCursor } - return bson.UnmarshalWithRegistry(cs.registry, cs.Current, val) + dec, err := getDecoder(cs.Current, cs.bsonOpts, cs.registry) + if err != nil { + return fmt.Errorf("error configuring BSON decoder: %w", err) + } + return dec.Decode(val) } // Err returns the last error seen by the change stream, or nil if no errors has occurred. @@ -626,8 +696,8 @@ func (cs *ChangeStream) loopNext(ctx context.Context, nonBlocking bool) { } func (cs *ChangeStream) isResumableError() bool { - commandErr, ok := cs.err.(CommandError) - if !ok || commandErr.HasErrorLabel(networkErrorLabel) { + var commandErr CommandError + if !errors.As(cs.err, &commandErr) || commandErr.HasErrorLabel(networkErrorLabel) { // All non-server errors or network errors are resumable. return true } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/change_stream_deployment.go b/vendor/go.mongodb.org/mongo-driver/mongo/change_stream_deployment.go index 36c6e25..4dca59f 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/change_stream_deployment.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/change_stream_deployment.go @@ -8,7 +8,6 @@ package mongo import ( "context" - "time" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" @@ -36,8 +35,8 @@ func (c *changeStreamDeployment) Connection(context.Context) (driver.Connection, return c.conn, nil } -func (c *changeStreamDeployment) MinRTT() time.Duration { - return c.server.MinRTT() +func (c *changeStreamDeployment) RTTMonitor() driver.RTTMonitor { + return c.server.RTTMonitor() } func (c *changeStreamDeployment) ProcessError(err error, conn driver.Connection) driver.ProcessErrorResult { diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/client.go b/vendor/go.mongodb.org/mongo-driver/mongo/client.go index ddc08bd..232d0a3 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/client.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/client.go @@ -8,15 +8,17 @@ package mongo import ( "context" - "crypto/tls" "errors" "fmt" - "strings" + "net/http" "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/internal/httputil" + "go.mongodb.org/mongo-driver/internal/logger" + "go.mongodb.org/mongo-driver/internal/uuid" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" @@ -25,14 +27,17 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth" - "go.mongodb.org/mongo-driver/x/mongo/driver/ocsp" + "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt" + mcopts "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" - "go.mongodb.org/mongo-driver/x/mongo/driver/uuid" ) -const defaultLocalThreshold = 15 * time.Millisecond +const ( + defaultLocalThreshold = 15 * time.Millisecond + defaultMaxPoolSize = 100 +) var ( // keyVaultCollOpts specifies options used to communicate with the key vault collection @@ -48,29 +53,34 @@ var ( // The Client type opens and closes connections automatically and maintains a pool of idle connections. For // connection pool configuration options, see documentation for the ClientOptions type in the mongo/options package. type Client struct { - id uuid.UUID - topologyOptions []topology.Option - deployment driver.Deployment - localThreshold time.Duration - retryWrites bool - retryReads bool - clock *session.ClusterClock - readPreference *readpref.ReadPref - readConcern *readconcern.ReadConcern - writeConcern *writeconcern.WriteConcern - registry *bsoncodec.Registry - monitor *event.CommandMonitor - serverAPI *driver.ServerAPIOptions - serverMonitor *event.ServerMonitor - sessionPool *session.Pool + id uuid.UUID + deployment driver.Deployment + localThreshold time.Duration + retryWrites bool + retryReads bool + clock *session.ClusterClock + readPreference *readpref.ReadPref + readConcern *readconcern.ReadConcern + writeConcern *writeconcern.WriteConcern + bsonOpts *options.BSONOptions + registry *bsoncodec.Registry + monitor *event.CommandMonitor + serverAPI *driver.ServerAPIOptions + serverMonitor *event.ServerMonitor + sessionPool *session.Pool + timeout *time.Duration + httpClient *http.Client + logger *logger.Logger // client-side encryption fields - keyVaultClientFLE *Client - keyVaultCollFLE *Collection - mongocryptdFLE *mcryptClient - cryptFLE driver.Crypt - metadataClientFLE *Client - internalClientFLE *Client + keyVaultClientFLE *Client + keyVaultCollFLE *Collection + mongocryptdFLE *mongocryptdClient + cryptFLE driver.Crypt + metadataClientFLE *Client + internalClientFLE *Client + encryptedFieldsMap map[string]interface{} + authenticator driver.Authenticator } // Connect creates a new Client and then initializes it using the Connect method. This is equivalent to calling @@ -120,6 +130,8 @@ func Connect(ctx context.Context, opts ...*options.ClientOptions) (*Client, erro // option fields of previous options, there is no partial overwriting. For example, if Username is // set in the Auth field for the first option, and Password is set for the second but with no // Username, after the merge the Username field will be empty. +// +// Deprecated: Use [Connect] instead. func NewClient(opts ...*options.ClientOptions) (*Client, error) { clientOpt := options.MergeClientOptions(opts...) @@ -129,17 +141,107 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { } client := &Client{id: id} - err = client.configure(clientOpt) + // ClusterClock + client.clock = new(session.ClusterClock) + + // LocalThreshold + client.localThreshold = defaultLocalThreshold + if clientOpt.LocalThreshold != nil { + client.localThreshold = *clientOpt.LocalThreshold + } + // Monitor + if clientOpt.Monitor != nil { + client.monitor = clientOpt.Monitor + } + // ServerMonitor + if clientOpt.ServerMonitor != nil { + client.serverMonitor = clientOpt.ServerMonitor + } + // ReadConcern + client.readConcern = readconcern.New() + if clientOpt.ReadConcern != nil { + client.readConcern = clientOpt.ReadConcern + } + // ReadPreference + client.readPreference = readpref.Primary() + if clientOpt.ReadPreference != nil { + client.readPreference = clientOpt.ReadPreference + } + // BSONOptions + if clientOpt.BSONOptions != nil { + client.bsonOpts = clientOpt.BSONOptions + } + // Registry + client.registry = bson.DefaultRegistry + if clientOpt.Registry != nil { + client.registry = clientOpt.Registry + } + // RetryWrites + client.retryWrites = true // retry writes on by default + if clientOpt.RetryWrites != nil { + client.retryWrites = *clientOpt.RetryWrites + } + client.retryReads = true + if clientOpt.RetryReads != nil { + client.retryReads = *clientOpt.RetryReads + } + // Timeout + client.timeout = clientOpt.Timeout + client.httpClient = clientOpt.HTTPClient + // WriteConcern + if clientOpt.WriteConcern != nil { + client.writeConcern = clientOpt.WriteConcern + } + // AutoEncryptionOptions + if clientOpt.AutoEncryptionOptions != nil { + if err := client.configureAutoEncryption(clientOpt); err != nil { + return nil, err + } + } else { + client.cryptFLE = clientOpt.Crypt + } + + // Deployment + if clientOpt.Deployment != nil { + client.deployment = clientOpt.Deployment + } + + // Set default options + if clientOpt.MaxPoolSize == nil { + clientOpt.SetMaxPoolSize(defaultMaxPoolSize) + } + + if clientOpt.Auth != nil { + client.authenticator, err = auth.CreateAuthenticator( + clientOpt.Auth.AuthMechanism, + topology.ConvertCreds(clientOpt.Auth), + clientOpt.HTTPClient, + ) + if err != nil { + return nil, fmt.Errorf("error creating authenticator: %w", err) + } + } + + cfg, err := topology.NewConfigWithAuthenticator(clientOpt, client.clock, client.authenticator) if err != nil { return nil, err } + client.serverAPI = topology.ServerAPIFromServerOptions(cfg.ServerOpts) + if client.deployment == nil { - client.deployment, err = topology.New(client.topologyOptions...) + client.deployment, err = topology.New(cfg) if err != nil { return nil, replaceErrors(err) } } + + // Create a logger for the client. + client.logger, err = newLogger(clientOpt.LoggerOptions) + if err != nil { + return nil, fmt.Errorf("invalid logger options: %w", err) + } + return client, nil } @@ -148,6 +250,8 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { // // Connect starts background goroutines to monitor the state of the deployment and does not do any I/O in the main // goroutine. The Client.Ping method can be used to verify that the connection was created successfully. +// +// Deprecated: Use [mongo.Connect] instead. func (c *Client) Connect(ctx context.Context) error { if connector, ok := c.deployment.(driver.Connector); ok { err := connector.Connect() @@ -201,10 +305,18 @@ func (c *Client) Connect(ctx context.Context) error { // or write operations. If this method returns with no errors, all connections // associated with this Client have been closed. func (c *Client) Disconnect(ctx context.Context) error { + if c.logger != nil { + defer c.logger.Close() + } + if ctx == nil { ctx = context.Background() } + if c.httpClient == httputil.DefaultHTTPClient { + defer httputil.CloseIdleHTTPConnections(c.httpClient) + } + c.endSessions(ctx) if c.mongocryptdFLE != nil { if err := c.mongocryptdFLE.disconnect(ctx); err != nil { @@ -235,6 +347,7 @@ func (c *Client) Disconnect(ctx context.Context) error { if disconnector, ok := c.deployment.(driver.Disconnector); ok { return replaceErrors(disconnector.Disconnect(ctx)) } + return nil } @@ -271,6 +384,9 @@ func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error { // StartSession does not actually communicate with the server and will not error if the client is // disconnected. // +// StartSession is safe to call from multiple goroutines concurrently. However, Sessions returned by StartSession are +// not safe for concurrent use by multiple goroutines. +// // If the DefaultReadConcern, DefaultWriteConcern, or DefaultReadPreference options are not set, the client's read // concern, write concern, or read preference will be used, respectively. func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error) { @@ -303,7 +419,7 @@ func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error) coreOpts.Snapshot = sopts.Snapshot } - sess, err := session.NewClientSession(c.sessionPool, c.id, session.Explicit, coreOpts) + sess, err := session.NewClientSession(c.sessionPool, c.id, coreOpts) if err != nil { return nil, replaceErrors(err) } @@ -347,370 +463,33 @@ func (c *Client) endSessions(ctx context.Context) { } } -func (c *Client) configure(opts *options.ClientOptions) error { - if err := opts.Validate(); err != nil { +func (c *Client) configureAutoEncryption(clientOpts *options.ClientOptions) error { + c.encryptedFieldsMap = clientOpts.AutoEncryptionOptions.EncryptedFieldsMap + if err := c.configureKeyVaultClientFLE(clientOpts); err != nil { return err } - - var connOpts []topology.ConnectionOption - var serverOpts []topology.ServerOption - var topologyOpts []topology.Option - - // TODO(GODRIVER-814): Add tests for topology, server, and connection related options. - - // ServerAPIOptions need to be handled early as other client and server options below reference - // c.serverAPI and serverOpts.serverAPI. - if opts.ServerAPIOptions != nil { - // convert passed in options to driver form for client. - c.serverAPI = convertToDriverAPIOptions(opts.ServerAPIOptions) - - serverOpts = append(serverOpts, topology.WithServerAPI(func(*driver.ServerAPIOptions) *driver.ServerAPIOptions { - return c.serverAPI - })) - } - - // ClusterClock - c.clock = new(session.ClusterClock) - - // Pass down URI, SRV service name, and SRV max hosts so topology can poll SRV records correctly. - topologyOpts = append(topologyOpts, - topology.WithURI(func(uri string) string { return opts.GetURI() }), - topology.WithSRVServiceName(func(srvName string) string { - if opts.SRVServiceName != nil { - return *opts.SRVServiceName - } - return "" - }), - topology.WithSRVMaxHosts(func(srvMaxHosts int) int { - if opts.SRVMaxHosts != nil { - return *opts.SRVMaxHosts - } - return 0 - }), - ) - - // AppName - var appName string - if opts.AppName != nil { - appName = *opts.AppName - - serverOpts = append(serverOpts, topology.WithServerAppName(func(string) string { - return appName - })) - } - // Compressors & ZlibLevel - var comps []string - if len(opts.Compressors) > 0 { - comps = opts.Compressors - - connOpts = append(connOpts, topology.WithCompressors( - func(compressors []string) []string { - return append(compressors, comps...) - }, - )) - - for _, comp := range comps { - switch comp { - case "zlib": - connOpts = append(connOpts, topology.WithZlibLevel(func(level *int) *int { - return opts.ZlibLevel - })) - case "zstd": - connOpts = append(connOpts, topology.WithZstdLevel(func(level *int) *int { - return opts.ZstdLevel - })) - } - } - - serverOpts = append(serverOpts, topology.WithCompressionOptions( - func(opts ...string) []string { return append(opts, comps...) }, - )) + if err := c.configureMetadataClientFLE(clientOpts); err != nil { + return err } - var loadBalanced bool - if opts.LoadBalanced != nil { - loadBalanced = *opts.LoadBalanced + mc, err := c.newMongoCrypt(clientOpts.AutoEncryptionOptions) + if err != nil { + return err } - // Handshaker - var handshaker = func(driver.Handshaker) driver.Handshaker { - return operation.NewHello().AppName(appName).Compressors(comps).ClusterClock(c.clock). - ServerAPI(c.serverAPI).LoadBalanced(loadBalanced) - } - // Auth & Database & Password & Username - if opts.Auth != nil { - cred := &auth.Cred{ - Username: opts.Auth.Username, - Password: opts.Auth.Password, - PasswordSet: opts.Auth.PasswordSet, - Props: opts.Auth.AuthMechanismProperties, - Source: opts.Auth.AuthSource, - } - mechanism := opts.Auth.AuthMechanism - - if len(cred.Source) == 0 { - switch strings.ToUpper(mechanism) { - case auth.MongoDBX509, auth.GSSAPI, auth.PLAIN: - cred.Source = "$external" - default: - cred.Source = "admin" - } - } - - authenticator, err := auth.CreateAuthenticator(mechanism, cred) + // If the crypt_shared library was not loaded, try to spawn and connect to mongocryptd. + if mc.CryptSharedLibVersionString() == "" { + mongocryptdFLE, err := newMongocryptdClient(clientOpts.AutoEncryptionOptions) if err != nil { return err } - - handshakeOpts := &auth.HandshakeOptions{ - AppName: appName, - Authenticator: authenticator, - Compressors: comps, - ClusterClock: c.clock, - ServerAPI: c.serverAPI, - LoadBalanced: loadBalanced, - } - if mechanism == "" { - // Required for SASL mechanism negotiation during handshake - handshakeOpts.DBUser = cred.Source + "." + cred.Username - } - if opts.AuthenticateToAnything != nil && *opts.AuthenticateToAnything { - // Authenticate arbiters - handshakeOpts.PerformAuthentication = func(serv description.Server) bool { - return true - } - } - - handshaker = func(driver.Handshaker) driver.Handshaker { - return auth.Handshaker(nil, handshakeOpts) - } - } - connOpts = append(connOpts, topology.WithHandshaker(handshaker)) - // ConnectTimeout - if opts.ConnectTimeout != nil { - serverOpts = append(serverOpts, topology.WithHeartbeatTimeout( - func(time.Duration) time.Duration { return *opts.ConnectTimeout }, - )) - connOpts = append(connOpts, topology.WithConnectTimeout( - func(time.Duration) time.Duration { return *opts.ConnectTimeout }, - )) - } - // Dialer - if opts.Dialer != nil { - connOpts = append(connOpts, topology.WithDialer( - func(topology.Dialer) topology.Dialer { return opts.Dialer }, - )) - } - // Direct - if opts.Direct != nil && *opts.Direct { - topologyOpts = append(topologyOpts, topology.WithMode( - func(topology.MonitorMode) topology.MonitorMode { return topology.SingleMode }, - )) - } - // HeartbeatInterval - if opts.HeartbeatInterval != nil { - serverOpts = append(serverOpts, topology.WithHeartbeatInterval( - func(time.Duration) time.Duration { return *opts.HeartbeatInterval }, - )) - } - // Hosts - hosts := []string{"localhost:27017"} // default host - if len(opts.Hosts) > 0 { - hosts = opts.Hosts - } - topologyOpts = append(topologyOpts, topology.WithSeedList( - func(...string) []string { return hosts }, - )) - // LocalThreshold - c.localThreshold = defaultLocalThreshold - if opts.LocalThreshold != nil { - c.localThreshold = *opts.LocalThreshold - } - // MaxConIdleTime - if opts.MaxConnIdleTime != nil { - connOpts = append(connOpts, topology.WithIdleTimeout( - func(time.Duration) time.Duration { return *opts.MaxConnIdleTime }, - )) - } - // MaxPoolSize - if opts.MaxPoolSize != nil { - serverOpts = append( - serverOpts, - topology.WithMaxConnections(func(uint64) uint64 { return *opts.MaxPoolSize }), - ) - } - // MinPoolSize - if opts.MinPoolSize != nil { - serverOpts = append( - serverOpts, - topology.WithMinConnections(func(uint64) uint64 { return *opts.MinPoolSize }), - ) - } - // MaxConnecting - if opts.MaxConnecting != nil { - serverOpts = append( - serverOpts, - topology.WithMaxConnecting(func(uint64) uint64 { return *opts.MaxConnecting }), - ) - } - // PoolMonitor - if opts.PoolMonitor != nil { - serverOpts = append( - serverOpts, - topology.WithConnectionPoolMonitor(func(*event.PoolMonitor) *event.PoolMonitor { return opts.PoolMonitor }), - ) - } - // Monitor - if opts.Monitor != nil { - c.monitor = opts.Monitor - connOpts = append(connOpts, topology.WithMonitor( - func(*event.CommandMonitor) *event.CommandMonitor { return opts.Monitor }, - )) - } - // ServerMonitor - if opts.ServerMonitor != nil { - c.serverMonitor = opts.ServerMonitor - serverOpts = append( - serverOpts, - topology.WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return opts.ServerMonitor }), - ) - - topologyOpts = append( - topologyOpts, - topology.WithTopologyServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return opts.ServerMonitor }), - ) - } - // ReadConcern - c.readConcern = readconcern.New() - if opts.ReadConcern != nil { - c.readConcern = opts.ReadConcern - } - // ReadPreference - c.readPreference = readpref.Primary() - if opts.ReadPreference != nil { - c.readPreference = opts.ReadPreference - } - // Registry - c.registry = bson.DefaultRegistry - if opts.Registry != nil { - c.registry = opts.Registry - } - // ReplicaSet - if opts.ReplicaSet != nil { - topologyOpts = append(topologyOpts, topology.WithReplicaSetName( - func(string) string { return *opts.ReplicaSet }, - )) - } - // RetryWrites - c.retryWrites = true // retry writes on by default - if opts.RetryWrites != nil { - c.retryWrites = *opts.RetryWrites - } - c.retryReads = true - if opts.RetryReads != nil { - c.retryReads = *opts.RetryReads - } - // ServerSelectionTimeout - if opts.ServerSelectionTimeout != nil { - topologyOpts = append(topologyOpts, topology.WithServerSelectionTimeout( - func(time.Duration) time.Duration { return *opts.ServerSelectionTimeout }, - )) - } - // SocketTimeout - if opts.SocketTimeout != nil { - connOpts = append( - connOpts, - topology.WithReadTimeout(func(time.Duration) time.Duration { return *opts.SocketTimeout }), - topology.WithWriteTimeout(func(time.Duration) time.Duration { return *opts.SocketTimeout }), - ) - } - // TLSConfig - if opts.TLSConfig != nil { - connOpts = append(connOpts, topology.WithTLSConfig( - func(*tls.Config) *tls.Config { - return opts.TLSConfig - }, - )) - } - // WriteConcern - if opts.WriteConcern != nil { - c.writeConcern = opts.WriteConcern - } - // AutoEncryptionOptions - if opts.AutoEncryptionOptions != nil { - if err := c.configureAutoEncryption(opts); err != nil { - return err - } - } else { - c.cryptFLE = opts.Crypt - } - - // OCSP cache - ocspCache := ocsp.NewCache() - connOpts = append( - connOpts, - topology.WithOCSPCache(func(ocsp.Cache) ocsp.Cache { return ocspCache }), - ) - - // Disable communication with external OCSP responders. - if opts.DisableOCSPEndpointCheck != nil { - connOpts = append( - connOpts, - topology.WithDisableOCSPEndpointCheck(func(bool) bool { return *opts.DisableOCSPEndpointCheck }), - ) - } - - // LoadBalanced - if opts.LoadBalanced != nil { - topologyOpts = append( - topologyOpts, - topology.WithLoadBalanced(func(bool) bool { return *opts.LoadBalanced }), - ) - serverOpts = append( - serverOpts, - topology.WithServerLoadBalanced(func(bool) bool { return *opts.LoadBalanced }), - ) - connOpts = append( - connOpts, - topology.WithConnectionLoadBalanced(func(bool) bool { return *opts.LoadBalanced }), - ) - } - - serverOpts = append( - serverOpts, - topology.WithClock(func(*session.ClusterClock) *session.ClusterClock { return c.clock }), - topology.WithConnectionOptions(func(...topology.ConnectionOption) []topology.ConnectionOption { return connOpts }), - ) - c.topologyOptions = append(topologyOpts, topology.WithServerOptions( - func(...topology.ServerOption) []topology.ServerOption { return serverOpts }, - )) - - // Deployment - if opts.Deployment != nil { - // topology options: WithSeedlist, WithURI, WithSRVServiceName and WithSRVMaxHosts - // server options: WithClock and WithConnectionOptions - if len(serverOpts) > 2 || len(topologyOpts) > 4 { - return errors.New("cannot specify topology or server options with a deployment") - } - c.deployment = opts.Deployment + c.mongocryptdFLE = mongocryptdFLE } + c.configureCryptFLE(mc, clientOpts.AutoEncryptionOptions) return nil } -func (c *Client) configureAutoEncryption(clientOpts *options.ClientOptions) error { - if err := c.configureKeyVaultClientFLE(clientOpts); err != nil { - return err - } - if err := c.configureMetadataClientFLE(clientOpts); err != nil { - return err - } - if err := c.configureMongocryptdClientFLE(clientOpts.AutoEncryptionOptions); err != nil { - return err - } - return c.configureCryptFLE(clientOpts.AutoEncryptionOptions) -} - func (c *Client) getOrCreateInternalClient(clientOpts *options.ClientOptions) (*Client, error) { if c.internalClientFLE != nil { return c.internalClientFLE, nil @@ -763,32 +542,91 @@ func (c *Client) configureMetadataClientFLE(clientOpts *options.ClientOptions) e return err } -func (c *Client) configureMongocryptdClientFLE(opts *options.AutoEncryptionOptions) error { - var err error - c.mongocryptdFLE, err = newMcryptClient(opts) - return err -} - -func (c *Client) configureCryptFLE(opts *options.AutoEncryptionOptions) error { +func (c *Client) newMongoCrypt(opts *options.AutoEncryptionOptions) (*mongocrypt.MongoCrypt, error) { // convert schemas in SchemaMap to bsoncore documents cryptSchemaMap := make(map[string]bsoncore.Document) for k, v := range opts.SchemaMap { - schema, err := transformBsoncoreDocument(c.registry, v, true, "schemaMap") + schema, err := marshal(v, c.bsonOpts, c.registry) if err != nil { - return err + return nil, err } cryptSchemaMap[k] = schema } - kmsProviders, err := transformBsoncoreDocument(c.registry, opts.KmsProviders, true, "kmsProviders") + + // convert schemas in EncryptedFieldsMap to bsoncore documents + cryptEncryptedFieldsMap := make(map[string]bsoncore.Document) + for k, v := range opts.EncryptedFieldsMap { + encryptedFields, err := marshal(v, c.bsonOpts, c.registry) + if err != nil { + return nil, err + } + cryptEncryptedFieldsMap[k] = encryptedFields + } + + kmsProviders, err := marshal(opts.KmsProviders, c.bsonOpts, c.registry) + if err != nil { + return nil, fmt.Errorf("error creating KMS providers document: %w", err) + } + + // Set the crypt_shared library override path from the "cryptSharedLibPath" extra option if one + // was set. + cryptSharedLibPath := "" + if val, ok := opts.ExtraOptions["cryptSharedLibPath"]; ok { + str, ok := val.(string) + if !ok { + return nil, fmt.Errorf( + `expected AutoEncryption extra option "cryptSharedLibPath" to be a string, but is a %T`, val) + } + cryptSharedLibPath = str + } + + // Explicitly disable loading the crypt_shared library if requested. Note that this is ONLY + // intended for use from tests; there is no supported public API for explicitly disabling + // loading the crypt_shared library. + cryptSharedLibDisabled := false + if v, ok := opts.ExtraOptions["__cryptSharedLibDisabledForTestOnly"]; ok { + cryptSharedLibDisabled = v.(bool) + } + + bypassAutoEncryption := opts.BypassAutoEncryption != nil && *opts.BypassAutoEncryption + bypassQueryAnalysis := opts.BypassQueryAnalysis != nil && *opts.BypassQueryAnalysis + + mc, err := mongocrypt.NewMongoCrypt(mcopts.MongoCrypt(). + SetKmsProviders(kmsProviders). + SetLocalSchemaMap(cryptSchemaMap). + SetBypassQueryAnalysis(bypassQueryAnalysis). + SetEncryptedFieldsMap(cryptEncryptedFieldsMap). + SetCryptSharedLibDisabled(cryptSharedLibDisabled || bypassAutoEncryption). + SetCryptSharedLibOverridePath(cryptSharedLibPath). + SetHTTPClient(opts.HTTPClient)) if err != nil { - return fmt.Errorf("error creating KMS providers document: %v", err) + return nil, err + } + + var cryptSharedLibRequired bool + if val, ok := opts.ExtraOptions["cryptSharedLibRequired"]; ok { + b, ok := val.(bool) + if !ok { + return nil, fmt.Errorf( + `expected AutoEncryption extra option "cryptSharedLibRequired" to be a bool, but is a %T`, val) + } + cryptSharedLibRequired = b } - // configure options - var bypass bool - if opts.BypassAutoEncryption != nil { - bypass = *opts.BypassAutoEncryption + // If the "cryptSharedLibRequired" extra option is set to true, check the MongoCrypt version + // string to confirm that the library was successfully loaded. If the version string is empty, + // return an error indicating that we couldn't load the crypt_shared library. + if cryptSharedLibRequired && mc.CryptSharedLibVersionString() == "" { + return nil, errors.New( + `AutoEncryption extra option "cryptSharedLibRequired" is true, but we failed to load the crypt_shared library`) } + + return mc, nil +} + +//nolint:unused // the unused linter thinks that this function is unreachable because "c.newMongoCrypt" always panics without the "cse" build tag set. +func (c *Client) configureCryptFLE(mc *mongocrypt.MongoCrypt, opts *options.AutoEncryptionOptions) { + bypass := opts.BypassAutoEncryption != nil && *opts.BypassAutoEncryption kr := keyRetriever{coll: c.keyVaultCollFLE} var cir collInfoRetriever // If bypass is true, c.metadataClientFLE is nil and the collInfoRetriever @@ -798,40 +636,24 @@ func (c *Client) configureCryptFLE(opts *options.AutoEncryptionOptions) error { cir = collInfoRetriever{client: c.metadataClientFLE} } - cryptOpts := &driver.CryptOptions{ + c.cryptFLE = driver.NewCrypt(&driver.CryptOptions{ + MongoCrypt: mc, CollInfoFn: cir.cryptCollInfo, KeyFn: kr.cryptKeys, MarkFn: c.mongocryptdFLE.markCommand, - KmsProviders: kmsProviders, TLSConfig: opts.TLSConfig, BypassAutoEncryption: bypass, - SchemaMap: cryptSchemaMap, - } - - c.cryptFLE, err = driver.NewCrypt(cryptOpts) - return err + }) } // validSession returns an error if the session doesn't belong to the client func (c *Client) validSession(sess *session.Client) error { - if sess != nil && !uuid.Equal(sess.ClientID, c.id) { + if sess != nil && sess.ClientID != c.id { return ErrWrongClient } return nil } -// convertToDriverAPIOptions converts a options.ServerAPIOptions instance to a driver.ServerAPIOptions. -func convertToDriverAPIOptions(s *options.ServerAPIOptions) *driver.ServerAPIOptions { - driverOpts := driver.NewServerAPIOptions(string(s.ServerAPIVersion)) - if s.Strict != nil { - driverOpts.SetStrict(*s.Strict) - } - if s.DeprecationErrors != nil { - driverOpts.SetDeprecationErrors(*s.DeprecationErrors) - } - return driverOpts -} - // Database returns a handle for a database with the given name configured with the given DatabaseOptions. func (c *Client) Database(name string, opts ...*options.DatabaseOptions) *Database { return newDatabase(c, name, opts...) @@ -845,7 +667,7 @@ func (c *Client) Database(name string, opts ...*options.DatabaseOptions) *Databa // // The opts parameter can be used to specify options for this operation (see the options.ListDatabasesOptions documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/listDatabases/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/listDatabases/. func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ...*options.ListDatabasesOptions) (ListDatabasesResult, error) { if ctx == nil { ctx = context.Background() @@ -858,10 +680,7 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ... return ListDatabasesResult{}, err } if sess == nil && c.sessionPool != nil { - sess, err = session.NewClientSession(c.sessionPool, c.id, session.Implicit) - if err != nil { - return ListDatabasesResult{}, err - } + sess = session.NewImplicitClientSession(c.sessionPool, c.id) defer sess.EndSession() } @@ -870,7 +689,7 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ... return ListDatabasesResult{}, err } - filterDoc, err := transformBsoncoreDocument(c.registry, filter, true, "filter") + filterDoc, err := marshal(filter, c.bsonOpts, c.registry) if err != nil { return ListDatabasesResult{}, err } @@ -885,7 +704,7 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ... op := operation.NewListDatabases(filterDoc). Session(sess).ReadPreference(c.readPreference).CommandMonitor(c.monitor). ServerSelector(selector).ClusterClock(c.clock).Database("admin").Deployment(c.deployment).Crypt(c.cryptFLE). - ServerAPI(c.serverAPI) + ServerAPI(c.serverAPI).Timeout(c.timeout).Authenticator(c.authenticator) if ldo.NameOnly != nil { op = op.NameOnly(*ldo.NameOnly) @@ -918,7 +737,7 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ... // The opts parameter can be used to specify options for this operation (see the options.ListDatabasesOptions // documentation.) // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/listDatabases/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/listDatabases/. func (c *Client) ListDatabaseNames(ctx context.Context, filter interface{}, opts ...*options.ListDatabasesOptions) ([]string, error) { opts = append(opts, options.ListDatabases().SetNameOnly(true)) @@ -939,6 +758,9 @@ func (c *Client) ListDatabaseNames(ctx context.Context, filter interface{}, opts // SessionContext must be used as the Context parameter for any operations in the fn callback that should be executed // under the session. // +// WithSession is safe to call from multiple goroutines concurrently. However, the SessionContext passed to the +// WithSession callback function is not safe for concurrent use by multiple goroutines. +// // If the ctx parameter already contains a Session, that Session will be replaced with the one provided. // // Any error returned by the fn callback will be returned without any modifications. @@ -951,6 +773,9 @@ func WithSession(ctx context.Context, sess Session, fn func(SessionContext) erro // be executed under a session. After the callback returns, the created Session is ended, meaning that any in-progress // transactions started by fn will be aborted even if fn returns an error. // +// UseSession is safe to call from multiple goroutines concurrently. However, the SessionContext passed to the +// UseSession callback function is not safe for concurrent use by multiple goroutines. +// // If the ctx parameter already contains a Session, that Session will be replaced with the newly created one. // // Any error returned by the fn callback will be returned without any modifications. @@ -959,6 +784,9 @@ func (c *Client) UseSession(ctx context.Context, fn func(SessionContext) error) } // UseSessionWithOptions operates like UseSession but uses the given SessionOptions to create the Session. +// +// UseSessionWithOptions is safe to call from multiple goroutines concurrently. However, the SessionContext passed to +// the UseSessionWithOptions callback function is not safe for concurrent use by multiple goroutines. func (c *Client) UseSessionWithOptions(ctx context.Context, opts *options.SessionOptions, fn func(SessionContext) error) error { defaultSess, err := c.StartSession(opts) if err != nil { @@ -970,13 +798,13 @@ func (c *Client) UseSessionWithOptions(ctx context.Context, opts *options.Sessio } // Watch returns a change stream for all changes on the deployment. See -// https://docs.mongodb.com/manual/changeStreams/ for more information about change streams. +// https://www.mongodb.com/docs/manual/changeStreams/ for more information about change streams. // // The client must be configured with read concern majority or no read concern for a change stream to be created // successfully. // // The pipeline parameter must be an array of documents, each representing a pipeline stage. The pipeline cannot be -// nil or empty. The stage documents must all be non-nil. See https://docs.mongodb.com/manual/changeStreams/ for a list +// nil or empty. The stage documents must all be non-nil. See https://www.mongodb.com/docs/manual/changeStreams/ for a list // of pipeline stages that can be used with change streams. For a pipeline of bson.D documents, the mongo.Pipeline{} // type can be used. // @@ -992,6 +820,7 @@ func (c *Client) Watch(ctx context.Context, pipeline interface{}, readConcern: c.readConcern, readPreference: c.readPreference, client: c, + bsonOpts: c.bsonOpts, registry: c.registry, streamType: ClientStream, crypt: c.cryptFLE, @@ -1003,7 +832,15 @@ func (c *Client) Watch(ctx context.Context, pipeline interface{}, // NumberSessionsInProgress returns the number of sessions that have been started for this client but have not been // closed (i.e. EndSession has not been called). func (c *Client) NumberSessionsInProgress() int { - return c.sessionPool.CheckedOut() + // The underlying session pool uses an int64 for checkedOut to allow atomic + // access. We convert to an int here to maintain backward compatibility with + // older versions of the driver that did not atomically access checkedOut. + return int(c.sessionPool.CheckedOut()) +} + +// Timeout returns the timeout set for this client. +func (c *Client) Timeout() *time.Duration { + return c.timeout } func (c *Client) createBaseCursorOptions() driver.CursorOptions { @@ -1013,3 +850,28 @@ func (c *Client) createBaseCursorOptions() driver.CursorOptions { ServerAPI: c.serverAPI, } } + +// newLogger will use the LoggerOptions to create an internal logger and publish +// messages using a LogSink. +func newLogger(opts *options.LoggerOptions) (*logger.Logger, error) { + // If there are no logger options, then create a default logger. + if opts == nil { + opts = options.Logger() + } + + // If there are no component-level options and the environment does not + // contain component variables, then do nothing. + if (len(opts.ComponentLevels) == 0) && + !logger.EnvHasComponentVariables() { + + return nil, nil + } + + // Otherwise, collect the component-level options and create a logger. + componentLevels := make(map[logger.Component]logger.Level) + for component, level := range opts.ComponentLevels { + componentLevels[logger.Component(component)] = logger.Level(level) + } + + return logger.New(opts.Sink, opts.MaxDocumentLength, componentLevels) +} diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/client_encryption.go b/vendor/go.mongodb.org/mongo-driver/mongo/client_encryption.go index fe4646b..352dac1 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/client_encryption.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/client_encryption.go @@ -8,16 +8,18 @@ package mongo import ( "context" + "errors" "fmt" "strings" - "github.com/pkg/errors" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - cryptOpts "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options" + "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt" + mcopts "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options" ) // ClientEncryption is used to create data keys and explicitly encrypt and decrypt BSON values. @@ -42,41 +44,121 @@ func NewClientEncryption(keyVaultClient *Client, opts ...*options.ClientEncrypti db, coll := splitNamespace(ceo.KeyVaultNamespace) ce.keyVaultColl = ce.keyVaultClient.Database(db).Collection(coll, keyVaultCollOpts) - kmsProviders, err := transformBsoncoreDocument(bson.DefaultRegistry, ceo.KmsProviders, true, "kmsProviders") + kmsProviders, err := marshal(ceo.KmsProviders, nil, nil) if err != nil { - return nil, fmt.Errorf("error creating KMS providers map: %v", err) + return nil, fmt.Errorf("error creating KMS providers map: %w", err) + } + + mc, err := mongocrypt.NewMongoCrypt(mcopts.MongoCrypt(). + SetKmsProviders(kmsProviders). + // Explicitly disable loading the crypt_shared library for the Crypt used for + // ClientEncryption because it's only needed for AutoEncryption and we don't expect users to + // have the crypt_shared library installed if they're using ClientEncryption. + SetCryptSharedLibDisabled(true). + SetHTTPClient(ceo.HTTPClient)) + if err != nil { + return nil, err } // create Crypt kr := keyRetriever{coll: ce.keyVaultColl} cir := collInfoRetriever{client: ce.keyVaultClient} - ce.crypt, err = driver.NewCrypt(&driver.CryptOptions{ - KeyFn: kr.cryptKeys, - CollInfoFn: cir.cryptCollInfo, - KmsProviders: kmsProviders, - TLSConfig: ceo.TLSConfig, + ce.crypt = driver.NewCrypt(&driver.CryptOptions{ + MongoCrypt: mc, + KeyFn: kr.cryptKeys, + CollInfoFn: cir.cryptCollInfo, + TLSConfig: ceo.TLSConfig, }) + + return ce, nil +} + +// CreateEncryptedCollection creates a new collection for Queryable Encryption with the help of automatic generation of new encryption data keys for null keyIds. +// It returns the created collection and the encrypted fields document used to create it. +func (ce *ClientEncryption) CreateEncryptedCollection(ctx context.Context, + db *Database, coll string, createOpts *options.CreateCollectionOptions, + kmsProvider string, masterKey interface{}) (*Collection, bson.M, error) { + if createOpts == nil { + return nil, nil, errors.New("nil CreateCollectionOptions") + } + ef := createOpts.EncryptedFields + if ef == nil { + return nil, nil, errors.New("no EncryptedFields defined for the collection") + } + + efBSON, err := marshal(ef, db.bsonOpts, db.registry) if err != nil { - return nil, err + return nil, nil, err + } + r := bsonrw.NewBSONDocumentReader(efBSON) + dec, err := bson.NewDecoder(r) + if err != nil { + return nil, nil, err + } + var m bson.M + err = dec.Decode(&m) + if err != nil { + return nil, nil, err } - return ce, nil + if v, ok := m["fields"]; ok { + if fields, ok := v.(bson.A); ok { + for _, field := range fields { + if f, ok := field.(bson.M); !ok { + continue + } else if v, ok := f["keyId"]; ok && v == nil { + dkOpts := options.DataKey() + if masterKey != nil { + dkOpts.SetMasterKey(masterKey) + } + keyid, err := ce.CreateDataKey(ctx, kmsProvider, dkOpts) + if err != nil { + createOpts.EncryptedFields = m + return nil, m, err + } + f["keyId"] = keyid + } + } + createOpts.EncryptedFields = m + } + } + err = db.CreateCollection(ctx, coll, createOpts) + if err != nil { + return nil, m, err + } + return db.Collection(coll), m, nil +} + +// AddKeyAltName adds a keyAltName to the keyAltNames array of the key document in the key vault collection with the +// given UUID (BSON binary subtype 0x04). Returns the previous version of the key document. +func (ce *ClientEncryption) AddKeyAltName(ctx context.Context, id primitive.Binary, keyAltName string) *SingleResult { + filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build() + keyAltNameDoc := bsoncore.NewDocumentBuilder().AppendString("keyAltNames", keyAltName).Build() + update := bsoncore.NewDocumentBuilder().AppendDocument("$addToSet", keyAltNameDoc).Build() + return ce.keyVaultColl.FindOneAndUpdate(ctx, filter, update) } -// CreateDataKey creates a new key document and inserts it into the key vault collection. Returns the _id of the -// created document. -func (ce *ClientEncryption) CreateDataKey(ctx context.Context, kmsProvider string, opts ...*options.DataKeyOptions) (primitive.Binary, error) { - // translate opts to cryptOpts.DataKeyOptions +// CreateDataKey creates a new key document and inserts into the key vault collection. Returns the _id of the created +// document as a UUID (BSON binary subtype 0x04). +func (ce *ClientEncryption) CreateDataKey(ctx context.Context, kmsProvider string, + opts ...*options.DataKeyOptions) (primitive.Binary, error) { + + // translate opts to mcopts.DataKeyOptions dko := options.MergeDataKeyOptions(opts...) - co := cryptOpts.DataKey().SetKeyAltNames(dko.KeyAltNames) + co := mcopts.DataKey().SetKeyAltNames(dko.KeyAltNames) if dko.MasterKey != nil { - keyDoc, err := transformBsoncoreDocument(ce.keyVaultClient.registry, dko.MasterKey, true, "masterKey") + keyDoc, err := marshal( + dko.MasterKey, + ce.keyVaultClient.bsonOpts, + ce.keyVaultClient.registry) if err != nil { return primitive.Binary{}, err } - co.SetMasterKey(keyDoc) } + if dko.KeyMaterial != nil { + co.SetKeyMaterial(dko.KeyMaterial) + } // create data key document dataKeyDoc, err := ce.crypt.CreateDataKey(ctx, kmsProvider, co) @@ -94,10 +176,10 @@ func (ce *ClientEncryption) CreateDataKey(ctx context.Context, kmsProvider strin return primitive.Binary{Subtype: subtype, Data: data}, nil } -// Encrypt encrypts a BSON value with the given key and algorithm. Returns an encrypted value (BSON binary of subtype 6). -func (ce *ClientEncryption) Encrypt(ctx context.Context, val bson.RawValue, opts ...*options.EncryptOptions) (primitive.Binary, error) { +// transformExplicitEncryptionOptions creates explicit encryption options to be passed to libmongocrypt. +func transformExplicitEncryptionOptions(opts ...*options.EncryptOptions) *mcopts.ExplicitEncryptionOptions { eo := options.MergeEncryptOptions(opts...) - transformed := cryptOpts.ExplicitEncryption() + transformed := mcopts.ExplicitEncryption() if eo.KeyID != nil { transformed.SetKeyID(*eo.KeyID) } @@ -105,7 +187,39 @@ func (ce *ClientEncryption) Encrypt(ctx context.Context, val bson.RawValue, opts transformed.SetKeyAltName(*eo.KeyAltName) } transformed.SetAlgorithm(eo.Algorithm) + transformed.SetQueryType(eo.QueryType) + + if eo.ContentionFactor != nil { + transformed.SetContentionFactor(*eo.ContentionFactor) + } + + if eo.RangeOptions != nil { + var transformedRange mcopts.ExplicitRangeOptions + if eo.RangeOptions.Min != nil { + transformedRange.Min = &bsoncore.Value{Type: eo.RangeOptions.Min.Type, Data: eo.RangeOptions.Min.Value} + } + if eo.RangeOptions.Max != nil { + transformedRange.Max = &bsoncore.Value{Type: eo.RangeOptions.Max.Type, Data: eo.RangeOptions.Max.Value} + } + if eo.RangeOptions.Precision != nil { + transformedRange.Precision = eo.RangeOptions.Precision + } + if eo.RangeOptions.Sparsity != nil { + transformedRange.Sparsity = eo.RangeOptions.Sparsity + } + if eo.RangeOptions.TrimFactor != nil { + transformedRange.TrimFactor = eo.RangeOptions.TrimFactor + } + transformed.SetRangeOptions(transformedRange) + } + return transformed +} + +// Encrypt encrypts a BSON value with the given key and algorithm. Returns an encrypted value (BSON binary of subtype 6). +func (ce *ClientEncryption) Encrypt(ctx context.Context, val bson.RawValue, + opts ...*options.EncryptOptions) (primitive.Binary, error) { + transformed := transformExplicitEncryptionOptions(opts...) subtype, data, err := ce.crypt.EncryptExplicit(ctx, bsoncore.Value{Type: val.Type, Data: val.Value}, transformed) if err != nil { return primitive.Binary{}, err @@ -113,6 +227,39 @@ func (ce *ClientEncryption) Encrypt(ctx context.Context, val bson.RawValue, opts return primitive.Binary{Subtype: subtype, Data: data}, nil } +// EncryptExpression encrypts an expression to query a range index. +// On success, `result` is populated with the resulting BSON document. +// `expr` is expected to be a BSON document of one of the following forms: +// 1. A Match Expression of this form: +// {$and: [{: {$gt: }}, {: {$lt: }}]} +// 2. An Aggregate Expression of this form: +// {$and: [{$gt: [, ]}, {$lt: [, ]}] +// $gt may also be $gte. $lt may also be $lte. +// Only supported for queryType "range" +func (ce *ClientEncryption) EncryptExpression(ctx context.Context, expr interface{}, result interface{}, opts ...*options.EncryptOptions) error { + transformed := transformExplicitEncryptionOptions(opts...) + + exprDoc, err := marshal(expr, nil, nil) + if err != nil { + return err + } + + encryptedExprDoc, err := ce.crypt.EncryptExplicitExpression(ctx, exprDoc, transformed) + if err != nil { + return err + } + if raw, ok := result.(*bson.Raw); ok { + // Avoid the cost of Unmarshal. + *raw = bson.Raw(encryptedExprDoc) + return nil + } + err = bson.Unmarshal([]byte(encryptedExprDoc), result) + if err != nil { + return err + } + return nil +} + // Decrypt decrypts an encrypted value (BSON binary of subtype 6) and returns the original BSON value. func (ce *ClientEncryption) Decrypt(ctx context.Context, val primitive.Binary) (bson.RawValue, error) { decrypted, err := ce.crypt.DecryptExplicit(ctx, val.Subtype, val.Data) @@ -130,6 +277,154 @@ func (ce *ClientEncryption) Close(ctx context.Context) error { return ce.keyVaultClient.Disconnect(ctx) } +// DeleteKey removes the key document with the given UUID (BSON binary subtype 0x04) from the key vault collection. +// Returns the result of the internal deleteOne() operation on the key vault collection. +func (ce *ClientEncryption) DeleteKey(ctx context.Context, id primitive.Binary) (*DeleteResult, error) { + filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build() + return ce.keyVaultColl.DeleteOne(ctx, filter) +} + +// GetKeyByAltName returns a key document in the key vault collection with the given keyAltName. +func (ce *ClientEncryption) GetKeyByAltName(ctx context.Context, keyAltName string) *SingleResult { + filter := bsoncore.NewDocumentBuilder().AppendString("keyAltNames", keyAltName).Build() + return ce.keyVaultColl.FindOne(ctx, filter) +} + +// GetKey finds a single key document with the given UUID (BSON binary subtype 0x04). Returns the result of the +// internal find() operation on the key vault collection. +func (ce *ClientEncryption) GetKey(ctx context.Context, id primitive.Binary) *SingleResult { + filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build() + return ce.keyVaultColl.FindOne(ctx, filter) +} + +// GetKeys finds all documents in the key vault collection. Returns the result of the internal find() operation on the +// key vault collection. +func (ce *ClientEncryption) GetKeys(ctx context.Context) (*Cursor, error) { + return ce.keyVaultColl.Find(ctx, bson.D{}) +} + +// RemoveKeyAltName removes a keyAltName from the keyAltNames array of the key document in the key vault collection with +// the given UUID (BSON binary subtype 0x04). Returns the previous version of the key document. +func (ce *ClientEncryption) RemoveKeyAltName(ctx context.Context, id primitive.Binary, keyAltName string) *SingleResult { + filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build() + update := bson.A{bson.D{{"$set", bson.D{{"keyAltNames", bson.D{{"$cond", bson.A{bson.D{{"$eq", + bson.A{"$keyAltNames", bson.A{keyAltName}}}}, "$$REMOVE", bson.D{{"$filter", + bson.D{{"input", "$keyAltNames"}, {"cond", bson.D{{"$ne", bson.A{"$$this", keyAltName}}}}}}}}}}}}}}} + return ce.keyVaultColl.FindOneAndUpdate(ctx, filter, update) +} + +// setRewrapManyDataKeyWriteModels will prepare the WriteModel slice for a bulk updating rewrapped documents. +func setRewrapManyDataKeyWriteModels(rewrappedDocuments []bsoncore.Document, writeModels *[]WriteModel) error { + const idKey = "_id" + const keyMaterial = "keyMaterial" + const masterKey = "masterKey" + + if writeModels == nil { + return fmt.Errorf("writeModels pointer not set for location referenced") + } + + // Append a slice of WriteModel with the update document per each rewrappedDoc _id filter. + for _, rewrappedDocument := range rewrappedDocuments { + // Prepare the new master key for update. + masterKeyValue, err := rewrappedDocument.LookupErr(masterKey) + if err != nil { + return err + } + masterKeyDoc := masterKeyValue.Document() + + // Prepare the new material key for update. + keyMaterialValue, err := rewrappedDocument.LookupErr(keyMaterial) + if err != nil { + return err + } + keyMaterialSubtype, keyMaterialData := keyMaterialValue.Binary() + keyMaterialBinary := primitive.Binary{Subtype: keyMaterialSubtype, Data: keyMaterialData} + + // Prepare the _id filter for documents to update. + id, err := rewrappedDocument.LookupErr(idKey) + if err != nil { + return err + } + + idSubtype, idData, ok := id.BinaryOK() + if !ok { + return fmt.Errorf("expected to assert %q as binary, got type %T", idKey, id) + } + binaryID := primitive.Binary{Subtype: idSubtype, Data: idData} + + // Append the mutable document to the slice for bulk update. + *writeModels = append(*writeModels, NewUpdateOneModel(). + SetFilter(bson.D{{idKey, binaryID}}). + SetUpdate( + bson.D{ + {"$set", bson.D{{keyMaterial, keyMaterialBinary}, {masterKey, masterKeyDoc}}}, + {"$currentDate", bson.D{{"updateDate", true}}}, + }, + )) + } + return nil +} + +// RewrapManyDataKey decrypts and encrypts all matching data keys with a possibly new masterKey value. For all +// matching documents, this method will overwrite the "masterKey", "updateDate", and "keyMaterial". On error, some +// matching data keys may have been rewrapped. +// libmongocrypt 1.5.2 is required. An error is returned if the detected version of libmongocrypt is less than 1.5.2. +func (ce *ClientEncryption) RewrapManyDataKey(ctx context.Context, filter interface{}, + opts ...*options.RewrapManyDataKeyOptions) (*RewrapManyDataKeyResult, error) { + + // libmongocrypt versions 1.5.0 and 1.5.1 have a severe bug in RewrapManyDataKey. + // Check if the version string starts with 1.5.0 or 1.5.1. This accounts for pre-release versions, like 1.5.0-rc0. + libmongocryptVersion := mongocrypt.Version() + if strings.HasPrefix(libmongocryptVersion, "1.5.0") || strings.HasPrefix(libmongocryptVersion, "1.5.1") { + return nil, fmt.Errorf("RewrapManyDataKey requires libmongocrypt 1.5.2 or newer. Detected version: %v", libmongocryptVersion) + } + + rmdko := options.MergeRewrapManyDataKeyOptions(opts...) + if ctx == nil { + ctx = context.Background() + } + + // Transfer rmdko options to /x/ package options to publish the mongocrypt feed. + co := mcopts.RewrapManyDataKey() + if rmdko.MasterKey != nil { + keyDoc, err := marshal( + rmdko.MasterKey, + ce.keyVaultClient.bsonOpts, + ce.keyVaultClient.registry) + if err != nil { + return nil, err + } + co.SetMasterKey(keyDoc) + } + if rmdko.Provider != nil { + co.SetProvider(*rmdko.Provider) + } + + // Prepare the filters and rewrap the data key using mongocrypt. + filterdoc, err := marshal(filter, ce.keyVaultClient.bsonOpts, ce.keyVaultClient.registry) + if err != nil { + return nil, err + } + + rewrappedDocuments, err := ce.crypt.RewrapDataKey(ctx, filterdoc, co) + if err != nil { + return nil, err + } + if len(rewrappedDocuments) == 0 { + // If there are no documents to rewrap, then do nothing. + return new(RewrapManyDataKeyResult), nil + } + + // Prepare the WriteModel slice for bulk updating the rewrapped data keys. + models := []WriteModel{} + if err := setRewrapManyDataKeyWriteModels(rewrappedDocuments, &models); err != nil { + return nil, err + } + + bulkWriteResults, err := ce.keyVaultColl.BulkWrite(ctx, models) + return &RewrapManyDataKeyResult{BulkWriteResult: bulkWriteResults}, err +} + // splitNamespace takes a namespace in the form "database.collection" and returns (database name, collection name) func splitNamespace(ns string) (string, string) { firstDot := strings.Index(ns, ".") diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/collection.go b/vendor/go.mongodb.org/mongo-driver/mongo/collection.go index 37c6676..95889c8 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/collection.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/collection.go @@ -10,12 +10,15 @@ import ( "context" "errors" "fmt" + "reflect" "strings" "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/internal/csfle" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" @@ -37,6 +40,7 @@ type Collection struct { readPreference *readpref.ReadPref readSelector description.ServerSelector writeSelector description.ServerSelector + bsonOpts *options.BSONOptions registry *bsoncodec.Registry } @@ -45,6 +49,7 @@ type aggregateParams struct { ctx context.Context pipeline interface{} client *Client + bsonOpts *options.BSONOptions registry *bsoncodec.Registry readConcern *readconcern.ReadConcern writeConcern *writeconcern.WriteConcern @@ -58,7 +63,7 @@ type aggregateParams struct { } func closeImplicitSession(sess *session.Client) { - if sess != nil && sess.SessionType == session.Implicit { + if sess != nil && sess.IsImplicit { sess.EndSession() } } @@ -81,6 +86,11 @@ func newCollection(db *Database, name string, opts ...*options.CollectionOptions rp = collOpt.ReadPreference } + bsonOpts := db.bsonOpts + if collOpt.BSONOptions != nil { + bsonOpts = collOpt.BSONOptions + } + reg := db.registry if collOpt.Registry != nil { reg = collOpt.Registry @@ -105,6 +115,7 @@ func newCollection(db *Database, name string, opts ...*options.CollectionOptions writeConcern: wc, readSelector: readSelector, writeSelector: writeSelector, + bsonOpts: bsonOpts, registry: reg, } @@ -166,7 +177,7 @@ func (coll *Collection) Database() *Database { return coll.db } -// BulkWrite performs a bulk write operation (https://docs.mongodb.com/manual/core/bulk-write-operations/). +// BulkWrite performs a bulk write operation (https://www.mongodb.com/docs/manual/core/bulk-write-operations/). // // The models parameter must be a slice of operations to be executed in this bulk write. It cannot be nil or empty. // All of the models must be non-nil. See the mongo.WriteModel documentation for a list of valid model types and @@ -186,11 +197,7 @@ func (coll *Collection) BulkWrite(ctx context.Context, models []WriteModel, sess := sessionFromContext(ctx) if sess == nil && coll.client.sessionPool != nil { - var err error - sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) - if err != nil { - return nil, err - } + sess = session.NewImplicitClientSession(coll.client.sessionPool, coll.client.id) defer sess.EndSession() } @@ -218,6 +225,7 @@ func (coll *Collection) BulkWrite(ctx context.Context, models []WriteModel, bwo := options.MergeBulkWriteOptions(opts...) op := bulkWrite{ + comment: bwo.Comment, ordered: bwo.Ordered, bypassDocumentValidation: bwo.BypassDocumentValidation, models: models, @@ -226,6 +234,7 @@ func (coll *Collection) BulkWrite(ctx context.Context, models []WriteModel, selector: selector, writeConcern: wc, let: bwo.Let, + bypassEmptyTsReplacement: bwo.BypassEmptyTsReplacement, } err = op.execute(ctx) @@ -244,20 +253,22 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{}, docs := make([]bsoncore.Document, len(documents)) for i, doc := range documents { - var err error - docs[i], result[i], err = transformAndEnsureID(coll.registry, doc) + bsoncoreDoc, err := marshal(doc, coll.bsonOpts, coll.registry) if err != nil { return nil, err } + bsoncoreDoc, id, err := ensureID(bsoncoreDoc, primitive.NilObjectID, coll.bsonOpts, coll.registry) + if err != nil { + return nil, err + } + + docs[i] = bsoncoreDoc + result[i] = id } sess := sessionFromContext(ctx) if sess == nil && coll.client.sessionPool != nil { - var err error - sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) - if err != nil { - return nil, err - } + sess = session.NewImplicitClientSession(coll.client.sessionPool, coll.client.id) defer sess.EndSession() } @@ -281,14 +292,25 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{}, ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Ordered(true). - ServerAPI(coll.client.serverAPI) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger). + Authenticator(coll.client.authenticator) imo := options.MergeInsertManyOptions(opts...) if imo.BypassDocumentValidation != nil && *imo.BypassDocumentValidation { op = op.BypassDocumentValidation(*imo.BypassDocumentValidation) } + if imo.Comment != nil { + comment, err := marshalValue(imo.Comment, coll.bsonOpts, coll.registry) + if err != nil { + return nil, err + } + op = op.Comment(comment) + } if imo.Ordered != nil { op = op.Ordered(*imo.Ordered) } + if imo.BypassEmptyTsReplacement != nil { + op = op.BypassEmptyTsReplacement(*imo.BypassEmptyTsReplacement) + } retry := driver.RetryNone if coll.client.retryWrites { retry = driver.RetryOncePerCommand @@ -296,8 +318,8 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{}, op = op.Retry(retry) err = op.Execute(ctx) - wce, ok := err.(driver.WriteCommandError) - if !ok { + var wce driver.WriteCommandError + if !errors.As(err, &wce) { return result, err } @@ -324,7 +346,7 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{}, // // The opts parameter can be used to specify options for the operation (see the options.InsertOneOptions documentation.) // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/insert/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/insert/. func (coll *Collection) InsertOne(ctx context.Context, document interface{}, opts ...*options.InsertOneOptions) (*InsertOneResult, error) { @@ -334,6 +356,12 @@ func (coll *Collection) InsertOne(ctx context.Context, document interface{}, if ioOpts.BypassDocumentValidation != nil && *ioOpts.BypassDocumentValidation { imOpts.SetBypassDocumentValidation(*ioOpts.BypassDocumentValidation) } + if ioOpts.Comment != nil { + imOpts.SetComment(ioOpts.Comment) + } + if ioOpts.BypassEmptyTsReplacement != nil { + imOpts.BypassEmptyTsReplacement = ioOpts.BypassEmptyTsReplacement + } res, err := coll.insert(ctx, []interface{}{document}, imOpts) rr, err := processWriteError(err) @@ -353,7 +381,7 @@ func (coll *Collection) InsertOne(ctx context.Context, document interface{}, // // The opts parameter can be used to specify options for the operation (see the options.InsertManyOptions documentation.) // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/insert/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/insert/. func (coll *Collection) InsertMany(ctx context.Context, documents []interface{}, opts ...*options.InsertManyOptions) (*InsertManyResult, error) { @@ -368,8 +396,8 @@ func (coll *Collection) InsertMany(ctx context.Context, documents []interface{}, } imResult := &InsertManyResult{InsertedIDs: result} - writeException, ok := err.(WriteException) - if !ok { + var writeException WriteException + if !errors.As(err, &writeException) { return imResult, err } @@ -396,17 +424,14 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return nil, err } sess := sessionFromContext(ctx) if sess == nil && coll.client.sessionPool != nil { - sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) - if err != nil { - return nil, err - } + sess = session.NewImplicitClientSession(coll.client.sessionPool, coll.client.id) defer sess.EndSession() } @@ -437,7 +462,10 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn doc = bsoncore.AppendDocumentElement(doc, "collation", do.Collation.ToDocument()) } if do.Hint != nil { - hint, err := transformValue(coll.registry, do.Hint, false, "hint") + if isUnorderedMap(do.Hint) { + return nil, ErrMapForOrderedArgument{"hint"} + } + hint, err := marshalValue(do.Hint, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -451,12 +479,20 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Ordered(true). - ServerAPI(coll.client.serverAPI) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger). + Authenticator(coll.client.authenticator) + if do.Comment != nil { + comment, err := marshalValue(do.Comment, coll.bsonOpts, coll.registry) + if err != nil { + return nil, err + } + op = op.Comment(comment) + } if do.Hint != nil { op = op.Hint(true) } if do.Let != nil { - let, err := transformBsoncoreDocument(coll.registry, do.Let, true, "let") + let, err := marshal(do.Let, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -473,7 +509,7 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn if rr&expectedRr == 0 { return nil, err } - return &DeleteResult{DeletedCount: int64(op.Result().N)}, err + return &DeleteResult{DeletedCount: op.Result().N}, err } // DeleteOne executes a delete command to delete at most one document from the collection. @@ -485,7 +521,7 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn // // The opts parameter can be used to specify options for the operation (see the options.DeleteOptions documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/delete/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/delete/. func (coll *Collection) DeleteOne(ctx context.Context, filter interface{}, opts ...*options.DeleteOptions) (*DeleteResult, error) { @@ -501,7 +537,7 @@ func (coll *Collection) DeleteOne(ctx context.Context, filter interface{}, // // The opts parameter can be used to specify options for the operation (see the options.DeleteOptions documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/delete/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/delete/. func (coll *Collection) DeleteMany(ctx context.Context, filter interface{}, opts ...*options.DeleteOptions) (*DeleteResult, error) { @@ -519,19 +555,24 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc // collation, arrayFilters, upsert, and hint are included on the individual update documents rather than as part of the // command - updateDoc, err := createUpdateDoc(filter, update, uo.Hint, uo.ArrayFilters, uo.Collation, uo.Upsert, multi, - checkDollarKey, coll.registry) + updateDoc, err := createUpdateDoc( + filter, + update, + uo.Hint, + uo.ArrayFilters, + uo.Collation, + uo.Upsert, + multi, + checkDollarKey, + coll.bsonOpts, + coll.registry) if err != nil { return nil, err } sess := sessionFromContext(ctx) if sess == nil && coll.client.sessionPool != nil { - var err error - sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) - if err != nil { - return nil, err - } + sess = session.NewImplicitClientSession(coll.client.sessionPool, coll.client.id) defer sess.EndSession() } @@ -555,9 +596,10 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Hint(uo.Hint != nil). - ArrayFilters(uo.ArrayFilters != nil).Ordered(true).ServerAPI(coll.client.serverAPI) + ArrayFilters(uo.ArrayFilters != nil).Ordered(true).ServerAPI(coll.client.serverAPI). + Timeout(coll.client.timeout).Logger(coll.client.logger).Authenticator(coll.client.authenticator) if uo.Let != nil { - let, err := transformBsoncoreDocument(coll.registry, uo.Let, true, "let") + let, err := marshal(uo.Let, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -567,6 +609,16 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc if uo.BypassDocumentValidation != nil && *uo.BypassDocumentValidation { op = op.BypassDocumentValidation(*uo.BypassDocumentValidation) } + if uo.Comment != nil { + comment, err := marshalValue(uo.Comment, coll.bsonOpts, coll.registry) + if err != nil { + return nil, err + } + op = op.Comment(comment) + } + if uo.BypassEmptyTsReplacement != nil { + op.BypassEmptyTsReplacement(*uo.BypassEmptyTsReplacement) + } retry := driver.RetryNone // retryable writes are only enabled updateOne/replaceOne operations if !multi && coll.client.retryWrites { @@ -582,8 +634,8 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc opRes := op.Result() res := &UpdateResult{ - MatchedCount: int64(opRes.N), - ModifiedCount: int64(opRes.NModified), + MatchedCount: opRes.N, + ModifiedCount: opRes.NModified, UpsertedCount: int64(len(opRes.Upserted)), } if len(opRes.Upserted) > 0 { @@ -601,12 +653,12 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc // the operation will succeed and an UpdateResult with a MatchedCount of 0 will be returned. // // The update parameter must be a document containing update operators -// (https://docs.mongodb.com/manual/reference/operator/update/) and can be used to specify the modifications to be +// (https://www.mongodb.com/docs/manual/reference/operator/update/) and can be used to specify the modifications to be // made to the selected document. It cannot be nil or empty. // // The opts parameter can be used to specify options for the operation (see the options.UpdateOptions documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/update/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/update/. func (coll *Collection) UpdateByID(ctx context.Context, id interface{}, update interface{}, opts ...*options.UpdateOptions) (*UpdateResult, error) { if id == nil { @@ -623,12 +675,12 @@ func (coll *Collection) UpdateByID(ctx context.Context, id interface{}, update i // matched set and MatchedCount will equal 1. // // The update parameter must be a document containing update operators -// (https://docs.mongodb.com/manual/reference/operator/update/) and can be used to specify the modifications to be +// (https://www.mongodb.com/docs/manual/reference/operator/update/) and can be used to specify the modifications to be // made to the selected document. It cannot be nil or empty. // // The opts parameter can be used to specify options for the operation (see the options.UpdateOptions documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/update/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/update/. func (coll *Collection) UpdateOne(ctx context.Context, filter interface{}, update interface{}, opts ...*options.UpdateOptions) (*UpdateResult, error) { @@ -636,7 +688,7 @@ func (coll *Collection) UpdateOne(ctx context.Context, filter interface{}, updat ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -651,12 +703,12 @@ func (coll *Collection) UpdateOne(ctx context.Context, filter interface{}, updat // with a MatchedCount of 0 will be returned. // // The update parameter must be a document containing update operators -// (https://docs.mongodb.com/manual/reference/operator/update/) and can be used to specify the modifications to be made +// (https://www.mongodb.com/docs/manual/reference/operator/update/) and can be used to specify the modifications to be made // to the selected documents. It cannot be nil or empty. // // The opts parameter can be used to specify options for the operation (see the options.UpdateOptions documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/update/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/update/. func (coll *Collection) UpdateMany(ctx context.Context, filter interface{}, update interface{}, opts ...*options.UpdateOptions) (*UpdateResult, error) { @@ -664,7 +716,7 @@ func (coll *Collection) UpdateMany(ctx context.Context, filter interface{}, upda ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -680,11 +732,11 @@ func (coll *Collection) UpdateMany(ctx context.Context, filter interface{}, upda // selected from the matched set and MatchedCount will equal 1. // // The replacement parameter must be a document that will be used to replace the selected document. It cannot be nil -// and cannot contain any update operators (https://docs.mongodb.com/manual/reference/operator/update/). +// and cannot contain any update operators (https://www.mongodb.com/docs/manual/reference/operator/update/). // // The opts parameter can be used to specify options for the operation (see the options.ReplaceOptions documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/update/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/update/. func (coll *Collection) ReplaceOne(ctx context.Context, filter interface{}, replacement interface{}, opts ...*options.ReplaceOptions) (*UpdateResult, error) { @@ -692,12 +744,12 @@ func (coll *Collection) ReplaceOne(ctx context.Context, filter interface{}, ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return nil, err } - r, err := transformBsoncoreDocument(coll.registry, replacement, true, "replacement") + r, err := marshal(replacement, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -717,6 +769,8 @@ func (coll *Collection) ReplaceOne(ctx context.Context, filter interface{}, uOpts.Upsert = opt.Upsert uOpts.Hint = opt.Hint uOpts.Let = opt.Let + uOpts.Comment = opt.Comment + uOpts.BypassEmptyTsReplacement = opt.BypassEmptyTsReplacement updateOptions = append(updateOptions, uOpts) } @@ -728,12 +782,12 @@ func (coll *Collection) ReplaceOne(ctx context.Context, filter interface{}, // The pipeline parameter must be an array of documents, each representing an aggregation stage. The pipeline cannot // be nil but can be empty. The stage documents must all be non-nil. For a pipeline of bson.D documents, the // mongo.Pipeline type can be used. See -// https://docs.mongodb.com/manual/reference/operator/aggregation-pipeline/#db-collection-aggregate-stages for a list of +// https://www.mongodb.com/docs/manual/reference/operator/aggregation-pipeline/#db-collection-aggregate-stages for a list of // valid stages in aggregations. // // The opts parameter can be used to specify options for the operation (see the options.AggregateOptions documentation.) // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/aggregate/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/aggregate/. func (coll *Collection) Aggregate(ctx context.Context, pipeline interface{}, opts ...*options.AggregateOptions) (*Cursor, error) { a := aggregateParams{ @@ -743,6 +797,7 @@ func (coll *Collection) Aggregate(ctx context.Context, pipeline interface{}, registry: coll.registry, readConcern: coll.readConcern, writeConcern: coll.writeConcern, + bsonOpts: coll.bsonOpts, retryRead: coll.client.retryReads, db: coll.db.name, col: coll.name, @@ -754,13 +809,13 @@ func (coll *Collection) Aggregate(ctx context.Context, pipeline interface{}, return aggregate(a) } -// aggreate is the helper method for Aggregate +// aggregate is the helper method for Aggregate func aggregate(a aggregateParams) (cur *Cursor, err error) { if a.ctx == nil { a.ctx = context.Background() } - pipelineArr, hasOutputStage, err := transformAggregatePipeline(a.registry, a.pipeline) + pipelineArr, hasOutputStage, err := marshalAggregatePipeline(a.pipeline, a.bsonOpts, a.registry) if err != nil { return nil, err } @@ -773,10 +828,7 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { } }() if sess == nil && a.client.sessionPool != nil { - sess, err = session.NewClientSession(a.client.sessionPool, a.client.id, session.Implicit) - if err != nil { - return nil, err - } + sess = session.NewImplicitClientSession(a.client.sessionPool, a.client.id) } if err = a.client.validSession(sess); err != nil { return nil, err @@ -802,8 +854,11 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { } ao := options.MergeAggregateOptions(a.opts...) + cursorOpts := a.client.createBaseCursorOptions() + cursorOpts.MarshalValueEncoderFn = newEncoderFn(a.bsonOpts, a.registry) + op := operation.NewAggregate(pipelineArr). Session(sess). WriteConcern(wc). @@ -817,7 +872,19 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { Deployment(a.client.deployment). Crypt(a.client.cryptFLE). ServerAPI(a.client.serverAPI). - HasOutputStage(hasOutputStage) + HasOutputStage(hasOutputStage). + Timeout(a.client.timeout). + MaxTime(ao.MaxTime). + Authenticator(a.client.authenticator) + + // Omit "maxTimeMS" from operations that return a user-managed cursor to + // prevent confusing "cursor not found" errors. To maintain existing + // behavior for users who set "timeoutMS" with no context deadline, only + // omit "maxTimeMS" when a context deadline is set. + // + // See DRIVERS-2722 for more detail. + _, deadlineSet := a.ctx.Deadline() + op.OmitCSOTMaxTimeMS(deadlineSet) if ao.AllowDiskUse != nil { op.AllowDiskUse(*ao.AllowDiskUse) @@ -833,24 +900,30 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { if ao.Collation != nil { op.Collation(bsoncore.Document(ao.Collation.ToDocument())) } - if ao.MaxTime != nil { - op.MaxTimeMS(int64(*ao.MaxTime / time.Millisecond)) - } if ao.MaxAwaitTime != nil { cursorOpts.MaxTimeMS = int64(*ao.MaxAwaitTime / time.Millisecond) } if ao.Comment != nil { op.Comment(*ao.Comment) + + commentVal, err := marshalValue(ao.Comment, a.bsonOpts, a.registry) + if err != nil { + return nil, err + } + cursorOpts.Comment = commentVal } if ao.Hint != nil { - hintVal, err := transformValue(a.registry, ao.Hint, false, "hint") + if isUnorderedMap(ao.Hint) { + return nil, ErrMapForOrderedArgument{"hint"} + } + hintVal, err := marshalValue(ao.Hint, a.bsonOpts, a.registry) if err != nil { return nil, err } op.Hint(hintVal) } if ao.Let != nil { - let, err := transformBsoncoreDocument(a.registry, ao.Let, true, "let") + let, err := marshal(ao.Let, a.bsonOpts, a.registry) if err != nil { return nil, err } @@ -889,7 +962,7 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { if err != nil { return nil, replaceErrors(err) } - cursor, err := newCursorWithSession(bc, a.registry, sess) + cursor, err := newCursorWithSession(bc, a.client.bsonOpts, a.registry, sess) return cursor, replaceErrors(err) } @@ -910,17 +983,14 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, countOpts := options.MergeCountOptions(opts...) - pipelineArr, err := countDocumentsAggregatePipeline(coll.registry, filter, countOpts) + pipelineArr, err := countDocumentsAggregatePipeline(filter, coll.bsonOpts, coll.registry, countOpts) if err != nil { return 0, err } sess := sessionFromContext(ctx) if sess == nil && coll.client.sessionPool != nil { - sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) - if err != nil { - return 0, err - } + sess = session.NewImplicitClientSession(coll.client.sessionPool, coll.client.id) defer sess.EndSession() } if err = coll.client.validSession(sess); err != nil { @@ -935,15 +1005,19 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, selector := makeReadPrefSelector(sess, coll.readSelector, coll.client.localThreshold) op := operation.NewAggregate(pipelineArr).Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference). CommandMonitor(coll.client.monitor).ServerSelector(selector).ClusterClock(coll.client.clock).Database(coll.db.name). - Collection(coll.name).Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI) + Collection(coll.name).Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). + Timeout(coll.client.timeout).MaxTime(countOpts.MaxTime).Authenticator(coll.client.authenticator) if countOpts.Collation != nil { op.Collation(bsoncore.Document(countOpts.Collation.ToDocument())) } - if countOpts.MaxTime != nil { - op.MaxTimeMS(int64(*countOpts.MaxTime / time.Millisecond)) + if countOpts.Comment != nil { + op.Comment(*countOpts.Comment) } if countOpts.Hint != nil { - hintVal, err := transformValue(coll.registry, countOpts.Hint, false, "hint") + if isUnorderedMap(countOpts.Hint) { + return 0, ErrMapForOrderedArgument{"hint"} + } + hintVal, err := marshalValue(countOpts.Hint, coll.bsonOpts, coll.registry) if err != nil { return 0, err } @@ -984,7 +1058,7 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, // The opts parameter can be used to specify options for the operation (see the options.EstimatedDocumentCountOptions // documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/count/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/count/. func (coll *Collection) EstimatedDocumentCount(ctx context.Context, opts ...*options.EstimatedDocumentCountOptions) (int64, error) { @@ -996,10 +1070,7 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, var err error if sess == nil && coll.client.sessionPool != nil { - sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) - if err != nil { - return 0, err - } + sess = session.NewImplicitClientSession(coll.client.sessionPool, coll.client.id) defer sess.EndSession() } @@ -1013,16 +1084,23 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, rc = nil } + co := options.MergeEstimatedDocumentCountOptions(opts...) + selector := makeReadPrefSelector(sess, coll.readSelector, coll.client.localThreshold) op := operation.NewCount().Session(sess).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). - ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI) + ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). + Timeout(coll.client.timeout).MaxTime(co.MaxTime).Authenticator(coll.client.authenticator) - co := options.MergeEstimatedDocumentCountOptions(opts...) - if co.MaxTime != nil { - op = op.MaxTimeMS(int64(*co.MaxTime / time.Millisecond)) + if co.Comment != nil { + comment, err := marshalValue(co.Comment, coll.bsonOpts, coll.registry) + if err != nil { + return 0, err + } + op = op.Comment(comment) } + retry := driver.RetryNone if coll.client.retryReads { retry = driver.RetryOncePerCommand @@ -1030,7 +1108,6 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, op.Retry(retry) err = op.Execute(ctx) - return op.Result().N, replaceErrors(err) } @@ -1043,7 +1120,7 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, // // The opts parameter can be used to specify options for the operation (see the options.DistinctOptions documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/distinct/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/distinct/. func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter interface{}, opts ...*options.DistinctOptions) ([]interface{}, error) { @@ -1051,7 +1128,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -1059,10 +1136,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i sess := sessionFromContext(ctx) if sess == nil && coll.client.sessionPool != nil { - sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) - if err != nil { - return nil, err - } + sess = session.NewImplicitClientSession(coll.client.sessionPool, coll.client.id) defer sess.EndSession() } @@ -1083,13 +1157,18 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i Session(sess).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). - ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI) + ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). + Timeout(coll.client.timeout).MaxTime(option.MaxTime).Authenticator(coll.client.authenticator) if option.Collation != nil { op.Collation(bsoncore.Document(option.Collation.ToDocument())) } - if option.MaxTime != nil { - op.MaxTimeMS(int64(*option.MaxTime / time.Millisecond)) + if option.Comment != nil { + comment, err := marshalValue(option.Comment, coll.bsonOpts, coll.registry) + if err != nil { + return nil, err + } + op.Comment(comment) } retry := driver.RetryNone if coll.client.retryReads { @@ -1132,7 +1211,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i // // The opts parameter can be used to specify options for the operation (see the options.FindOptions documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/find/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/find/. func (coll *Collection) Find(ctx context.Context, filter interface{}, opts ...*options.FindOptions) (cur *Cursor, err error) { @@ -1140,7 +1219,24 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + // Omit "maxTimeMS" from operations that return a user-managed cursor to + // prevent confusing "cursor not found" errors. To maintain existing + // behavior for users who set "timeoutMS" with no context deadline, only + // omit "maxTimeMS" when a context deadline is set. + // + // See DRIVERS-2722 for more detail. + _, deadlineSet := ctx.Deadline() + return coll.find(ctx, filter, deadlineSet, opts...) +} + +func (coll *Collection) find( + ctx context.Context, + filter interface{}, + omitCSOTMaxTimeMS bool, + opts ...*options.FindOptions, +) (cur *Cursor, err error) { + + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -1153,11 +1249,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, } }() if sess == nil && coll.client.sessionPool != nil { - var err error - sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) - if err != nil { - return nil, err - } + sess = session.NewImplicitClientSession(coll.client.sessionPool, coll.client.id) } err = coll.client.validSession(sess) @@ -1170,16 +1262,21 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, rc = nil } + fo := options.MergeFindOptions(opts...) + selector := makeReadPrefSelector(sess, coll.readSelector, coll.client.localThreshold) op := operation.NewFind(f). Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference). CommandMonitor(coll.client.monitor).ServerSelector(selector). ClusterClock(coll.client.clock).Database(coll.db.name).Collection(coll.name). - Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI) + Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). + Timeout(coll.client.timeout).MaxTime(fo.MaxTime).Logger(coll.client.logger). + OmitCSOTMaxTimeMS(omitCSOTMaxTimeMS).Authenticator(coll.client.authenticator) - fo := options.MergeFindOptions(opts...) cursorOpts := coll.client.createBaseCursorOptions() + cursorOpts.MarshalValueEncoderFn = newEncoderFn(coll.bsonOpts, coll.registry) + if fo.AllowDiskUse != nil { op.AllowDiskUse(*fo.AllowDiskUse) } @@ -1195,6 +1292,12 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, } if fo.Comment != nil { op.Comment(*fo.Comment) + + commentVal, err := marshalValue(fo.Comment, coll.bsonOpts, coll.registry) + if err != nil { + return nil, err + } + cursorOpts.Comment = commentVal } if fo.CursorType != nil { switch *fo.CursorType { @@ -1206,14 +1309,17 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, } } if fo.Hint != nil { - hint, err := transformValue(coll.registry, fo.Hint, false, "hint") + if isUnorderedMap(fo.Hint) { + return nil, ErrMapForOrderedArgument{"hint"} + } + hint, err := marshalValue(fo.Hint, coll.bsonOpts, coll.registry) if err != nil { return nil, err } op.Hint(hint) } if fo.Let != nil { - let, err := transformBsoncoreDocument(coll.registry, fo.Let, true, "let") + let, err := marshal(fo.Let, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -1229,7 +1335,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, op.Limit(limit) } if fo.Max != nil { - max, err := transformBsoncoreDocument(coll.registry, fo.Max, true, "max") + max, err := marshal(fo.Max, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -1238,11 +1344,8 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, if fo.MaxAwaitTime != nil { cursorOpts.MaxTimeMS = int64(*fo.MaxAwaitTime / time.Millisecond) } - if fo.MaxTime != nil { - op.MaxTimeMS(int64(*fo.MaxTime / time.Millisecond)) - } if fo.Min != nil { - min, err := transformBsoncoreDocument(coll.registry, fo.Min, true, "min") + min, err := marshal(fo.Min, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -1255,7 +1358,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, op.OplogReplay(*fo.OplogReplay) } if fo.Projection != nil { - proj, err := transformBsoncoreDocument(coll.registry, fo.Projection, true, "projection") + proj, err := marshal(fo.Projection, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -1274,7 +1377,10 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, op.Snapshot(*fo.Snapshot) } if fo.Sort != nil { - sort, err := transformBsoncoreDocument(coll.registry, fo.Sort, false, "sort") + if isUnorderedMap(fo.Sort) { + return nil, ErrMapForOrderedArgument{"sort"} + } + sort, err := marshal(fo.Sort, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -1294,7 +1400,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, if err != nil { return nil, replaceErrors(err) } - return newCursorWithSession(bc, coll.registry, sess) + return newCursorWithSession(bc, coll.bsonOpts, coll.registry, sess) } // FindOne executes a find command and returns a SingleResult for one document in the collection. @@ -1305,7 +1411,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, // // The opts parameter can be used to specify options for this operation (see the options.FindOneOptions documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/find/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/find/. func (coll *Collection) FindOne(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) *SingleResult { @@ -1343,8 +1449,14 @@ func (coll *Collection) FindOne(ctx context.Context, filter interface{}, // by the server. findOpts = append(findOpts, options.Find().SetLimit(-1)) - cursor, err := coll.Find(ctx, filter, findOpts...) - return &SingleResult{cur: cursor, reg: coll.registry, err: replaceErrors(err)} + cursor, err := coll.find(ctx, filter, false, findOpts...) + return &SingleResult{ + ctx: ctx, + cur: cursor, + bsonOpts: coll.bsonOpts, + reg: coll.registry, + err: replaceErrors(err), + } } func (coll *Collection) findAndModify(ctx context.Context, op *operation.FindAndModify) *SingleResult { @@ -1355,10 +1467,7 @@ func (coll *Collection) findAndModify(ctx context.Context, op *operation.FindAnd sess := sessionFromContext(ctx) var err error if sess == nil && coll.client.sessionPool != nil { - sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) - if err != nil { - return &SingleResult{err: err} - } + sess = session.NewImplicitClientSession(coll.client.sessionPool, coll.client.id) defer sess.EndSession() } @@ -1398,7 +1507,12 @@ func (coll *Collection) findAndModify(ctx context.Context, op *operation.FindAnd return &SingleResult{err: err} } - return &SingleResult{rdr: bson.Raw(op.Result().Value), reg: coll.registry} + return &SingleResult{ + ctx: ctx, + rdr: bson.Raw(op.Result().Value), + bsonOpts: coll.bsonOpts, + reg: coll.registry, + } } // FindOneAndDelete executes a findAndModify command to delete at most one document in the collection. and returns the @@ -1411,45 +1525,56 @@ func (coll *Collection) findAndModify(ctx context.Context, op *operation.FindAnd // The opts parameter can be used to specify options for the operation (see the options.FindOneAndDeleteOptions // documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/findAndModify/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/findAndModify/. func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{}, opts ...*options.FindOneAndDeleteOptions) *SingleResult { - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } fod := options.MergeFindOneAndDeleteOptions(opts...) - op := operation.NewFindAndModify(f).Remove(true).ServerAPI(coll.client.serverAPI) + op := operation.NewFindAndModify(f).Remove(true).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). + MaxTime(fod.MaxTime).Authenticator(coll.client.authenticator) if fod.Collation != nil { op = op.Collation(bsoncore.Document(fod.Collation.ToDocument())) } - if fod.MaxTime != nil { - op = op.MaxTimeMS(int64(*fod.MaxTime / time.Millisecond)) + if fod.Comment != nil { + comment, err := marshalValue(fod.Comment, coll.bsonOpts, coll.registry) + if err != nil { + return &SingleResult{err: err} + } + op = op.Comment(comment) } if fod.Projection != nil { - proj, err := transformBsoncoreDocument(coll.registry, fod.Projection, true, "projection") + proj, err := marshal(fod.Projection, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } op = op.Fields(proj) } if fod.Sort != nil { - sort, err := transformBsoncoreDocument(coll.registry, fod.Sort, false, "sort") + if isUnorderedMap(fod.Sort) { + return &SingleResult{err: ErrMapForOrderedArgument{"sort"}} + } + sort, err := marshal(fod.Sort, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } op = op.Sort(sort) } if fod.Hint != nil { - hint, err := transformValue(coll.registry, fod.Hint, false, "hint") + if isUnorderedMap(fod.Hint) { + return &SingleResult{err: ErrMapForOrderedArgument{"hint"}} + } + hint, err := marshalValue(fod.Hint, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } op = op.Hint(hint) } if fod.Let != nil { - let, err := transformBsoncoreDocument(coll.registry, fod.Let, true, "let") + let, err := marshal(fod.Let, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } @@ -1467,20 +1592,20 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{} // ErrNoDocuments wil be returned. If the filter matches multiple documents, one will be selected from the matched set. // // The replacement parameter must be a document that will be used to replace the selected document. It cannot be nil -// and cannot contain any update operators (https://docs.mongodb.com/manual/reference/operator/update/). +// and cannot contain any update operators (https://www.mongodb.com/docs/manual/reference/operator/update/). // // The opts parameter can be used to specify options for the operation (see the options.FindOneAndReplaceOptions // documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/findAndModify/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/findAndModify/. func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{}, replacement interface{}, opts ...*options.FindOneAndReplaceOptions) *SingleResult { - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } - r, err := transformBsoncoreDocument(coll.registry, replacement, true, "replacement") + r, err := marshal(replacement, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } @@ -1490,18 +1615,23 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ fo := options.MergeFindOneAndReplaceOptions(opts...) op := operation.NewFindAndModify(f).Update(bsoncore.Value{Type: bsontype.EmbeddedDocument, Data: r}). - ServerAPI(coll.client.serverAPI) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).MaxTime(fo.MaxTime).Authenticator(coll.client.authenticator) + if fo.BypassDocumentValidation != nil && *fo.BypassDocumentValidation { op = op.BypassDocumentValidation(*fo.BypassDocumentValidation) } if fo.Collation != nil { op = op.Collation(bsoncore.Document(fo.Collation.ToDocument())) } - if fo.MaxTime != nil { - op = op.MaxTimeMS(int64(*fo.MaxTime / time.Millisecond)) + if fo.Comment != nil { + comment, err := marshalValue(fo.Comment, coll.bsonOpts, coll.registry) + if err != nil { + return &SingleResult{err: err} + } + op = op.Comment(comment) } if fo.Projection != nil { - proj, err := transformBsoncoreDocument(coll.registry, fo.Projection, true, "projection") + proj, err := marshal(fo.Projection, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } @@ -1511,7 +1641,10 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ op = op.NewDocument(*fo.ReturnDocument == options.After) } if fo.Sort != nil { - sort, err := transformBsoncoreDocument(coll.registry, fo.Sort, false, "sort") + if isUnorderedMap(fo.Sort) { + return &SingleResult{err: ErrMapForOrderedArgument{"sort"}} + } + sort, err := marshal(fo.Sort, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } @@ -1521,19 +1654,25 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ op = op.Upsert(*fo.Upsert) } if fo.Hint != nil { - hint, err := transformValue(coll.registry, fo.Hint, false, "hint") + if isUnorderedMap(fo.Hint) { + return &SingleResult{err: ErrMapForOrderedArgument{"hint"}} + } + hint, err := marshalValue(fo.Hint, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } op = op.Hint(hint) } if fo.Let != nil { - let, err := transformBsoncoreDocument(coll.registry, fo.Let, true, "let") + let, err := marshal(fo.Let, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } op = op.Let(let) } + if fo.BypassEmptyTsReplacement != nil { + op = op.BypassEmptyTsReplacement(*fo.BypassEmptyTsReplacement) + } return coll.findAndModify(ctx, op) } @@ -1546,13 +1685,13 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ // ErrNoDocuments wil be returned. If the filter matches multiple documents, one will be selected from the matched set. // // The update parameter must be a document containing update operators -// (https://docs.mongodb.com/manual/reference/operator/update/) and can be used to specify the modifications to be made +// (https://www.mongodb.com/docs/manual/reference/operator/update/) and can be used to specify the modifications to be made // to the selected document. It cannot be nil or empty. // // The opts parameter can be used to specify options for the operation (see the options.FindOneAndUpdateOptions // documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/findAndModify/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/findAndModify/. func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{}, update interface{}, opts ...*options.FindOneAndUpdateOptions) *SingleResult { @@ -1560,26 +1699,32 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } fo := options.MergeFindOneAndUpdateOptions(opts...) - op := operation.NewFindAndModify(f).ServerAPI(coll.client.serverAPI) + op := operation.NewFindAndModify(f).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). + MaxTime(fo.MaxTime).Authenticator(coll.client.authenticator) - u, err := transformUpdateValue(coll.registry, update, true) + u, err := marshalUpdateValue(update, coll.bsonOpts, coll.registry, true) if err != nil { return &SingleResult{err: err} } op = op.Update(u) if fo.ArrayFilters != nil { - filtersDoc, err := fo.ArrayFilters.ToArrayDocument() + af := fo.ArrayFilters + reg := coll.registry + if af.Registry != nil { + reg = af.Registry + } + filtersDoc, err := marshalValue(af.Filters, coll.bsonOpts, reg) if err != nil { return &SingleResult{err: err} } - op = op.ArrayFilters(bsoncore.Document(filtersDoc)) + op = op.ArrayFilters(filtersDoc.Data) } if fo.BypassDocumentValidation != nil && *fo.BypassDocumentValidation { op = op.BypassDocumentValidation(*fo.BypassDocumentValidation) @@ -1587,11 +1732,15 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} if fo.Collation != nil { op = op.Collation(bsoncore.Document(fo.Collation.ToDocument())) } - if fo.MaxTime != nil { - op = op.MaxTimeMS(int64(*fo.MaxTime / time.Millisecond)) + if fo.Comment != nil { + comment, err := marshalValue(fo.Comment, coll.bsonOpts, coll.registry) + if err != nil { + return &SingleResult{err: err} + } + op = op.Comment(comment) } if fo.Projection != nil { - proj, err := transformBsoncoreDocument(coll.registry, fo.Projection, true, "projection") + proj, err := marshal(fo.Projection, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } @@ -1601,7 +1750,10 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} op = op.NewDocument(*fo.ReturnDocument == options.After) } if fo.Sort != nil { - sort, err := transformBsoncoreDocument(coll.registry, fo.Sort, false, "sort") + if isUnorderedMap(fo.Sort) { + return &SingleResult{err: ErrMapForOrderedArgument{"sort"}} + } + sort, err := marshal(fo.Sort, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } @@ -1611,31 +1763,37 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} op = op.Upsert(*fo.Upsert) } if fo.Hint != nil { - hint, err := transformValue(coll.registry, fo.Hint, false, "hint") + if isUnorderedMap(fo.Hint) { + return &SingleResult{err: ErrMapForOrderedArgument{"hint"}} + } + hint, err := marshalValue(fo.Hint, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } op = op.Hint(hint) } if fo.Let != nil { - let, err := transformBsoncoreDocument(coll.registry, fo.Let, true, "let") + let, err := marshal(fo.Let, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } op = op.Let(let) } + if fo.BypassEmptyTsReplacement != nil { + op = op.BypassEmptyTsReplacement(*fo.BypassEmptyTsReplacement) + } return coll.findAndModify(ctx, op) } // Watch returns a change stream for all changes on the corresponding collection. See -// https://docs.mongodb.com/manual/changeStreams/ for more information about change streams. +// https://www.mongodb.com/docs/manual/changeStreams/ for more information about change streams. // // The Collection must be configured with read concern majority or no read concern for a change stream to be created // successfully. // // The pipeline parameter must be an array of documents, each representing a pipeline stage. The pipeline cannot be -// nil but can be empty. The stage documents must all be non-nil. See https://docs.mongodb.com/manual/changeStreams/ for +// nil but can be empty. The stage documents must all be non-nil. See https://www.mongodb.com/docs/manual/changeStreams/ for // a list of pipeline stages that can be used with change streams. For a pipeline of bson.D documents, the // mongo.Pipeline{} type can be used. // @@ -1648,6 +1806,7 @@ func (coll *Collection) Watch(ctx context.Context, pipeline interface{}, readConcern: coll.readConcern, readPreference: coll.readPreference, client: coll.client, + bsonOpts: coll.bsonOpts, registry: coll.registry, streamType: CollectionStream, collectionName: coll.Name(), @@ -1662,20 +1821,77 @@ func (coll *Collection) Indexes() IndexView { return IndexView{coll: coll} } +// SearchIndexes returns a SearchIndexView instance that can be used to perform operations on the search indexes for the collection. +func (coll *Collection) SearchIndexes() SearchIndexView { + c, _ := coll.Clone() // Clone() always return a nil error. + c.readConcern = nil + c.writeConcern = nil + return SearchIndexView{ + coll: c, + } +} + // Drop drops the collection on the server. This method ignores "namespace not found" errors so it is safe to drop // a collection that does not exist on the server. func (coll *Collection) Drop(ctx context.Context) error { + // Follow Client-Side Encryption specification to check for encryptedFields. + // Drop does not have an encryptedFields option. See: GODRIVER-2413. + // Check for encryptedFields from the client EncryptedFieldsMap. + // Check for encryptedFields from the server if EncryptedFieldsMap is set. + ef := coll.db.getEncryptedFieldsFromMap(coll.name) + if ef == nil && coll.db.client.encryptedFieldsMap != nil { + var err error + if ef, err = coll.db.getEncryptedFieldsFromServer(ctx, coll.name); err != nil { + return err + } + } + + if ef != nil { + return coll.dropEncryptedCollection(ctx, ef) + } + + return coll.drop(ctx) +} + +// dropEncryptedCollection drops a collection with EncryptedFields. +func (coll *Collection) dropEncryptedCollection(ctx context.Context, ef interface{}) error { + efBSON, err := marshal(ef, coll.bsonOpts, coll.registry) + if err != nil { + return fmt.Errorf("error transforming document: %w", err) + } + + // Drop the two encryption-related, associated collections: `escCollection` and `ecocCollection`. + // Drop ESCCollection. + escCollection, err := csfle.GetEncryptedStateCollectionName(efBSON, coll.name, csfle.EncryptedStateCollection) + if err != nil { + return err + } + if err := coll.db.Collection(escCollection).drop(ctx); err != nil { + return err + } + + // Drop ECOCCollection. + ecocCollection, err := csfle.GetEncryptedStateCollectionName(efBSON, coll.name, csfle.EncryptedCompactionCollection) + if err != nil { + return err + } + if err := coll.db.Collection(ecocCollection).drop(ctx); err != nil { + return err + } + + // Drop the data collection. + return coll.drop(ctx) +} + +// drop drops a collection without EncryptedFields. +func (coll *Collection) drop(ctx context.Context) error { if ctx == nil { ctx = context.Background() } sess := sessionFromContext(ctx) if sess == nil && coll.client.sessionPool != nil { - var err error - sess, err = session.NewClientSession(coll.client.sessionPool, coll.client.id, session.Implicit) - if err != nil { - return err - } + sess = session.NewImplicitClientSession(coll.client.sessionPool, coll.client.id) defer sess.EndSession() } @@ -1699,10 +1915,11 @@ func (coll *Collection) Drop(ctx context.Context) error { ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE). - ServerAPI(coll.client.serverAPI) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). + Authenticator(coll.client.authenticator) err = op.Execute(ctx) - // ignore namespace not found erorrs + // ignore namespace not found errors driverErr, ok := err.(driver.Error) if !ok || (ok && !driverErr.NamespaceNotFound()) { return replaceErrors(err) @@ -1710,26 +1927,52 @@ func (coll *Collection) Drop(ctx context.Context) error { return nil } -// makePinnedSelector makes a selector for a pinned session with a pinned server. Will attempt to do server selection on -// the pinned server but if that fails it will go through a list of default selectors -func makePinnedSelector(sess *session.Client, defaultSelector description.ServerSelector) description.ServerSelectorFunc { - return func(t description.Topology, svrs []description.Server) ([]description.Server, error) { - if sess != nil && sess.PinnedServer != nil { - // If there is a pinned server, try to find it in the list of candidates. - for _, candidate := range svrs { - if candidate.Addr == sess.PinnedServer.Addr { - return []description.Server{candidate}, nil - } - } +type pinnedServerSelector struct { + stringer fmt.Stringer + fallback description.ServerSelector + session *session.Client +} + +func (pss pinnedServerSelector) String() string { + if pss.stringer == nil { + return "" + } + + return pss.stringer.String() +} - return nil, nil +func (pss pinnedServerSelector) SelectServer( + t description.Topology, + svrs []description.Server, +) ([]description.Server, error) { + if pss.session != nil && pss.session.PinnedServer != nil { + // If there is a pinned server, try to find it in the list of candidates. + for _, candidate := range svrs { + if candidate.Addr == pss.session.PinnedServer.Addr { + return []description.Server{candidate}, nil + } } - return defaultSelector.SelectServer(t, svrs) + return nil, nil + } + + return pss.fallback.SelectServer(t, svrs) +} + +func makePinnedSelector(sess *session.Client, fallback description.ServerSelector) description.ServerSelector { + pss := pinnedServerSelector{ + session: sess, + fallback: fallback, } + + if srvSelectorStringer, ok := fallback.(fmt.Stringer); ok { + pss.stringer = srvSelectorStringer + } + + return pss } -func makeReadPrefSelector(sess *session.Client, selector description.ServerSelector, localThreshold time.Duration) description.ServerSelectorFunc { +func makeReadPrefSelector(sess *session.Client, selector description.ServerSelector, localThreshold time.Duration) description.ServerSelector { if sess != nil && sess.TransactionRunning() { selector = description.CompositeSelector([]description.ServerSelector{ description.ReadPrefSelector(sess.CurrentRp), @@ -1740,7 +1983,7 @@ func makeReadPrefSelector(sess *session.Client, selector description.ServerSelec return makePinnedSelector(sess, selector) } -func makeOutputAggregateSelector(sess *session.Client, rp *readpref.ReadPref, localThreshold time.Duration) description.ServerSelectorFunc { +func makeOutputAggregateSelector(sess *session.Client, rp *readpref.ReadPref, localThreshold time.Duration) description.ServerSelector { if sess != nil && sess.TransactionRunning() { // Use current transaction's read preference if available rp = sess.CurrentRp @@ -1752,3 +1995,11 @@ func makeOutputAggregateSelector(sess *session.Client, rp *readpref.ReadPref, lo }) return makePinnedSelector(sess, selector) } + +// isUnorderedMap returns true if val is a map with more than 1 element. It is typically used to +// check for unordered Go values that are used in nested command documents where different field +// orders mean different things. Examples are the "sort" and "hint" fields. +func isUnorderedMap(val interface{}) bool { + refValue := reflect.ValueOf(val) + return refValue.Kind() == reflect.Map && refValue.Len() > 1 +} diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/cursor.go b/vendor/go.mongodb.org/mongo-driver/mongo/cursor.go index 533cfce..1e01e39 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/cursor.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/cursor.go @@ -12,10 +12,12 @@ import ( "fmt" "io" "reflect" + "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" - "go.mongodb.org/mongo-driver/x/bsonx" + "go.mongodb.org/mongo-driver/bson/bsonrw" + "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/session" @@ -32,17 +34,27 @@ type Cursor struct { bc batchCursor batch *bsoncore.DocumentSequence batchLength int + bsonOpts *options.BSONOptions registry *bsoncodec.Registry clientSession *session.Client err error } -func newCursor(bc batchCursor, registry *bsoncodec.Registry) (*Cursor, error) { - return newCursorWithSession(bc, registry, nil) +func newCursor( + bc batchCursor, + bsonOpts *options.BSONOptions, + registry *bsoncodec.Registry, +) (*Cursor, error) { + return newCursorWithSession(bc, bsonOpts, registry, nil) } -func newCursorWithSession(bc batchCursor, registry *bsoncodec.Registry, clientSession *session.Client) (*Cursor, error) { +func newCursorWithSession( + bc batchCursor, + bsonOpts *options.BSONOptions, + registry *bsoncodec.Registry, + clientSession *session.Client, +) (*Cursor, error) { if registry == nil { registry = bson.DefaultRegistry } @@ -51,6 +63,7 @@ func newCursorWithSession(bc batchCursor, registry *bsoncodec.Registry, clientSe } c := &Cursor{ bc: bc, + bsonOpts: bsonOpts, registry: registry, clientSession: clientSession, } @@ -83,8 +96,6 @@ func NewCursorFromDocuments(documents []interface{}, err error, registry *bsonco switch t := doc.(type) { case nil: return nil, ErrNilDocument - case bsonx.Doc: - doc = t.Copy() case []byte: // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery. doc = bson.Raw(t) @@ -115,8 +126,8 @@ func (c *Cursor) ID() int64 { return c.bc.ID() } // Next gets the next document for this cursor. It returns true if there were no errors and the cursor has not been // exhausted. // -// Next blocks until a document is available, an error occurs, or ctx expires. If ctx expires, the -// error will be set to ctx.Err(). In an error case, Next will return false. +// Next blocks until a document is available or an error occurs. If the context expires, the cursor's error will +// be set to ctx.Err(). In case of an error, Next will return false. // // If Next returns false, subsequent calls will also return false. func (c *Cursor) Next(ctx context.Context) bool { @@ -125,10 +136,10 @@ func (c *Cursor) Next(ctx context.Context) bool { // TryNext attempts to get the next document for this cursor. It returns true if there were no errors and the next // document is available. This is only recommended for use with tailable cursors as a non-blocking alternative to -// Next. See https://docs.mongodb.com/manual/core/tailable-cursors/ for more information about tailable cursors. +// Next. See https://www.mongodb.com/docs/manual/core/tailable-cursors/ for more information about tailable cursors. // // TryNext returns false if the cursor is exhausted, an error occurs when getting results from the server, the next -// document is not yet available, or ctx expires. If ctx expires, the error will be set to ctx.Err(). +// document is not yet available, or ctx expires. If the context expires, the cursor's error will be set to ctx.Err(). // // If TryNext returns false and an error occurred or the cursor has been exhausted (i.e. c.Err() != nil || c.ID() == 0), // subsequent attempts will also return false. Otherwise, it is safe to call TryNext again until a document is @@ -149,13 +160,13 @@ func (c *Cursor) next(ctx context.Context, nonBlocking bool) bool { ctx = context.Background() } doc, err := c.batch.Next() - switch err { - case nil: + switch { + case err == nil: // Consume the next document in the current batch. c.batchLength-- c.Current = bson.Raw(doc) return true - case io.EOF: // Need to do a getMore + case errors.Is(err, io.EOF): // Need to do a getMore default: c.err = err return false @@ -193,12 +204,12 @@ func (c *Cursor) next(ctx context.Context, nonBlocking bool) bool { c.batch = c.bc.Batch() c.batchLength = c.batch.DocumentCount() doc, err = c.batch.Next() - switch err { - case nil: + switch { + case err == nil: c.batchLength-- c.Current = bson.Raw(doc) return true - case io.EOF: // Empty batch so we continue + case errors.Is(err, io.EOF): // Empty batch so we continue default: c.err = err return false @@ -206,10 +217,62 @@ func (c *Cursor) next(ctx context.Context, nonBlocking bool) bool { } } +func getDecoder( + data []byte, + opts *options.BSONOptions, + reg *bsoncodec.Registry, +) (*bson.Decoder, error) { + dec, err := bson.NewDecoder(bsonrw.NewBSONDocumentReader(data)) + if err != nil { + return nil, err + } + + if opts != nil { + if opts.AllowTruncatingDoubles { + dec.AllowTruncatingDoubles() + } + if opts.BinaryAsSlice { + dec.BinaryAsSlice() + } + if opts.DefaultDocumentD { + dec.DefaultDocumentD() + } + if opts.DefaultDocumentM { + dec.DefaultDocumentM() + } + if opts.UseJSONStructTags { + dec.UseJSONStructTags() + } + if opts.UseLocalTimeZone { + dec.UseLocalTimeZone() + } + if opts.ZeroMaps { + dec.ZeroMaps() + } + if opts.ZeroStructs { + dec.ZeroStructs() + } + } + + if reg != nil { + // TODO:(GODRIVER-2719): Remove error handling. + if err := dec.SetRegistry(reg); err != nil { + return nil, err + } + } + + return dec, nil +} + // Decode will unmarshal the current document into val and return any errors from the unmarshalling process without any // modification. If val is nil or is a typed nil, an error will be returned. func (c *Cursor) Decode(val interface{}) error { - return bson.UnmarshalWithRegistry(c.registry, c.Current, val) + dec, err := getDecoder(c.Current, c.bsonOpts, c.registry) + if err != nil { + return fmt.Errorf("error configuring BSON decoder: %w", err) + } + + return dec.Decode(val) } // Err returns the last error seen by the Cursor, or nil if no error has occurred. @@ -223,8 +286,9 @@ func (c *Cursor) Close(ctx context.Context) error { } // All iterates the cursor and decodes each document into results. The results parameter must be a pointer to a slice. -// The slice pointed to by results will be completely overwritten. This method will close the cursor after retrieving -// all documents. If the cursor has been iterated, any previously iterated documents will not be included in results. +// The slice pointed to by results will be completely overwritten. A nil slice pointer will not be modified if the cursor +// has been closed, exhausted, or is empty. This method will close the cursor after retrieving all documents. If the +// cursor has been iterated, any previously iterated documents will not be included in results. // // This method requires driver version >= 1.1.0. func (c *Cursor) All(ctx context.Context, results interface{}) error { @@ -246,7 +310,10 @@ func (c *Cursor) All(ctx context.Context, results interface{}) error { var index int var err error - defer c.Close(ctx) + // Defer a call to Close to try to clean up the cursor server-side when all + // documents have not been exhausted. Use context.Background() to ensure Close + // completes even if the context passed to All has errored. + defer c.Close(context.Background()) batch := c.batch // exhaust the current batch before iterating the batch cursor for { @@ -295,7 +362,12 @@ func (c *Cursor) addFromBatch(sliceVal reflect.Value, elemType reflect.Type, bat } currElem := sliceVal.Index(index).Addr().Interface() - if err = bson.UnmarshalWithRegistry(c.registry, doc, currElem); err != nil { + dec, err := getDecoder(doc, c.bsonOpts, c.registry) + if err != nil { + return sliceVal, index, fmt.Errorf("error configuring BSON decoder: %w", err) + } + err = dec.Decode(currElem) + if err != nil { return sliceVal, index, err } @@ -306,11 +378,35 @@ func (c *Cursor) addFromBatch(sliceVal reflect.Value, elemType reflect.Type, bat } func (c *Cursor) closeImplicitSession() { - if c.clientSession != nil && c.clientSession.SessionType == session.Implicit { + if c.clientSession != nil && c.clientSession.IsImplicit { c.clientSession.EndSession() } } +// SetBatchSize sets the number of documents to fetch from the database with +// each iteration of the cursor's "Next" method. Note that some operations set +// an initial cursor batch size, so this setting only affects subsequent +// document batches fetched from the database. +func (c *Cursor) SetBatchSize(batchSize int32) { + c.bc.SetBatchSize(batchSize) +} + +// SetMaxTime will set the maximum amount of time the server will allow the +// operations to execute. The server will error if this field is set but the +// cursor is not configured with awaitData=true. +// +// The time.Duration value passed by this setter will be converted and rounded +// down to the nearest millisecond. +func (c *Cursor) SetMaxTime(dur time.Duration) { + c.bc.SetMaxTime(dur) +} + +// SetComment will set a user-configurable comment that can be used to identify +// the operation in server logs. +func (c *Cursor) SetComment(comment interface{}) { + c.bc.SetComment(comment) +} + // BatchCursorFromCursor returns a driver.BatchCursor for the given Cursor. If there is no underlying // driver.BatchCursor, nil is returned. // diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/database.go b/vendor/go.mongodb.org/mongo-driver/mongo/database.go index b0066f0..5344c96 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/database.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/database.go @@ -10,15 +10,16 @@ import ( "context" "errors" "fmt" + "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/internal/csfle" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" - "go.mongodb.org/mongo-driver/x/bsonx" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" @@ -38,6 +39,7 @@ type Database struct { readPreference *readpref.ReadPref readSelector description.ServerSelector writeSelector description.ServerSelector + bsonOpts *options.BSONOptions registry *bsoncodec.Registry } @@ -59,6 +61,11 @@ func newDatabase(client *Client, name string, opts ...*options.DatabaseOptions) wc = dbOpt.WriteConcern } + bsonOpts := client.bsonOpts + if dbOpt.BSONOptions != nil { + bsonOpts = dbOpt.BSONOptions + } + reg := client.registry if dbOpt.Registry != nil { reg = dbOpt.Registry @@ -70,6 +77,7 @@ func newDatabase(client *Client, name string, opts ...*options.DatabaseOptions) readPreference: rp, readConcern: rc, writeConcern: wc, + bsonOpts: bsonOpts, registry: reg, } @@ -107,12 +115,12 @@ func (db *Database) Collection(name string, opts ...*options.CollectionOptions) // The pipeline parameter must be a slice of documents, each representing an aggregation stage. The pipeline // cannot be nil but can be empty. The stage documents must all be non-nil. For a pipeline of bson.D documents, the // mongo.Pipeline type can be used. See -// https://docs.mongodb.com/manual/reference/operator/aggregation-pipeline/#db-aggregate-stages for a list of valid +// https://www.mongodb.com/docs/manual/reference/operator/aggregation-pipeline/#db-aggregate-stages for a list of valid // stages in database-level aggregations. // // The opts parameter can be used to specify options for this operation (see the options.AggregateOptions documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/aggregate/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/aggregate/. func (db *Database) Aggregate(ctx context.Context, pipeline interface{}, opts ...*options.AggregateOptions) (*Cursor, error) { a := aggregateParams{ @@ -136,11 +144,7 @@ func (db *Database) processRunCommand(ctx context.Context, cmd interface{}, cursorCommand bool, opts ...*options.RunCmdOptions) (*operation.Command, *session.Client, error) { sess := sessionFromContext(ctx) if sess == nil && db.client.sessionPool != nil { - var err error - sess, err = session.NewClientSession(db.client.sessionPool, db.client.id, session.Implicit) - if err != nil { - return nil, sess, err - } + sess = session.NewImplicitClientSession(db.client.sessionPool, db.client.id) } err := db.client.validSession(sess) @@ -153,7 +157,11 @@ func (db *Database) processRunCommand(ctx context.Context, cmd interface{}, return nil, sess, errors.New("read preference in a transaction must be primary") } - runCmdDoc, err := transformBsoncoreDocument(db.registry, cmd, false, "cmd") + if isUnorderedMap(cmd) { + return nil, sess, ErrMapForOrderedArgument{"cmd"} + } + + runCmdDoc, err := marshal(cmd, db.bsonOpts, db.registry) if err != nil { return nil, sess, err } @@ -169,26 +177,39 @@ func (db *Database) processRunCommand(ctx context.Context, cmd interface{}, switch cursorCommand { case true: cursorOpts := db.client.createBaseCursorOptions() + + cursorOpts.MarshalValueEncoderFn = newEncoderFn(db.bsonOpts, db.registry) + op = operation.NewCursorCommand(runCmdDoc, cursorOpts) default: op = operation.NewCommand(runCmdDoc) } + return op.Session(sess).CommandMonitor(db.client.monitor). ServerSelector(readSelect).ClusterClock(db.client.clock). - Database(db.name).Deployment(db.client.deployment).ReadConcern(db.readConcern). - Crypt(db.client.cryptFLE).ReadPreference(ro.ReadPreference).ServerAPI(db.client.serverAPI), sess, nil + Database(db.name).Deployment(db.client.deployment). + Crypt(db.client.cryptFLE).ReadPreference(ro.ReadPreference).ServerAPI(db.client.serverAPI). + Timeout(db.client.timeout).Logger(db.client.logger).Authenticator(db.client.authenticator), sess, nil } -// RunCommand executes the given command against the database. This function does not obey the Database's read -// preference. To specify a read preference, the RunCmdOptions.ReadPreference option must be used. +// RunCommand executes the given command against the database. +// +// This function does not obey the Database's readPreference. To specify a read +// preference, the RunCmdOptions.ReadPreference option must be used. +// +// This function does not obey the Database's readConcern or writeConcern. A +// user must supply these values manually in the user-provided runCommand +// parameter. // // The runCommand parameter must be a document for the command to be executed. It cannot be nil. // This must be an order-preserving type such as bson.D. Map types such as bson.M are not valid. -// If the command document contains a session ID or any transaction-specific fields, the behavior is undefined. -// Specifying API versioning options in the command document and declaring an API version on the client is not supported. -// The behavior of RunCommand is undefined in this case. // // The opts parameter can be used to specify options for this operation (see the options.RunCmdOptions documentation). +// +// The behavior of RunCommand is undefined if the command document contains any of the following: +// - A session ID or any transaction-specific fields +// - API versioning options when an API version is already declared on the Client +// - maxTimeMS when Timeout is set on the Client func (db *Database) RunCommand(ctx context.Context, runCommand interface{}, opts ...*options.RunCmdOptions) *SingleResult { if ctx == nil { ctx = context.Background() @@ -204,9 +225,11 @@ func (db *Database) RunCommand(ctx context.Context, runCommand interface{}, opts // RunCommand can be used to run a write, thus execute may return a write error _, convErr := processWriteError(err) return &SingleResult{ - err: convErr, - rdr: bson.Raw(op.Result()), - reg: db.registry, + ctx: ctx, + err: convErr, + rdr: bson.Raw(op.Result()), + bsonOpts: db.bsonOpts, + reg: db.registry, } } @@ -217,9 +240,13 @@ func (db *Database) RunCommand(ctx context.Context, runCommand interface{}, opts // // The runCommand parameter must be a document for the command to be executed. It cannot be nil. // This must be an order-preserving type such as bson.D. Map types such as bson.M are not valid. -// If the command document contains a session ID or any transaction-specific fields, the behavior is undefined. // // The opts parameter can be used to specify options for this operation (see the options.RunCmdOptions documentation). +// +// The behavior of RunCommandCursor is undefined if the command document contains any of the following: +// - A session ID or any transaction-specific fields +// - API versioning options when an API version is already declared on the Client +// - maxTimeMS when Timeout is set on the Client func (db *Database) RunCommandCursor(ctx context.Context, runCommand interface{}, opts ...*options.RunCmdOptions) (*Cursor, error) { if ctx == nil { ctx = context.Background() @@ -233,6 +260,10 @@ func (db *Database) RunCommandCursor(ctx context.Context, runCommand interface{} if err = op.Execute(ctx); err != nil { closeImplicitSession(sess) + if errors.Is(err, driver.ErrNoCursor) { + return nil, errors.New( + "database response does not contain a cursor; try using RunCommand instead") + } return nil, replaceErrors(err) } @@ -241,7 +272,7 @@ func (db *Database) RunCommandCursor(ctx context.Context, runCommand interface{} closeImplicitSession(sess) return nil, replaceErrors(err) } - cursor, err := newCursorWithSession(bc, db.registry, sess) + cursor, err := newCursorWithSession(bc, db.bsonOpts, db.registry, sess) return cursor, replaceErrors(err) } @@ -254,11 +285,7 @@ func (db *Database) Drop(ctx context.Context) error { sess := sessionFromContext(ctx) if sess == nil && db.client.sessionPool != nil { - var err error - sess, err = session.NewClientSession(db.client.sessionPool, db.client.id, session.Implicit) - if err != nil { - return err - } + sess = session.NewImplicitClientSession(db.client.sessionPool, db.client.id) defer sess.EndSession() } @@ -281,7 +308,7 @@ func (db *Database) Drop(ctx context.Context) error { Session(sess).WriteConcern(wc).CommandMonitor(db.client.monitor). ServerSelector(selector).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment).Crypt(db.client.cryptFLE). - ServerAPI(db.client.serverAPI) + ServerAPI(db.client.serverAPI).Authenticator(db.client.authenticator) err = op.Execute(ctx) @@ -302,7 +329,7 @@ func (db *Database) Drop(ctx context.Context) error { // The opts parameter can be used to specify options for the operation (see the options.ListCollectionsOptions // documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/listCollections/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/listCollections/. // // BUG(benjirewis): ListCollectionSpecifications prevents listing more than 100 collections per database when running // against MongoDB version 2.6. @@ -339,7 +366,7 @@ func (db *Database) ListCollectionSpecifications(ctx context.Context, filter int // The opts parameter can be used to specify options for the operation (see the options.ListCollectionsOptions // documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/listCollections/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/listCollections/. // // BUG(benjirewis): ListCollections prevents listing more than 100 collections per database when running against // MongoDB version 2.6. @@ -348,17 +375,14 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt ctx = context.Background() } - filterDoc, err := transformBsoncoreDocument(db.registry, filter, true, "filter") + filterDoc, err := marshal(filter, db.bsonOpts, db.registry) if err != nil { return nil, err } sess := sessionFromContext(ctx) if sess == nil && db.client.sessionPool != nil { - sess, err = session.NewClientSession(db.client.sessionPool, db.client.id, session.Implicit) - if err != nil { - return nil, err - } + sess = session.NewImplicitClientSession(db.client.sessionPool, db.client.id) } err = db.client.validSession(sess) @@ -378,9 +402,12 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt Session(sess).ReadPreference(db.readPreference).CommandMonitor(db.client.monitor). ServerSelector(selector).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment).Crypt(db.client.cryptFLE). - ServerAPI(db.client.serverAPI) + ServerAPI(db.client.serverAPI).Timeout(db.client.timeout).Authenticator(db.client.authenticator) cursorOpts := db.client.createBaseCursorOptions() + + cursorOpts.MarshalValueEncoderFn = newEncoderFn(db.bsonOpts, db.registry) + if lco.NameOnly != nil { op = op.NameOnly(*lco.NameOnly) } @@ -409,7 +436,7 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt closeImplicitSession(sess) return nil, replaceErrors(err) } - cursor, err := newCursorWithSession(bc, db.registry, sess) + cursor, err := newCursorWithSession(bc, db.bsonOpts, db.registry, sess) return cursor, replaceErrors(err) } @@ -423,7 +450,7 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt // The opts parameter can be used to specify options for the operation (see the options.ListCollectionsOptions // documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/listCollections/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/listCollections/. // // BUG(benjirewis): ListCollectionNames prevents listing more than 100 collections per database when running against // MongoDB version 2.6. @@ -439,19 +466,13 @@ func (db *Database) ListCollectionNames(ctx context.Context, filter interface{}, names := make([]string, 0) for res.Next(ctx) { - next := &bsonx.Doc{} - err = res.Decode(next) + elem, err := res.Current.LookupErr("name") if err != nil { return nil, err } - elem, err := next.LookupErr("name") - if err != nil { - return nil, err - } - - if elem.Type() != bson.TypeString { - return nil, fmt.Errorf("incorrect type for 'name'. got %v. want %v", elem.Type(), bson.TypeString) + if elem.Type != bson.TypeString { + return nil, fmt.Errorf("incorrect type for 'name'. got %v. want %v", elem.Type, bson.TypeString) } elemName := elem.StringValue() @@ -478,13 +499,13 @@ func (db *Database) WriteConcern() *writeconcern.WriteConcern { } // Watch returns a change stream for all changes to the corresponding database. See -// https://docs.mongodb.com/manual/changeStreams/ for more information about change streams. +// https://www.mongodb.com/docs/manual/changeStreams/ for more information about change streams. // // The Database must be configured with read concern majority or no read concern for a change stream to be created // successfully. // // The pipeline parameter must be a slice of documents, each representing a pipeline stage. The pipeline cannot be -// nil but can be empty. The stage documents must all be non-nil. See https://docs.mongodb.com/manual/changeStreams/ for +// nil but can be empty. The stage documents must all be non-nil. See https://www.mongodb.com/docs/manual/changeStreams/ for // a list of pipeline stages that can be used with change streams. For a pipeline of bson.D documents, the // mongo.Pipeline{} type can be used. // @@ -512,10 +533,153 @@ func (db *Database) Watch(ctx context.Context, pipeline interface{}, // The opts parameter can be used to specify options for the operation (see the options.CreateCollectionOptions // documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/create/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/create/. func (db *Database) CreateCollection(ctx context.Context, name string, opts ...*options.CreateCollectionOptions) error { cco := options.MergeCreateCollectionOptions(opts...) - op := operation.NewCreate(name).ServerAPI(db.client.serverAPI) + // Follow Client-Side Encryption specification to check for encryptedFields. + // Check for encryptedFields from create options. + ef := cco.EncryptedFields + // Check for encryptedFields from the client EncryptedFieldsMap. + if ef == nil { + ef = db.getEncryptedFieldsFromMap(name) + } + if ef != nil { + return db.createCollectionWithEncryptedFields(ctx, name, ef, opts...) + } + + return db.createCollection(ctx, name, opts...) +} + +// getEncryptedFieldsFromServer tries to get an "encryptedFields" document associated with collectionName by running the "listCollections" command. +// Returns nil and no error if the listCollections command succeeds, but "encryptedFields" is not present. +func (db *Database) getEncryptedFieldsFromServer(ctx context.Context, collectionName string) (interface{}, error) { + // Check if collection has an EncryptedFields configured server-side. + collSpecs, err := db.ListCollectionSpecifications(ctx, bson.D{{"name", collectionName}}) + if err != nil { + return nil, err + } + if len(collSpecs) == 0 { + return nil, nil + } + if len(collSpecs) > 1 { + return nil, fmt.Errorf("expected 1 or 0 results from listCollections, got %v", len(collSpecs)) + } + collSpec := collSpecs[0] + rawValue, err := collSpec.Options.LookupErr("encryptedFields") + if errors.Is(err, bsoncore.ErrElementNotFound) { + return nil, nil + } else if err != nil { + return nil, err + } + + encryptedFields, ok := rawValue.DocumentOK() + if !ok { + return nil, fmt.Errorf("expected encryptedFields of %v to be document, got %v", collectionName, rawValue.Type) + } + + return encryptedFields, nil +} + +// getEncryptedFieldsFromMap tries to get an "encryptedFields" document associated with collectionName by checking the client EncryptedFieldsMap. +// Returns nil and no error if an EncryptedFieldsMap is not configured, or does not contain an entry for collectionName. +func (db *Database) getEncryptedFieldsFromMap(collectionName string) interface{} { + // Check the EncryptedFieldsMap + efMap := db.client.encryptedFieldsMap + if efMap == nil { + return nil + } + + namespace := db.name + "." + collectionName + + ef, ok := efMap[namespace] + if ok { + return ef + } + return nil +} + +// createCollectionWithEncryptedFields creates a collection with an EncryptedFields. +func (db *Database) createCollectionWithEncryptedFields(ctx context.Context, name string, ef interface{}, opts ...*options.CreateCollectionOptions) error { + efBSON, err := marshal(ef, db.bsonOpts, db.registry) + if err != nil { + return fmt.Errorf("error transforming document: %w", err) + } + + // Check the wire version to ensure server is 7.0.0 or newer. + // After the wire version check, and before creating the collections, it is possible the server state changes. + // That is OK. This wire version check is a best effort to inform users earlier if using a QEv2 driver with a QEv1 server. + { + const QEv2WireVersion = 21 + server, err := db.client.deployment.SelectServer(ctx, description.WriteSelector()) + if err != nil { + return fmt.Errorf("error selecting server to check maxWireVersion: %w", err) + } + conn, err := server.Connection(ctx) + if err != nil { + return fmt.Errorf("error getting connection to check maxWireVersion: %w", err) + } + defer conn.Close() + wireVersionRange := conn.Description().WireVersion + if wireVersionRange.Max < QEv2WireVersion { + return fmt.Errorf("Driver support of Queryable Encryption is incompatible with server. Upgrade server to use Queryable Encryption. Got maxWireVersion %v but need maxWireVersion >= %v", wireVersionRange.Max, QEv2WireVersion) + } + } + + // Create the two encryption-related, associated collections: `escCollection` and `ecocCollection`. + + stateCollectionOpts := options.CreateCollection(). + SetClusteredIndex(bson.D{{"key", bson.D{{"_id", 1}}}, {"unique", true}}) + // Create ESCCollection. + escCollection, err := csfle.GetEncryptedStateCollectionName(efBSON, name, csfle.EncryptedStateCollection) + if err != nil { + return err + } + + if err := db.createCollection(ctx, escCollection, stateCollectionOpts); err != nil { + return err + } + + // Create ECOCCollection. + ecocCollection, err := csfle.GetEncryptedStateCollectionName(efBSON, name, csfle.EncryptedCompactionCollection) + if err != nil { + return err + } + + if err := db.createCollection(ctx, ecocCollection, stateCollectionOpts); err != nil { + return err + } + + // Create a data collection with the 'encryptedFields' option. + op, err := db.createCollectionOperation(name, opts...) + if err != nil { + return err + } + + op.EncryptedFields(efBSON) + if err := db.executeCreateOperation(ctx, op); err != nil { + return err + } + + // Create an index on the __safeContent__ field in the collection @collectionName. + if _, err := db.Collection(name).Indexes().CreateOne(ctx, IndexModel{Keys: bson.D{{"__safeContent__", 1}}}); err != nil { + return fmt.Errorf("error creating safeContent index: %w", err) + } + + return nil +} + +// createCollection creates a collection without EncryptedFields. +func (db *Database) createCollection(ctx context.Context, name string, opts ...*options.CreateCollectionOptions) error { + op, err := db.createCollectionOperation(name, opts...) + if err != nil { + return err + } + return db.executeCreateOperation(ctx, op) +} + +func (db *Database) createCollectionOperation(name string, opts ...*options.CreateCollectionOptions) (*operation.Create, error) { + cco := options.MergeCreateCollectionOptions(opts...) + op := operation.NewCreate(name).ServerAPI(db.client.serverAPI).Authenticator(db.client.authenticator) if cco.Capped != nil { op.Capped(*cco.Capped) @@ -523,19 +687,26 @@ func (db *Database) CreateCollection(ctx context.Context, name string, opts ...* if cco.Collation != nil { op.Collation(bsoncore.Document(cco.Collation.ToDocument())) } + if cco.ChangeStreamPreAndPostImages != nil { + csppi, err := marshal(cco.ChangeStreamPreAndPostImages, db.bsonOpts, db.registry) + if err != nil { + return nil, err + } + op.ChangeStreamPreAndPostImages(csppi) + } if cco.DefaultIndexOptions != nil { idx, doc := bsoncore.AppendDocumentStart(nil) if cco.DefaultIndexOptions.StorageEngine != nil { - storageEngine, err := transformBsoncoreDocument(db.registry, cco.DefaultIndexOptions.StorageEngine, true, "storageEngine") + storageEngine, err := marshal(cco.DefaultIndexOptions.StorageEngine, db.bsonOpts, db.registry) if err != nil { - return err + return nil, err } doc = bsoncore.AppendDocumentElement(doc, "storageEngine", storageEngine) } doc, err := bsoncore.AppendDocumentEnd(doc, idx) if err != nil { - return err + return nil, err } op.IndexOptionDefaults(doc) @@ -547,9 +718,9 @@ func (db *Database) CreateCollection(ctx context.Context, name string, opts ...* op.Size(*cco.SizeInBytes) } if cco.StorageEngine != nil { - storageEngine, err := transformBsoncoreDocument(db.registry, cco.StorageEngine, true, "storageEngine") + storageEngine, err := marshal(cco.StorageEngine, db.bsonOpts, db.registry) if err != nil { - return err + return nil, err } op.StorageEngine(storageEngine) } @@ -560,9 +731,9 @@ func (db *Database) CreateCollection(ctx context.Context, name string, opts ...* op.ValidationLevel(*cco.ValidationLevel) } if cco.Validator != nil { - validator, err := transformBsoncoreDocument(db.registry, cco.Validator, true, "validator") + validator, err := marshal(cco.Validator, db.bsonOpts, db.registry) if err != nil { - return err + return nil, err } op.Validator(validator) } @@ -580,24 +751,43 @@ func (db *Database) CreateCollection(ctx context.Context, name string, opts ...* doc = bsoncore.AppendStringElement(doc, "granularity", *cco.TimeSeriesOptions.Granularity) } + if cco.TimeSeriesOptions.BucketMaxSpan != nil { + bmss := int64(*cco.TimeSeriesOptions.BucketMaxSpan / time.Second) + + doc = bsoncore.AppendInt64Element(doc, "bucketMaxSpanSeconds", bmss) + } + + if cco.TimeSeriesOptions.BucketRounding != nil { + brs := int64(*cco.TimeSeriesOptions.BucketRounding / time.Second) + + doc = bsoncore.AppendInt64Element(doc, "bucketRoundingSeconds", brs) + } + doc, err := bsoncore.AppendDocumentEnd(doc, idx) if err != nil { - return err + return nil, err } op.TimeSeries(doc) } + if cco.ClusteredIndex != nil { + clusteredIndex, err := marshal(cco.ClusteredIndex, db.bsonOpts, db.registry) + if err != nil { + return nil, err + } + op.ClusteredIndex(clusteredIndex) + } - return db.executeCreateOperation(ctx, op) + return op, nil } // CreateView executes a create command to explicitly create a view on the server. See -// https://docs.mongodb.com/manual/core/views/ for more information about views. This method requires driver version >= +// https://www.mongodb.com/docs/manual/core/views/ for more information about views. This method requires driver version >= // 1.4.0 and MongoDB version >= 3.4. // // The viewName parameter specifies the name of the view to create. // -// The viewOn parameter specifies the name of the collection or view on which this view will be created +// # The viewOn parameter specifies the name of the collection or view on which this view will be created // // The pipeline parameter specifies an aggregation pipeline that will be exececuted against the source collection or // view to create this view. @@ -607,7 +797,7 @@ func (db *Database) CreateCollection(ctx context.Context, name string, opts ...* func (db *Database) CreateView(ctx context.Context, viewName, viewOn string, pipeline interface{}, opts ...*options.CreateViewOptions) error { - pipelineArray, _, err := transformAggregatePipeline(db.registry, pipeline) + pipelineArray, _, err := marshalAggregatePipeline(pipeline, db.bsonOpts, db.registry) if err != nil { return err } @@ -615,7 +805,8 @@ func (db *Database) CreateView(ctx context.Context, viewName, viewOn string, pip op := operation.NewCreate(viewName). ViewOn(viewOn). Pipeline(pipelineArray). - ServerAPI(db.client.serverAPI) + ServerAPI(db.client.serverAPI). + Authenticator(db.client.authenticator) cvo := options.MergeCreateViewOptions(opts...) if cvo.Collation != nil { op.Collation(bsoncore.Document(cvo.Collation.ToDocument())) @@ -627,11 +818,7 @@ func (db *Database) CreateView(ctx context.Context, viewName, viewOn string, pip func (db *Database) executeCreateOperation(ctx context.Context, op *operation.Create) error { sess := sessionFromContext(ctx) if sess == nil && db.client.sessionPool != nil { - var err error - sess, err = session.NewClientSession(db.client.sessionPool, db.client.id, session.Implicit) - if err != nil { - return err - } + sess = session.NewImplicitClientSession(db.client.sessionPool, db.client.id) defer sess.EndSession() } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/description/description.go b/vendor/go.mongodb.org/mongo-driver/mongo/description/description.go index 40b1af1..e750e33 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/description/description.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/description/description.go @@ -4,6 +4,7 @@ // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +// Package description contains types and functions for describing the state of MongoDB clusters. package description // import "go.mongodb.org/mongo-driver/mongo/description" // Unknown is an unknown server or topology kind. diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/description/server.go b/vendor/go.mongodb.org/mongo-driver/mongo/description/server.go index 405efe9..19f2760 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/description/server.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/description/server.go @@ -13,7 +13,9 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" - "go.mongodb.org/mongo-driver/internal" + "go.mongodb.org/mongo-driver/internal/bsonutil" + "go.mongodb.org/mongo-driver/internal/handshake" + "go.mongodb.org/mongo-driver/internal/ptrutil" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/tag" ) @@ -31,34 +33,37 @@ type SelectedServer struct { type Server struct { Addr address.Address - Arbiters []string - AverageRTT time.Duration - AverageRTTSet bool - Compression []string // compression methods returned by server - CanonicalAddr address.Address - ElectionID primitive.ObjectID - HeartbeatInterval time.Duration - HelloOK bool - Hosts []string - LastError error - LastUpdateTime time.Time - LastWriteTime time.Time - MaxBatchCount uint32 - MaxDocumentSize uint32 - MaxMessageSize uint32 - Members []address.Address - Passives []string - Passive bool - Primary address.Address - ReadOnly bool - ServiceID *primitive.ObjectID // Only set for servers that are deployed behind a load balancer. - SessionTimeoutMinutes uint32 - SetName string - SetVersion uint32 - Tags tag.Set - TopologyVersion *TopologyVersion - Kind ServerKind - WireVersion *VersionRange + Arbiters []string + AverageRTT time.Duration + AverageRTTSet bool + Compression []string // compression methods returned by server + CanonicalAddr address.Address + ElectionID primitive.ObjectID + HeartbeatInterval time.Duration + HelloOK bool + Hosts []string + IsCryptd bool + LastError error + LastUpdateTime time.Time + LastWriteTime time.Time + MaxBatchCount uint32 + MaxDocumentSize uint32 + MaxMessageSize uint32 + Members []address.Address + Passives []string + Passive bool + Primary address.Address + ReadOnly bool + ServiceID *primitive.ObjectID // Only set for servers that are deployed behind a load balancer. + // Deprecated: Use SessionTimeoutMinutesPtr instead. + SessionTimeoutMinutes uint32 + SessionTimeoutMinutesPtr *int64 + SetName string + SetVersion uint32 + Tags tag.Set + TopologyVersion *TopologyVersion + Kind ServerKind + WireVersion *VersionRange } // NewServer creates a new server description from the given hello command response. @@ -72,12 +77,12 @@ func NewServer(addr address.Address, response bson.Raw) Server { var ok bool var isReplicaSet, isWritablePrimary, hidden, secondary, arbiterOnly bool var msg string - var version VersionRange + var versionRange VersionRange for _, element := range elements { switch element.Key() { case "arbiters": var err error - desc.Arbiters, err = internal.StringSliceFromRawElement(element) + desc.Arbiters, err = stringSliceFromRawElement(element) if err != nil { desc.LastError = err return desc @@ -90,7 +95,7 @@ func NewServer(addr address.Address, response bson.Raw) Server { } case "compression": var err error - desc.Compression, err = internal.StringSliceFromRawElement(element) + desc.Compression, err = stringSliceFromRawElement(element) if err != nil { desc.LastError = err return desc @@ -101,6 +106,12 @@ func NewServer(addr address.Address, response bson.Raw) Server { desc.LastError = fmt.Errorf("expected 'electionId' to be a objectID but it's a BSON %s", element.Value().Type) return desc } + case "iscryptd": + desc.IsCryptd, ok = element.Value().BooleanOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'iscryptd' to be a boolean but it's a BSON %s", element.Value().Type) + return desc + } case "helloOk": desc.HelloOK, ok = element.Value().BooleanOK() if !ok { @@ -115,7 +126,7 @@ func NewServer(addr address.Address, response bson.Raw) Server { } case "hosts": var err error - desc.Hosts, err = internal.StringSliceFromRawElement(element) + desc.Hosts, err = stringSliceFromRawElement(element) if err != nil { desc.LastError = err return desc @@ -126,7 +137,7 @@ func NewServer(addr address.Address, response bson.Raw) Server { desc.LastError = fmt.Errorf("expected 'isWritablePrimary' to be a boolean but it's a BSON %s", element.Value().Type) return desc } - case internal.LegacyHelloLowercase: + case handshake.LegacyHelloLowercase: isWritablePrimary, ok = element.Value().BooleanOK() if !ok { desc.LastError = fmt.Errorf("expected legacy hello to be a boolean but it's a BSON %s", element.Value().Type) @@ -159,7 +170,9 @@ func NewServer(addr address.Address, response bson.Raw) Server { desc.LastError = fmt.Errorf("expected 'logicalSessionTimeoutMinutes' to be an integer but it's a BSON %s", element.Value().Type) return desc } + desc.SessionTimeoutMinutes = uint32(i64) + desc.SessionTimeoutMinutesPtr = &i64 case "maxBsonObjectSize": i64, ok := element.Value().AsInt64OK() if !ok { @@ -189,13 +202,13 @@ func NewServer(addr address.Address, response bson.Raw) Server { } desc.CanonicalAddr = address.Address(me).Canonicalize() case "maxWireVersion": - version.Max, ok = element.Value().AsInt32OK() + versionRange.Max, ok = element.Value().AsInt32OK() if !ok { desc.LastError = fmt.Errorf("expected 'maxWireVersion' to be an integer but it's a BSON %s", element.Value().Type) return desc } case "minWireVersion": - version.Min, ok = element.Value().AsInt32OK() + versionRange.Min, ok = element.Value().AsInt32OK() if !ok { desc.LastError = fmt.Errorf("expected 'minWireVersion' to be an integer but it's a BSON %s", element.Value().Type) return desc @@ -218,7 +231,7 @@ func NewServer(addr address.Address, response bson.Raw) Server { } case "passives": var err error - desc.Passives, err = internal.StringSliceFromRawElement(element) + desc.Passives, err = stringSliceFromRawElement(element) if err != nil { desc.LastError = err return desc @@ -303,25 +316,27 @@ func NewServer(addr address.Address, response bson.Raw) Server { desc.Kind = Standalone - if isReplicaSet { + switch { + case isReplicaSet: desc.Kind = RSGhost - } else if desc.SetName != "" { - if isWritablePrimary { + case desc.SetName != "": + switch { + case isWritablePrimary: desc.Kind = RSPrimary - } else if hidden { + case hidden: desc.Kind = RSMember - } else if secondary { + case secondary: desc.Kind = RSSecondary - } else if arbiterOnly { + case arbiterOnly: desc.Kind = RSArbiter - } else { + default: desc.Kind = RSMember } - } else if msg == "isdbgrid" { + case msg == "isdbgrid": desc.Kind = Mongos } - desc.WireVersion = &version + desc.WireVersion = &versionRange return desc } @@ -455,7 +470,7 @@ func (s Server) Equal(other Server) bool { return false } - if s.SessionTimeoutMinutes != other.SessionTimeoutMinutes { + if ptrutil.CompareInt64(s.SessionTimeoutMinutesPtr, other.SessionTimeoutMinutesPtr) != 0 { return false } @@ -479,3 +494,11 @@ func sliceStringEqual(a []string, b []string) bool { } return true } + +// stringSliceFromRawElement decodes the provided BSON element into a []string. +// This internally calls StringSliceFromRawValue on the element's value. The +// error conditions outlined in that function's documentation apply for this +// function as well. +func stringSliceFromRawElement(element bson.RawElement) ([]string, error) { + return bsonutil.StringSliceFromRawValue(element.Key(), element.Value()) +} diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/description/server_selector.go b/vendor/go.mongodb.org/mongo-driver/mongo/description/server_selector.go index 8e810cb..176f0fb 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/description/server_selector.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/description/server_selector.go @@ -7,6 +7,7 @@ package description import ( + "encoding/json" "fmt" "math" "time" @@ -30,10 +31,48 @@ func (ssf ServerSelectorFunc) SelectServer(t Topology, s []Server) ([]Server, er return ssf(t, s) } +// serverSelectorInfo contains metadata concerning the server selector for the +// purpose of publication. +type serverSelectorInfo struct { + Type string + Data string `json:",omitempty"` + Selectors []serverSelectorInfo `json:",omitempty"` +} + +// String returns the JSON string representation of the serverSelectorInfo. +func (sss serverSelectorInfo) String() string { + bytes, _ := json.Marshal(sss) + + return string(bytes) +} + +// serverSelectorInfoGetter is an interface that defines an info() method to +// get the serverSelectorInfo. +type serverSelectorInfoGetter interface { + info() serverSelectorInfo +} + type compositeSelector struct { selectors []ServerSelector } +func (cs *compositeSelector) info() serverSelectorInfo { + csInfo := serverSelectorInfo{Type: "compositeSelector"} + + for _, sel := range cs.selectors { + if getter, ok := sel.(serverSelectorInfoGetter); ok { + csInfo.Selectors = append(csInfo.Selectors, getter.info()) + } + } + + return csInfo +} + +// String returns the JSON string representation of the compositeSelector. +func (cs *compositeSelector) String() string { + return cs.info().String() +} + // CompositeSelector combines multiple selectors into a single selector by applying them in order to the candidates // list. // @@ -68,8 +107,16 @@ func LatencySelector(latency time.Duration) ServerSelector { return &latencySelector{latency: latency} } -func (ls *latencySelector) SelectServer(t Topology, candidates []Server) ([]Server, error) { - if ls.latency < 0 { +func (latencySelector) info() serverSelectorInfo { + return serverSelectorInfo{Type: "latencySelector"} +} + +func (selector latencySelector) String() string { + return selector.info().String() +} + +func (selector *latencySelector) SelectServer(t Topology, candidates []Server) ([]Server, error) { + if selector.latency < 0 { return candidates, nil } if t.Kind == LoadBalanced { @@ -94,90 +141,119 @@ func (ls *latencySelector) SelectServer(t Topology, candidates []Server) ([]Serv return candidates, nil } - max := min + ls.latency + max := min + selector.latency - var result []Server - for _, candidate := range candidates { + viableIndexes := make([]int, 0, len(candidates)) + for i, candidate := range candidates { if candidate.AverageRTTSet { if candidate.AverageRTT <= max { - result = append(result, candidate) + viableIndexes = append(viableIndexes, i) } } } - + if len(viableIndexes) == len(candidates) { + return candidates, nil + } + result := make([]Server, len(viableIndexes)) + for i, idx := range viableIndexes { + result[i] = candidates[idx] + } return result, nil } } +type writeServerSelector struct{} + // WriteSelector selects all the writable servers. func WriteSelector() ServerSelector { - return ServerSelectorFunc(func(t Topology, candidates []Server) ([]Server, error) { - switch t.Kind { - case Single, LoadBalanced: - return candidates, nil - default: - result := []Server{} - for _, candidate := range candidates { - switch candidate.Kind { - case Mongos, RSPrimary, Standalone: - result = append(result, candidate) - } - } - return result, nil - } - }) + return writeServerSelector{} } -// ReadPrefSelector selects servers based on the provided read preference. -func ReadPrefSelector(rp *readpref.ReadPref) ServerSelector { - return readPrefSelector(rp, false) +func (writeServerSelector) info() serverSelectorInfo { + return serverSelectorInfo{Type: "writeSelector"} } -// OutputAggregateSelector selects servers based on the provided read preference given that the underlying operation is -// aggregate with an output stage. -func OutputAggregateSelector(rp *readpref.ReadPref) ServerSelector { - return readPrefSelector(rp, true) +func (selector writeServerSelector) String() string { + return selector.info().String() } -func readPrefSelector(rp *readpref.ReadPref, isOutputAggregate bool) ServerSelector { - return ServerSelectorFunc(func(t Topology, candidates []Server) ([]Server, error) { - if t.Kind == LoadBalanced { - // In LoadBalanced mode, there should only be one server in the topology and it must be selected. We check - // this before checking MaxStaleness support because there's no monitoring in this mode, so the candidate - // server wouldn't have a wire version set, which would result in an error. - return candidates, nil +func (writeServerSelector) SelectServer(t Topology, candidates []Server) ([]Server, error) { + switch t.Kind { + case Single, LoadBalanced: + return candidates, nil + default: + // Determine the capacity of the results slice. + selected := 0 + for _, candidate := range candidates { + switch candidate.Kind { + case Mongos, RSPrimary, Standalone: + selected++ + } } - if _, set := rp.MaxStaleness(); set { - for _, s := range candidates { - if s.Kind != Unknown { - if err := maxStalenessSupported(s.WireVersion); err != nil { - return nil, err - } - } + // Append candidates to the results slice. + result := make([]Server, 0, selected) + for _, candidate := range candidates { + switch candidate.Kind { + case Mongos, RSPrimary, Standalone: + result = append(result, candidate) } } + return result, nil + } +} - switch t.Kind { - case Single: - return candidates, nil - case ReplicaSetNoPrimary, ReplicaSetWithPrimary: - return selectForReplicaSet(rp, isOutputAggregate, t, candidates) - case Sharded: - return selectByKind(candidates, Mongos), nil - } +type readPrefServerSelector struct { + rp *readpref.ReadPref + isOutputAggregate bool +} - return nil, nil - }) +// ReadPrefSelector selects servers based on the provided read preference. +func ReadPrefSelector(rp *readpref.ReadPref) ServerSelector { + return readPrefServerSelector{ + rp: rp, + isOutputAggregate: false, + } } -// maxStalenessSupported returns an error if the given server version does not support max staleness. -func maxStalenessSupported(wireVersion *VersionRange) error { - if wireVersion != nil && wireVersion.Max < 5 { - return fmt.Errorf("max staleness is only supported for servers 3.4 or newer") +func (selector readPrefServerSelector) info() serverSelectorInfo { + return serverSelectorInfo{ + Type: "readPrefSelector", + Data: selector.rp.String(), } +} - return nil +func (selector readPrefServerSelector) String() string { + return selector.info().String() +} + +func (selector readPrefServerSelector) SelectServer(t Topology, candidates []Server) ([]Server, error) { + if t.Kind == LoadBalanced { + // In LoadBalanced mode, there should only be one server in the topology and it must be selected. We check + // this before checking MaxStaleness support because there's no monitoring in this mode, so the candidate + // server wouldn't have a wire version set, which would result in an error. + return candidates, nil + } + + switch t.Kind { + case Single: + return candidates, nil + case ReplicaSetNoPrimary, ReplicaSetWithPrimary: + return selectForReplicaSet(selector.rp, selector.isOutputAggregate, t, candidates) + case Sharded: + return selectByKind(candidates, Mongos), nil + } + + return nil, nil +} + +// OutputAggregateSelector selects servers based on the provided read preference +// given that the underlying operation is aggregate with an output stage. +func OutputAggregateSelector(rp *readpref.ReadPref) ServerSelector { + return readPrefServerSelector{ + rp: rp, + isOutputAggregate: true, + } } func selectForReplicaSet(rp *readpref.ReadPref, isOutputAggregate bool, t Topology, candidates []Server) ([]Server, error) { @@ -296,13 +372,21 @@ func selectByTagSet(candidates []Server, tagSets []tag.Set) []Server { } func selectByKind(candidates []Server, kind ServerKind) []Server { - var result []Server - for _, s := range candidates { + // Record the indices of viable candidates first and then append those to the returned slice + // to avoid appending costly Server structs directly as an optimization. + viableIndexes := make([]int, 0, len(candidates)) + for i, s := range candidates { if s.Kind == kind { - result = append(result, s) + viableIndexes = append(viableIndexes, i) } } - + if len(viableIndexes) == len(candidates) { + return candidates + } + result := make([]Server, len(viableIndexes)) + for i, idx := range viableIndexes { + result[i] = candidates[idx] + } return result } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/description/topology.go b/vendor/go.mongodb.org/mongo-driver/mongo/description/topology.go index 8544548..b082515 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/description/topology.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/description/topology.go @@ -14,11 +14,13 @@ import ( // Topology contains information about a MongoDB cluster. type Topology struct { - Servers []Server - SetName string - Kind TopologyKind - SessionTimeoutMinutes uint32 - CompatibilityErr error + Servers []Server + SetName string + Kind TopologyKind + // Deprecated: Use SessionTimeoutMinutesPtr instead. + SessionTimeoutMinutes uint32 + SessionTimeoutMinutesPtr *int64 + CompatibilityErr error } // String implements the Stringer interface. diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/doc.go b/vendor/go.mongodb.org/mongo-driver/mongo/doc.go index 669aa14..e0a5d66 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/doc.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/doc.go @@ -11,67 +11,67 @@ // Basic usage of the driver starts with creating a Client from a connection // string. To do so, call Connect: // -// ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) -// defer cancel() -// client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://foo:bar@localhost:27017")) -// if err != nil { return err } +// ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) +// defer cancel() +// client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://foo:bar@localhost:27017")) +// if err != nil { return err } // // This will create a new client and start monitoring the MongoDB server on localhost. // The Database and Collection types can be used to access the database: // -// collection := client.Database("baz").Collection("qux") +// collection := client.Database("baz").Collection("qux") // // A Collection can be used to query the database or insert documents: // -// res, err := collection.InsertOne(context.Background(), bson.M{"hello": "world"}) -// if err != nil { return err } -// id := res.InsertedID +// res, err := collection.InsertOne(context.Background(), bson.M{"hello": "world"}) +// if err != nil { return err } +// id := res.InsertedID // // Several methods return a cursor, which can be used like this: // -// cur, err := collection.Find(context.Background(), bson.D{}) -// if err != nil { log.Fatal(err) } -// defer cur.Close(context.Background()) -// for cur.Next(context.Background()) { -// // To decode into a struct, use cursor.Decode() -// result := struct{ -// Foo string -// Bar int32 -// }{} -// err := cur.Decode(&result) -// if err != nil { log.Fatal(err) } -// // do something with result... -// -// // To get the raw bson bytes use cursor.Current -// raw := cur.Current -// // do something with raw... -// } -// if err := cur.Err(); err != nil { -// return err -// } +// cur, err := collection.Find(context.Background(), bson.D{}) +// if err != nil { log.Fatal(err) } +// defer cur.Close(context.Background()) +// for cur.Next(context.Background()) { +// // To decode into a struct, use cursor.Decode() +// result := struct{ +// Foo string +// Bar int32 +// }{} +// err := cur.Decode(&result) +// if err != nil { log.Fatal(err) } +// // do something with result... +// +// // To get the raw bson bytes use cursor.Current +// raw := cur.Current +// // do something with raw... +// } +// if err := cur.Err(); err != nil { +// return err +// } // // Cursor.All will decode all of the returned elements at once: // -// var results []struct{ -// Foo string -// Bar int32 -// } -// if err = cur.All(context.Background(), &results); err != nil { -// log.Fatal(err) -// } -// // do something with results... +// var results []struct{ +// Foo string +// Bar int32 +// } +// if err = cur.All(context.Background(), &results); err != nil { +// log.Fatal(err) +// } +// // do something with results... // // Methods that only return a single document will return a *SingleResult, which works // like a *sql.Row: // -// result := struct{ -// Foo string -// Bar int32 -// }{} -// filter := bson.D{{"hello", "world"}} -// err := collection.FindOne(context.Background(), filter).Decode(&result) -// if err != nil { return err } -// // do something with result... +// result := struct{ +// Foo string +// Bar int32 +// }{} +// filter := bson.D{{"hello", "world"}} +// err := collection.FindOne(context.Background(), filter).Decode(&result) +// if err != nil { return err } +// // do something with result... // // All Client, Collection, and Database methods that take parameters of type interface{} // will return ErrNilDocument if nil is passed in for an interface{}. @@ -79,7 +79,7 @@ // Additional examples can be found under the examples directory in the driver's repository and // on the MongoDB website. // -// Error Handling +// # Error Handling // // Errors from the MongoDB server will implement the ServerError interface, which has functions to check for specific // error codes, labels, and message substrings. These can be used to check for and handle specific errors. Some methods, @@ -87,26 +87,47 @@ // functions will return true if any of the contained errors satisfy the check. // // There are also helper functions to check for certain specific types of errors: -// IsDuplicateKeyError(error) -// IsNetworkError(error) -// IsTimeout(error) // -// Potential DNS Issues +// IsDuplicateKeyError(error) +// IsNetworkError(error) +// IsTimeout(error) // -// Building with Go 1.11+ and using connection strings with the "mongodb+srv"[1] scheme is +// # Potential DNS Issues +// +// Building with Go 1.11+ and using connection strings with the "mongodb+srv"[1] scheme is unfortunately // incompatible with some DNS servers in the wild due to the change introduced in -// https://github.com/golang/go/issues/10622. If you receive an error with the message "cannot -// unmarshal DNS message" while running an operation, we suggest you use a different DNS server. +// https://github.com/golang/go/issues/10622. You may receive an error with the message "cannot unmarshal DNS message" +// while running an operation when using DNS servers that non-compliantly compress SRV records. Old versions of kube-dns +// and the native DNS resolver (systemd-resolver) on Ubuntu 18.04 are known to be non-compliant in this manner. We suggest +// using a different DNS server (8.8.8.8 is the common default), and, if that's not possible, avoiding the "mongodb+srv" +// scheme. // -// Client Side Encryption +// # Client Side Encryption // // Client-side encryption is a new feature in MongoDB 4.2 that allows specific data fields to be encrypted. Using this -// feature requires specifying the "cse" build tag during compilation. +// feature requires specifying the "cse" build tag during compilation: +// +// go build -tags cse +// +// Note: Auto encryption is an enterprise- and Atlas-only feature. +// +// The libmongocrypt C library is required when using client-side encryption. Specific versions of libmongocrypt +// are required for different versions of the Go Driver: +// +// - Go Driver v1.2.0 requires libmongocrypt v1.0.0 or higher +// +// - Go Driver v1.5.0 requires libmongocrypt v1.1.0 or higher +// +// - Go Driver v1.8.0 requires libmongocrypt v1.3.0 or higher // -// Note: Auto encryption is an enterprise-only feature. +// - Go Driver v1.10.0 requires libmongocrypt v1.5.0 or higher. +// There is a severe bug when calling RewrapManyDataKey with libmongocrypt versions less than 1.5.2. +// This bug may result in data corruption. +// Please use libmongocrypt 1.5.2 or higher when calling RewrapManyDataKey. // -// The libmongocrypt C library is required when using client-side encryption. libmongocrypt version 1.3.0 or higher is -// required when using driver version 1.8.0 or higher. To install libmongocrypt, follow the instructions for your +// - Go Driver v1.12.0 requires libmongocrypt v1.8.0 or higher. +// +// To install libmongocrypt, follow the instructions for your // operating system: // // 1. Linux: follow the instructions listed at @@ -117,29 +138,20 @@ // to install packages via brew and compile the libmongocrypt source code. // // 3. Windows: -// mkdir -p c:/libmongocrypt/bin -// mkdir -p c:/libmongocrypt/include -// -// // Run the curl command in an empty directory as it will create new directories when unpacked. -// curl https://s3.amazonaws.com/mciuploads/libmongocrypt/windows/latest_release/libmongocrypt.tar.gz --output libmongocrypt.tar.gz -// tar -xvzf libmongocrypt.tar.gz -// -// cp ./bin/mongocrypt.dll c:/libmongocrypt/bin -// cp ./include/mongocrypt/*.h c:/libmongocrypt/include -// export PATH=$PATH:/cygdrive/c/libmongocrypt/bin -// -// libmongocrypt communicates with the mongocryptd process for automatic encryption. This process can be started manually -// or auto-spawned by the driver itself. To enable auto-spawning, ensure the process binary is on the PATH. To start it -// manually, use AutoEncryptionOptions: -// -// aeo := options.AutoEncryption() -// mongocryptdOpts := map[string]interface{}{ -// "mongocryptdBypassSpawn": true, -// } -// aeo.SetExtraOptions(mongocryptdOpts) -// To specify a process URI for mongocryptd, the "mongocryptdURI" option can be passed in the ExtraOptions map as well. -// See the ClientSideEncryption and ClientSideEncryptionCreateKey examples below for code samples about using this -// feature. -// -// [1] See https://docs.mongodb.com/manual/reference/connection-string/#dns-seedlist-connection-format +// +// mkdir -p c:/libmongocrypt/bin +// mkdir -p c:/libmongocrypt/include +// +// // Run the curl command in an empty directory as it will create new directories when unpacked. +// curl https://s3.amazonaws.com/mciuploads/libmongocrypt/windows/latest_release/libmongocrypt.tar.gz --output libmongocrypt.tar.gz +// tar -xvzf libmongocrypt.tar.gz +// +// cp ./bin/mongocrypt.dll c:/libmongocrypt/bin +// cp ./include/mongocrypt/*.h c:/libmongocrypt/include +// export PATH=$PATH:/cygdrive/c/libmongocrypt/bin +// +// libmongocrypt communicates with the mongocryptd process or mongo_crypt shared library for automatic encryption. +// See AutoEncryptionOpts.SetExtraOptions for options to configure use of mongocryptd or mongo_crypt. +// +// [1] See https://www.mongodb.com/docs/manual/reference/connection-string/#dns-seedlist-connection-format package mongo diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/errors.go b/vendor/go.mongodb.org/mongo-driver/mongo/errors.go index a16efab..d92c9ca 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/errors.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/errors.go @@ -15,6 +15,7 @@ import ( "strings" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/internal/codecutil" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" @@ -46,7 +47,12 @@ func (e ErrMapForOrderedArgument) Error() string { } func replaceErrors(err error) error { - if err == topology.ErrTopologyClosed { + // Return nil when err is nil to avoid costly reflection logic below. + if err == nil { + return nil + } + + if errors.Is(err, topology.ErrTopologyClosed) { return ErrClientDisconnected } if de, ok := err.(driver.Error); ok { @@ -82,36 +88,69 @@ func replaceErrors(err error) error { return MongocryptError{Code: me.Code, Message: me.Message} } + if errors.Is(err, codecutil.ErrNilValue) { + return ErrNilValue + } + + if marshalErr, ok := err.(codecutil.MarshalError); ok { + return MarshalError{ + Value: marshalErr.Value, + Err: marshalErr.Err, + } + } + return err } -// IsDuplicateKeyError returns true if err is a duplicate key error +// IsDuplicateKeyError returns true if err is a duplicate key error. func IsDuplicateKeyError(err error) bool { - // handles SERVER-7164 and SERVER-11493 - for ; err != nil; err = unwrap(err) { - if e, ok := err.(ServerError); ok { - return e.HasErrorCode(11000) || e.HasErrorCode(11001) || e.HasErrorCode(12582) || - e.HasErrorCodeWithMessage(16460, " E11000 ") - } + if se := ServerError(nil); errors.As(err, &se) { + return se.HasErrorCode(11000) || // Duplicate key error. + se.HasErrorCode(11001) || // Duplicate key error on update. + // Duplicate key error in a capped collection. See SERVER-7164. + se.HasErrorCode(12582) || + // Mongos insert error caused by a duplicate key error. See + // SERVER-11493. + se.HasErrorCodeWithMessage(16460, " E11000 ") } return false } -// IsTimeout returns true if err is from a timeout +// timeoutErrs is a list of error values that indicate a timeout happened. +var timeoutErrs = [...]error{ + context.DeadlineExceeded, + driver.ErrDeadlineWouldBeExceeded, + topology.ErrServerSelectionTimeout, +} + +// IsTimeout returns true if err was caused by a timeout. For error chains, +// IsTimeout returns true if any error in the chain was caused by a timeout. func IsTimeout(err error) bool { - for ; err != nil; err = unwrap(err) { - // check unwrappable errors together - if err == context.DeadlineExceeded { + // Check if the error chain contains any of the timeout error values. + for _, target := range timeoutErrs { + if errors.Is(err, target) { return true } - if ne, ok := err.(net.Error); ok { - return ne.Timeout() - } - //timeout error labels - if le, ok := err.(labeledError); ok { - if le.HasErrorLabel("NetworkTimeoutError") || le.HasErrorLabel("ExceededTimeLimitError") { - return true - } + } + + // Check if the error chain contains any error types that can indicate + // timeout. + if errors.As(err, &topology.WaitQueueTimeoutError{}) { + return true + } + if ce := (CommandError{}); errors.As(err, &ce) && ce.IsMaxTimeMSExpiredError() { + return true + } + if we := (WriteException{}); errors.As(err, &we) && we.WriteConcernError != nil && we.WriteConcernError.IsMaxTimeMSExpiredError() { + return true + } + if ne := net.Error(nil); errors.As(err, &ne) { + return ne.Timeout() + } + // Check timeout error labels. + if le := LabeledError(nil); errors.As(err, &le) { + if le.HasErrorLabel("NetworkTimeoutError") || le.HasErrorLabel("ExceededTimeLimitError") { + return true } } @@ -132,7 +171,7 @@ func unwrap(err error) error { // errorHasLabel returns true if err contains the specified label func errorHasLabel(err error, label string) bool { for ; err != nil; err = unwrap(err) { - if le, ok := err.(labeledError); ok && le.HasErrorLabel(label) { + if le, ok := err.(LabeledError); ok && le.HasErrorLabel(label) { return true } } @@ -186,7 +225,8 @@ func (e MongocryptdError) Unwrap() error { return e.Wrapped } -type labeledError interface { +// LabeledError is an interface for errors with labels. +type LabeledError interface { error // HasErrorLabel returns true if the error contains the specified label. HasErrorLabel(string) bool @@ -195,11 +235,9 @@ type labeledError interface { // ServerError is the interface implemented by errors returned from the server. Custom implementations of this // interface should not be used in production. type ServerError interface { - error + LabeledError // HasErrorCode returns true if the error has the specified code. HasErrorCode(int) bool - // HasErrorLabel returns true if the error contains the specified label. - HasErrorLabel(string) bool // HasErrorMessage returns true if the error contains the specified message. HasErrorMessage(string) bool // HasErrorCodeWithMessage returns true if any of the contained errors have the specified code and message. @@ -300,7 +338,7 @@ func (we WriteError) HasErrorCode(code int) bool { // HasErrorLabel returns true if the error contains the specified label. WriteErrors do not contain labels, // so we always return false. -func (we WriteError) HasErrorLabel(label string) bool { +func (we WriteError) HasErrorLabel(string) bool { return false } @@ -362,6 +400,11 @@ func (wce WriteConcernError) Error() string { return wce.Message } +// IsMaxTimeMSExpiredError returns true if the error is a MaxTimeMSExpired error. +func (wce WriteConcernError) IsMaxTimeMSExpiredError() bool { + return wce.Code == 50 +} + // WriteException is the error type returned by the InsertOne, DeleteOne, DeleteMany, UpdateOne, UpdateMany, and // ReplaceOne operations. type WriteException struct { @@ -587,7 +630,7 @@ const ( // WriteConcernError will be returned over WriteErrors if both are present. func processWriteError(err error) (returnResult, error) { switch { - case err == driver.ErrUnacknowledgedWrite: + case errors.Is(err, driver.ErrUnacknowledgedWrite): return rrAll, ErrUnacknowledgedWrite case err != nil: switch tt := err.(type) { @@ -616,7 +659,8 @@ const batchErrorsTargetLength = 2000 // to the end. // // Example format: -// "[message 1, message 2, +8 more errors...]" +// +// "[message 1, message 2, +8 more errors...]" func joinBatchErrors(errs []error) string { var buf bytes.Buffer fmt.Fprint(&buf, "[") diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/index_view.go b/vendor/go.mongodb.org/mongo-driver/mongo/index_view.go index e8e260f..db65f75 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/index_view.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/index_view.go @@ -12,7 +12,6 @@ import ( "errors" "fmt" "strconv" - "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsontype" @@ -45,7 +44,7 @@ type IndexView struct { // IndexModel represents a new index to be created. type IndexModel struct { // A document describing which keys should be used for the index. It cannot be nil. This must be an order-preserving - // type such as bson.D. Map types such as bson.M are not valid. See https://docs.mongodb.com/manual/indexes/#indexes + // type such as bson.D. Map types such as bson.M are not valid. See https://www.mongodb.com/docs/manual/indexes/#indexes // for examples of valid documents. Keys interface{} @@ -65,7 +64,7 @@ func isNamespaceNotFoundError(err error) bool { // The opts parameter can be used to specify options for this operation (see the options.ListIndexesOptions // documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/listIndexes/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/listIndexes/. func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOptions) (*Cursor, error) { if ctx == nil { ctx = context.Background() @@ -73,11 +72,7 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption sess := sessionFromContext(ctx) if sess == nil && iv.coll.client.sessionPool != nil { - var err error - sess, err = session.NewClientSession(iv.coll.client.sessionPool, iv.coll.client.id, session.Implicit) - if err != nil { - return nil, err - } + sess = session.NewImplicitClientSession(iv.coll.client.sessionPool, iv.coll.client.id) } err := iv.coll.client.validSession(sess) @@ -91,21 +86,26 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption description.LatencySelector(iv.coll.client.localThreshold), }) selector = makeReadPrefSelector(sess, selector, iv.coll.client.localThreshold) + + // TODO(GODRIVER-3038): This operation should pass CSE to the ListIndexes + // Crypt setter to be applied to the operation. op := operation.NewListIndexes(). Session(sess).CommandMonitor(iv.coll.client.monitor). ServerSelector(selector).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name). - Deployment(iv.coll.client.deployment).ServerAPI(iv.coll.client.serverAPI) + Deployment(iv.coll.client.deployment).ServerAPI(iv.coll.client.serverAPI). + Timeout(iv.coll.client.timeout).Authenticator(iv.coll.client.authenticator) cursorOpts := iv.coll.client.createBaseCursorOptions() + + cursorOpts.MarshalValueEncoderFn = newEncoderFn(iv.coll.bsonOpts, iv.coll.registry) + lio := options.MergeListIndexesOptions(opts...) if lio.BatchSize != nil { op = op.BatchSize(*lio.BatchSize) cursorOpts.BatchSize = *lio.BatchSize } - if lio.MaxTime != nil { - op = op.MaxTimeMS(int64(*lio.MaxTime / time.Millisecond)) - } + op = op.MaxTime(lio.MaxTime) retry := driver.RetryNone if iv.coll.client.retryReads { retry = driver.RetryOncePerCommand @@ -128,7 +128,7 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption closeImplicitSession(sess) return nil, replaceErrors(err) } - cursor, err := newCursorWithSession(bc, iv.coll.registry, sess) + cursor, err := newCursorWithSession(bc, iv.coll.bsonOpts, iv.coll.registry, sess) return cursor, replaceErrors(err) } @@ -175,7 +175,7 @@ func (iv IndexView) CreateOne(ctx context.Context, model IndexModel, opts ...*op // The opts parameter can be used to specify options for this operation (see the options.CreateIndexesOptions // documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/createIndexes/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/createIndexes/. func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts ...*options.CreateIndexesOptions) ([]string, error) { names := make([]string, 0, len(models)) @@ -187,7 +187,11 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. return nil, fmt.Errorf("index model keys cannot be nil") } - keys, err := transformBsoncoreDocument(iv.coll.registry, model.Keys, false, "keys") + if isUnorderedMap(model.Keys) { + return nil, ErrMapForOrderedArgument{"keys"} + } + + keys, err := marshal(model.Keys, iv.coll.bsonOpts, iv.coll.registry) if err != nil { return nil, err } @@ -229,10 +233,7 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. sess := sessionFromContext(ctx) if sess == nil && iv.coll.client.sessionPool != nil { - sess, err = session.NewClientSession(iv.coll.client.sessionPool, iv.coll.client.id, session.Implicit) - if err != nil { - return nil, err - } + sess = session.NewImplicitClientSession(iv.coll.client.sessionPool, iv.coll.client.id) defer sess.EndSession() } @@ -253,16 +254,17 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. option := options.MergeCreateIndexesOptions(opts...) + // TODO(GODRIVER-3038): This operation should pass CSE to the CreateIndexes + // Crypt setter to be applied to the operation. + // + // This was added in GODRIVER-2413 for the 2.0 major release. op := operation.NewCreateIndexes(indexes). Session(sess).WriteConcern(wc).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name).CommandMonitor(iv.coll.client.monitor). - Deployment(iv.coll.client.deployment).ServerSelector(selector).ServerAPI(iv.coll.client.serverAPI) - - if option.MaxTime != nil { - op.MaxTimeMS(int64(*option.MaxTime / time.Millisecond)) - } + Deployment(iv.coll.client.deployment).ServerSelector(selector).ServerAPI(iv.coll.client.serverAPI). + Timeout(iv.coll.client.timeout).MaxTime(option.MaxTime).Authenticator(iv.coll.client.authenticator) if option.CommitQuorum != nil { - commitQuorum, err := transformValue(iv.coll.registry, option.CommitQuorum, true, "commitQuorum") + commitQuorum, err := marshalValue(option.CommitQuorum, iv.coll.bsonOpts, iv.coll.registry) if err != nil { return nil, err } @@ -294,7 +296,7 @@ func (iv IndexView) createOptionsDoc(opts *options.IndexOptions) (bsoncore.Docum optsDoc = bsoncore.AppendBooleanElement(optsDoc, "sparse", *opts.Sparse) } if opts.StorageEngine != nil { - doc, err := transformBsoncoreDocument(iv.coll.registry, opts.StorageEngine, true, "storageEngine") + doc, err := marshal(opts.StorageEngine, iv.coll.bsonOpts, iv.coll.registry) if err != nil { return nil, err } @@ -317,7 +319,7 @@ func (iv IndexView) createOptionsDoc(opts *options.IndexOptions) (bsoncore.Docum optsDoc = bsoncore.AppendInt32Element(optsDoc, "textIndexVersion", *opts.TextVersion) } if opts.Weights != nil { - doc, err := transformBsoncoreDocument(iv.coll.registry, opts.Weights, true, "weights") + doc, err := marshal(opts.Weights, iv.coll.bsonOpts, iv.coll.registry) if err != nil { return nil, err } @@ -340,7 +342,7 @@ func (iv IndexView) createOptionsDoc(opts *options.IndexOptions) (bsoncore.Docum optsDoc = bsoncore.AppendInt32Element(optsDoc, "bucketSize", *opts.BucketSize) } if opts.PartialFilterExpression != nil { - doc, err := transformBsoncoreDocument(iv.coll.registry, opts.PartialFilterExpression, true, "partialFilterExpression") + doc, err := marshal(opts.PartialFilterExpression, iv.coll.bsonOpts, iv.coll.registry) if err != nil { return nil, err } @@ -351,7 +353,7 @@ func (iv IndexView) createOptionsDoc(opts *options.IndexOptions) (bsoncore.Docum optsDoc = bsoncore.AppendDocumentElement(optsDoc, "collation", bsoncore.Document(opts.Collation.ToDocument())) } if opts.WildcardProjection != nil { - doc, err := transformBsoncoreDocument(iv.coll.registry, opts.WildcardProjection, true, "wildcardProjection") + doc, err := marshal(opts.WildcardProjection, iv.coll.bsonOpts, iv.coll.registry) if err != nil { return nil, err } @@ -365,18 +367,14 @@ func (iv IndexView) createOptionsDoc(opts *options.IndexOptions) (bsoncore.Docum return optsDoc, nil } -func (iv IndexView) drop(ctx context.Context, name string, opts ...*options.DropIndexesOptions) (bson.Raw, error) { +func (iv IndexView) drop(ctx context.Context, index any, opts ...*options.DropIndexesOptions) (bson.Raw, error) { if ctx == nil { ctx = context.Background() } sess := sessionFromContext(ctx) if sess == nil && iv.coll.client.sessionPool != nil { - var err error - sess, err = session.NewClientSession(iv.coll.client.sessionPool, iv.coll.client.id, session.Implicit) - if err != nil { - return nil, err - } + sess = session.NewImplicitClientSession(iv.coll.client.sessionPool, iv.coll.client.id) defer sess.EndSession() } @@ -396,14 +394,15 @@ func (iv IndexView) drop(ctx context.Context, name string, opts ...*options.Drop selector := makePinnedSelector(sess, iv.coll.writeSelector) dio := options.MergeDropIndexesOptions(opts...) - op := operation.NewDropIndexes(name). - Session(sess).WriteConcern(wc).CommandMonitor(iv.coll.client.monitor). + + // TODO(GODRIVER-3038): This operation should pass CSE to the DropIndexes + // Crypt setter to be applied to the operation. + op := operation.NewDropIndexes(index).Session(sess).WriteConcern(wc).CommandMonitor(iv.coll.client.monitor). ServerSelector(selector).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name). - Deployment(iv.coll.client.deployment).ServerAPI(iv.coll.client.serverAPI) - if dio.MaxTime != nil { - op.MaxTimeMS(int64(*dio.MaxTime / time.Millisecond)) - } + Deployment(iv.coll.client.deployment).ServerAPI(iv.coll.client.serverAPI). + Timeout(iv.coll.client.timeout).MaxTime(dio.MaxTime). + Authenticator(iv.coll.client.authenticator) err = op.Execute(ctx) if err != nil { @@ -427,7 +426,7 @@ func (iv IndexView) drop(ctx context.Context, name string, opts ...*options.Drop // The opts parameter can be used to specify options for this operation (see the options.DropIndexesOptions // documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/dropIndexes/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/dropIndexes/. func (iv IndexView) DropOne(ctx context.Context, name string, opts ...*options.DropIndexesOptions) (bson.Raw, error) { if name == "*" { return nil, ErrMultipleIndexDrop @@ -436,6 +435,20 @@ func (iv IndexView) DropOne(ctx context.Context, name string, opts ...*options.D return iv.drop(ctx, name, opts...) } +// DropOneWithKey drops a collection index by key using the dropIndexes operation. If the operation succeeds, this returns +// a BSON document in the form {nIndexesWas: }. The "nIndexesWas" field in the response contains the number of +// indexes that existed prior to the drop. +// +// This function is useful to drop an index using its key specification instead of its name. +func (iv IndexView) DropOneWithKey(ctx context.Context, keySpecDocument interface{}, opts ...*options.DropIndexesOptions) (bson.Raw, error) { + doc, err := marshal(keySpecDocument, iv.coll.bsonOpts, iv.coll.registry) + if err != nil { + return nil, err + } + + return iv.drop(ctx, doc, opts...) +} + // DropAll executes a dropIndexes operation to drop all indexes on the collection. If the operation succeeds, this // returns a BSON document in the form {nIndexesWas: }. The "nIndexesWas" field in the response contains the // number of indexes that existed prior to the drop. @@ -443,7 +456,7 @@ func (iv IndexView) DropOne(ctx context.Context, name string, opts ...*options.D // The opts parameter can be used to specify options for this operation (see the options.DropIndexesOptions // documentation). // -// For more information about the command, see https://docs.mongodb.com/manual/reference/command/dropIndexes/. +// For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/dropIndexes/. func (iv IndexView) DropAll(ctx context.Context, opts ...*options.DropIndexesOptions) (bson.Raw, error) { return iv.drop(ctx, "*", opts...) } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/mongo.go b/vendor/go.mongodb.org/mongo-driver/mongo/mongo.go index da29175..ec8e817 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/mongo.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/mongo.go @@ -7,20 +7,23 @@ package mongo // import "go.mongodb.org/mongo-driver/mongo" import ( + "bytes" "context" "errors" "fmt" + "io" "net" "reflect" "strconv" "strings" + "go.mongodb.org/mongo-driver/internal/codecutil" "go.mongodb.org/mongo-driver/mongo/options" - "go.mongodb.org/mongo-driver/x/bsonx" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" ) @@ -34,6 +37,8 @@ type Dialer interface { // provided type into BSON bytes and append those bytes to the provided []byte. // The AppendBSON can return a non-nil error and non-nil []byte. The AppendBSON // method may also write incomplete BSON to the []byte. +// +// Deprecated: BSONAppender is unused and will be removed in Go Driver 2.0. type BSONAppender interface { AppendBSON([]byte, interface{}) ([]byte, error) } @@ -41,14 +46,18 @@ type BSONAppender interface { // BSONAppenderFunc is an adapter function that allows any function that // satisfies the AppendBSON method signature to be used where a BSONAppender is // used. +// +// Deprecated: BSONAppenderFunc is unused and will be removed in Go Driver 2.0. type BSONAppenderFunc func([]byte, interface{}) ([]byte, error) // AppendBSON implements the BSONAppender interface +// +// Deprecated: BSONAppenderFunc is unused and will be removed in Go Driver 2.0. func (baf BSONAppenderFunc) AppendBSON(dst []byte, val interface{}) ([]byte, error) { return baf(dst, val) } -// MarshalError is returned when attempting to transform a value into a document +// MarshalError is returned when attempting to marshal a value into a document // results in an error. type MarshalError struct { Value interface{} @@ -57,7 +66,7 @@ type MarshalError struct { // Error implements the error interface. func (me MarshalError) Error() string { - return fmt.Sprintf("cannot transform type %s to a BSON Document: %v", reflect.TypeOf(me.Value), me.Err) + return fmt.Sprintf("cannot marshal type %s to a BSON Document: %v", reflect.TypeOf(me.Value), me.Err) } // Pipeline is a type that makes creating aggregation pipelines easier. It is a @@ -65,76 +74,83 @@ func (me MarshalError) Error() string { // // Example usage: // -// mongo.Pipeline{ -// {{"$group", bson.D{{"_id", "$state"}, {"totalPop", bson.D{{"$sum", "$pop"}}}}}}, -// {{"$match", bson.D{{"totalPop", bson.D{{"$gte", 10*1000*1000}}}}}}, -// } -// +// mongo.Pipeline{ +// {{"$group", bson.D{{"_id", "$state"}, {"totalPop", bson.D{{"$sum", "$pop"}}}}}}, +// {{"$match", bson.D{{"totalPop", bson.D{{"$gte", 10*1000*1000}}}}}}, +// } type Pipeline []bson.D -// transformAndEnsureID is a hack that makes it easy to get a RawValue as the _id value. -// It will also add an ObjectID _id as the first key if it not already present in the passed-in val. -func transformAndEnsureID(registry *bsoncodec.Registry, val interface{}) (bsoncore.Document, interface{}, error) { - if registry == nil { - registry = bson.NewRegistryBuilder().Build() - } - switch tt := val.(type) { - case nil: - return nil, nil, ErrNilDocument - case bsonx.Doc: - val = tt.Copy() - case []byte: - // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery. - val = bson.Raw(tt) - } - - // TODO(skriptble): Use a pool of these instead. - doc := make(bsoncore.Document, 0, 256) - doc, err := bson.MarshalAppendWithRegistry(registry, doc, val) +// bvwPool is a pool of BSON value writers. BSON value writers +var bvwPool = bsonrw.NewBSONValueWriterPool() + +// getEncoder takes a writer, BSON options, and a BSON registry and returns a properly configured +// bson.Encoder that writes to the given writer. +func getEncoder( + w io.Writer, + opts *options.BSONOptions, + reg *bsoncodec.Registry, +) (*bson.Encoder, error) { + vw := bvwPool.Get(w) + enc, err := bson.NewEncoder(vw) if err != nil { - return nil, nil, MarshalError{Value: val, Err: err} + return nil, err } - var id interface{} - - value := doc.Lookup("_id") - switch value.Type { - case bsontype.Type(0): - value = bsoncore.Value{Type: bsontype.ObjectID, Data: bsoncore.AppendObjectID(nil, primitive.NewObjectID())} - olddoc := doc - doc = make(bsoncore.Document, 0, len(olddoc)+17) // type byte + _id + null byte + object ID - _, doc = bsoncore.ReserveLength(doc) - doc = bsoncore.AppendValueElement(doc, "_id", value) - doc = append(doc, olddoc[4:]...) // remove the length - doc = bsoncore.UpdateLength(doc, 0, int32(len(doc))) - default: - // We copy the bytes here to ensure that any bytes returned to the user aren't modified - // later. - buf := make([]byte, len(value.Data)) - copy(buf, value.Data) - value.Data = buf + if opts != nil { + if opts.ErrorOnInlineDuplicates { + enc.ErrorOnInlineDuplicates() + } + if opts.IntMinSize { + enc.IntMinSize() + } + if opts.NilByteSliceAsEmpty { + enc.NilByteSliceAsEmpty() + } + if opts.NilMapAsEmpty { + enc.NilMapAsEmpty() + } + if opts.NilSliceAsEmpty { + enc.NilSliceAsEmpty() + } + if opts.OmitZeroStruct { + enc.OmitZeroStruct() + } + if opts.StringifyMapKeysWithFmt { + enc.StringifyMapKeysWithFmt() + } + if opts.UseJSONStructTags { + enc.UseJSONStructTags() + } } - err = bson.RawValue{Type: value.Type, Value: value.Data}.UnmarshalWithRegistry(registry, &id) - if err != nil { - return nil, nil, err + if reg != nil { + // TODO:(GODRIVER-2719): Remove error handling. + if err := enc.SetRegistry(reg); err != nil { + return nil, err + } } - return doc, id, nil + return enc, nil } -func transformDocument(registry *bsoncodec.Registry, val interface{}) (bsonx.Doc, error) { - if doc, ok := val.(bsonx.Doc); ok { - return doc.Copy(), nil +// newEncoderFn will return a function for constructing an encoder based on the +// provided codec options. +func newEncoderFn(opts *options.BSONOptions, registry *bsoncodec.Registry) codecutil.EncoderFn { + return func(w io.Writer) (*bson.Encoder, error) { + return getEncoder(w, opts, registry) } - b, err := transformBsoncoreDocument(registry, val, true, "document") - if err != nil { - return nil, err - } - return bsonx.ReadDoc(b) } -func transformBsoncoreDocument(registry *bsoncodec.Registry, val interface{}, mapAllowed bool, paramName string) (bsoncore.Document, error) { +// marshal marshals the given value as a BSON document. Byte slices are always converted to a +// bson.Raw before marshaling. +// +// If bsonOpts and registry are specified, the encoder is configured with the requested behaviors. +// If they are nil, the default behaviors are used. +func marshal( + val interface{}, + bsonOpts *options.BSONOptions, + registry *bsoncodec.Registry, +) (bsoncore.Document, error) { if registry == nil { registry = bson.DefaultRegistry } @@ -145,20 +161,78 @@ func transformBsoncoreDocument(registry *bsoncodec.Registry, val interface{}, ma // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery. val = bson.Raw(bs) } - if !mapAllowed { - refValue := reflect.ValueOf(val) - if refValue.Kind() == reflect.Map && refValue.Len() > 1 { - return nil, ErrMapForOrderedArgument{paramName} - } + + buf := new(bytes.Buffer) + enc, err := getEncoder(buf, bsonOpts, registry) + if err != nil { + return nil, fmt.Errorf("error configuring BSON encoder: %w", err) } - // TODO(skriptble): Use a pool of these instead. - buf := make([]byte, 0, 256) - b, err := bson.MarshalAppendWithRegistry(registry, buf[:0], val) + err = enc.Encode(val) if err != nil { return nil, MarshalError{Value: val, Err: err} } - return b, nil + + return buf.Bytes(), nil +} + +// ensureID inserts the given ObjectID as an element named "_id" at the +// beginning of the given BSON document if there is not an "_id" already. +// If the given ObjectID is primitive.NilObjectID, a new object ID will be +// generated with time.Now(). +// +// If there is already an element named "_id", the document is not modified. It +// returns the resulting document and the decoded Go value of the "_id" element. +func ensureID( + doc bsoncore.Document, + oid primitive.ObjectID, + bsonOpts *options.BSONOptions, + reg *bsoncodec.Registry, +) (bsoncore.Document, interface{}, error) { + if reg == nil { + reg = bson.DefaultRegistry + } + + // Try to find the "_id" element. If it exists, try to unmarshal just the + // "_id" field as an interface{} and return it along with the unmodified + // BSON document. + if _, err := doc.LookupErr("_id"); err == nil { + var id struct { + ID interface{} `bson:"_id"` + } + dec, err := getDecoder(doc, bsonOpts, reg) + if err != nil { + return nil, nil, fmt.Errorf("error configuring BSON decoder: %w", err) + } + err = dec.Decode(&id) + if err != nil { + return nil, nil, fmt.Errorf("error unmarshaling BSON document: %w", err) + } + + return doc, id.ID, nil + } + + // We couldn't find an "_id" element, so add one with the value of the + // provided ObjectID. + + olddoc := doc + + // Reserve an extra 17 bytes for the "_id" field we're about to add: + // type (1) + "_id" (3) + terminator (1) + object ID (12) + const extraSpace = 17 + doc = make(bsoncore.Document, 0, len(olddoc)+extraSpace) + _, doc = bsoncore.ReserveLength(doc) + if oid.IsZero() { + oid = primitive.NewObjectID() + } + doc = bsoncore.AppendObjectIDElement(doc, "_id", oid) + + // Remove and re-write the BSON document length header. + const int32Len = 4 + doc = append(doc, olddoc[int32Len:]...) + doc = bsoncore.UpdateLength(doc, 0, int32(len(doc))) + + return doc, oid, nil } func ensureDollarKey(doc bsoncore.Document) error { @@ -181,7 +255,11 @@ func ensureNoDollarKey(doc bsoncore.Document) error { return nil } -func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface{}) (bsoncore.Document, bool, error) { +func marshalAggregatePipeline( + pipeline interface{}, + bsonOpts *options.BSONOptions, + registry *bsoncodec.Registry, +) (bsoncore.Document, bool, error) { switch t := pipeline.(type) { case bsoncodec.ValueMarshaler: btype, val, err := t.MarshalBSONValue() @@ -207,7 +285,7 @@ func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface default: val := reflect.ValueOf(t) if !val.IsValid() || (val.Kind() != reflect.Slice && val.Kind() != reflect.Array) { - return nil, false, fmt.Errorf("can only transform slices and arrays into aggregation pipelines, but got %v", val.Kind()) + return nil, false, fmt.Errorf("can only marshal slices and arrays into aggregation pipelines, but got %v", val.Kind()) } var hasOutputStage bool @@ -221,7 +299,7 @@ func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface return nil, false, fmt.Errorf("%T is not an allowed pipeline type as it represents a single document. Use bson.A or mongo.Pipeline instead", t) } - // bsoncore.Arrays do not need to be transformed. Only check validity and presence of output stage. + // bsoncore.Arrays do not need to be marshaled. Only check validity and presence of output stage. case bsoncore.Array: if err := t.Validate(); err != nil { return nil, false, err @@ -248,7 +326,7 @@ func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface aidx, arr := bsoncore.AppendArrayStart(nil) for idx := 0; idx < valLen; idx++ { - doc, err := transformBsoncoreDocument(registry, val.Index(idx).Interface(), true, fmt.Sprintf("pipeline stage :%v", idx)) + doc, err := marshal(val.Index(idx).Interface(), bsonOpts, registry) if err != nil { return nil, false, err } @@ -265,7 +343,12 @@ func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface } } -func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, dollarKeysAllowed bool) (bsoncore.Value, error) { +func marshalUpdateValue( + update interface{}, + bsonOpts *options.BSONOptions, + registry *bsoncodec.Registry, + dollarKeysAllowed bool, +) (bsoncore.Value, error) { documentCheckerFunc := ensureDollarKey if !dollarKeysAllowed { documentCheckerFunc = ensureNoDollarKey @@ -276,9 +359,9 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, doll switch t := update.(type) { case nil: return u, ErrNilDocument - case primitive.D, bsonx.Doc: + case primitive.D: u.Type = bsontype.EmbeddedDocument - u.Data, err = transformBsoncoreDocument(registry, update, true, "update") + u.Data, err = marshal(update, bsonOpts, registry) if err != nil { return u, err } @@ -316,11 +399,11 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, doll default: val := reflect.ValueOf(t) if !val.IsValid() { - return u, fmt.Errorf("can only transform slices and arrays into update pipelines, but got %v", val.Kind()) + return u, fmt.Errorf("can only marshal slices and arrays into update pipelines, but got %v", val.Kind()) } if val.Kind() != reflect.Slice && val.Kind() != reflect.Array { u.Type = bsontype.EmbeddedDocument - u.Data, err = transformBsoncoreDocument(registry, update, true, "update") + u.Data, err = marshal(update, bsonOpts, registry) if err != nil { return u, err } @@ -332,7 +415,7 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, doll aidx, arr := bsoncore.AppendArrayStart(nil) valLen := val.Len() for idx := 0; idx < valLen; idx++ { - doc, err := transformBsoncoreDocument(registry, val.Index(idx).Interface(), true, "update") + doc, err := marshal(val.Index(idx).Interface(), bsonOpts, registry) if err != nil { return u, err } @@ -348,33 +431,22 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, doll } } -func transformValue(registry *bsoncodec.Registry, val interface{}, mapAllowed bool, paramName string) (bsoncore.Value, error) { - if registry == nil { - registry = bson.DefaultRegistry - } - if val == nil { - return bsoncore.Value{}, ErrNilValue - } - - if !mapAllowed { - refValue := reflect.ValueOf(val) - if refValue.Kind() == reflect.Map && refValue.Len() > 1 { - return bsoncore.Value{}, ErrMapForOrderedArgument{paramName} - } - } - - buf := make([]byte, 0, 256) - bsonType, bsonValue, err := bson.MarshalValueAppendWithRegistry(registry, buf[:0], val) - if err != nil { - return bsoncore.Value{}, MarshalError{Value: val, Err: err} - } - - return bsoncore.Value{Type: bsonType, Data: bsonValue}, nil +func marshalValue( + val interface{}, + bsonOpts *options.BSONOptions, + registry *bsoncodec.Registry, +) (bsoncore.Value, error) { + return codecutil.MarshalValue(val, newEncoderFn(bsonOpts, registry)) } // Build the aggregation pipeline for the CountDocument command. -func countDocumentsAggregatePipeline(registry *bsoncodec.Registry, filter interface{}, opts *options.CountOptions) (bsoncore.Document, error) { - filterDoc, err := transformBsoncoreDocument(registry, filter, true, "filter") +func countDocumentsAggregatePipeline( + filter interface{}, + encOpts *options.BSONOptions, + registry *bsoncodec.Registry, + opts *options.CountOptions, +) (bsoncore.Document, error) { + filterDoc, err := marshal(filter, encOpts, registry) if err != nil { return nil, err } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/mongocryptd.go b/vendor/go.mongodb.org/mongo-driver/mongo/mongocryptd.go index c36b1d3..2603a39 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/mongocryptd.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/mongocryptd.go @@ -28,17 +28,21 @@ const ( var defaultTimeoutArgs = []string{"--idleShutdownTimeoutSecs=60"} var databaseOpts = options.Database().SetReadConcern(readconcern.New()).SetReadPreference(readpref.Primary()) -type mcryptClient struct { +type mongocryptdClient struct { bypassSpawn bool client *Client path string spawnArgs []string } -func newMcryptClient(opts *options.AutoEncryptionOptions) (*mcryptClient, error) { +// newMongocryptdClient creates a client to mongocryptd. +// newMongocryptdClient is expected to not be called if the crypt shared library is available. +// The crypt shared library replaces all mongocryptd functionality. +func newMongocryptdClient(opts *options.AutoEncryptionOptions) (*mongocryptdClient, error) { // create mcryptClient instance and spawn process if necessary var bypassSpawn bool var bypassAutoEncryption bool + if bypass, ok := opts.ExtraOptions["mongocryptdBypassSpawn"]; ok { bypassSpawn = bypass.(bool) } @@ -46,10 +50,14 @@ func newMcryptClient(opts *options.AutoEncryptionOptions) (*mcryptClient, error) bypassAutoEncryption = *opts.BypassAutoEncryption } - mc := &mcryptClient{ - // mongocryptd should not be spawned if mongocryptdBypassSpawn is passed or if bypassAutoEncryption is - // specified because it is not used during decryption - bypassSpawn: bypassSpawn || bypassAutoEncryption, + bypassQueryAnalysis := opts.BypassQueryAnalysis != nil && *opts.BypassQueryAnalysis + + mc := &mongocryptdClient{ + // mongocryptd should not be spawned if any of these conditions are true: + // - mongocryptdBypassSpawn is passed + // - bypassAutoEncryption is true because mongocryptd is not used during decryption + // - bypassQueryAnalysis is true because mongocryptd is not used during decryption + bypassSpawn: bypassSpawn || bypassAutoEncryption || bypassQueryAnalysis, } if !mc.bypassSpawn { @@ -76,14 +84,14 @@ func newMcryptClient(opts *options.AutoEncryptionOptions) (*mcryptClient, error) } // markCommand executes the given command on mongocryptd. -func (mc *mcryptClient) markCommand(ctx context.Context, dbName string, cmd bsoncore.Document) (bsoncore.Document, error) { +func (mc *mongocryptdClient) markCommand(ctx context.Context, dbName string, cmd bsoncore.Document) (bsoncore.Document, error) { // Remove the explicit session from the context if one is set. // The explicit session will be from a different client. // If an explicit session is set, it is applied after automatic encryption. ctx = NewSessionContext(ctx, nil) db := mc.client.Database(dbName, databaseOpts) - res, err := db.RunCommand(ctx, cmd).DecodeBytes() + res, err := db.RunCommand(ctx, cmd).Raw() // propagate original result if err == nil { return bsoncore.Document(res), nil @@ -97,7 +105,7 @@ func (mc *mcryptClient) markCommand(ctx context.Context, dbName string, cmd bson if err = mc.spawnProcess(); err != nil { return nil, err } - res, err = db.RunCommand(ctx, cmd).DecodeBytes() + res, err = db.RunCommand(ctx, cmd).Raw() if err != nil { return nil, MongocryptdError{Wrapped: err} } @@ -105,16 +113,16 @@ func (mc *mcryptClient) markCommand(ctx context.Context, dbName string, cmd bson } // connect connects the underlying Client instance. This must be called before performing any mark operations. -func (mc *mcryptClient) connect(ctx context.Context) error { +func (mc *mongocryptdClient) connect(ctx context.Context) error { return mc.client.Connect(ctx) } // disconnect disconnects the underlying Client instance. This should be called after all operations have completed. -func (mc *mcryptClient) disconnect(ctx context.Context) error { +func (mc *mongocryptdClient) disconnect(ctx context.Context) error { return mc.client.Disconnect(ctx) } -func (mc *mcryptClient) spawnProcess() error { +func (mc *mongocryptdClient) spawnProcess() error { // Ignore gosec warning about subprocess launched with externally-provided path variable. /* #nosec G204 */ cmd := exec.Command(mc.path, mc.spawnArgs...) diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/mongointernal.go b/vendor/go.mongodb.org/mongo-driver/mongo/mongointernal.go new file mode 100644 index 0000000..0148756 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/mongo/mongointernal.go @@ -0,0 +1,41 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +//go:build mongointernal + +package mongo + +import ( + "time" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/session" +) + +// NewSessionWithLSID returns a Session with the given sessionID document. The +// sessionID is a BSON document with key "id" containing a 16-byte UUID (binary +// subtype 4). +// +// Sessions returned by NewSessionWithLSID are never added to the driver's +// session pool. Calling "EndSession" or "ClientSession.SetServer" on a Session +// returned by NewSessionWithLSID will panic. +// +// NewSessionWithLSID is intended only for internal use and may be changed or +// removed at any time. +func NewSessionWithLSID(client *Client, sessionID bson.Raw) Session { + return &sessionImpl{ + clientSession: &session.Client{ + Server: &session.Server{ + SessionID: bsoncore.Document(sessionID), + LastUsed: time.Now(), + }, + ClientID: client.id, + }, + client: client, + deployment: client.deployment, + } +} diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/aggregateoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/aggregateoptions.go index cf0da5f..20e1c70 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/aggregateoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/aggregateoptions.go @@ -23,7 +23,7 @@ type AggregateOptions struct { // If true, writes executed as part of the operation will opt out of document-level validation on the server. This // option is valid for MongoDB versions >= 3.2 and is ignored for previous server versions. The default value is - // false. See https://docs.mongodb.com/manual/core/schema-validation/ for more information about document + // false. See https://www.mongodb.com/docs/manual/core/schema-validation/ for more information about document // validation. BypassDocumentValidation *bool @@ -34,6 +34,10 @@ type AggregateOptions struct { // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there // is no time limit for query execution. + // + // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used + // in its place to control the amount of time that a single operation can run before returning an error. MaxTime + // is ignored if Timeout is set on the client. MaxTime *time.Duration // The maximum amount of time that the server should wait for new documents to satisfy a tailable cursor query. @@ -41,7 +45,7 @@ type AggregateOptions struct { MaxAwaitTime *time.Duration // A string that will be included in server logs, profiling logs, and currentOp queries to help trace the operation. - // The default is the empty string, which means that no comment will be included in the logs. + // The default is nil, which means that no comment will be included in the logs. Comment *string // The index to use for the aggregation. This should either be the index name as a string or the index specification @@ -91,6 +95,10 @@ func (ao *AggregateOptions) SetCollation(c *Collation) *AggregateOptions { } // SetMaxTime sets the value for the MaxTime field. +// +// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout +// option may be used in its place to control the amount of time that a single operation can +// run before returning an error. MaxTime is ignored if Timeout is set on the client. func (ao *AggregateOptions) SetMaxTime(d time.Duration) *AggregateOptions { ao.MaxTime = &d return ao @@ -131,6 +139,9 @@ func (ao *AggregateOptions) SetCustom(c bson.M) *AggregateOptions { // MergeAggregateOptions combines the given AggregateOptions instances into a single AggregateOptions in a last-one-wins // fashion. +// +// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a +// single options struct instead. func MergeAggregateOptions(opts ...*AggregateOptions) *AggregateOptions { aggOpts := Aggregate() for _, ao := range opts { diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/autoencryptionoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/autoencryptionoptions.go index 89c3c05..15d5138 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/autoencryptionoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/autoencryptionoptions.go @@ -8,6 +8,9 @@ package options import ( "crypto/tls" + "net/http" + + "go.mongodb.org/mongo-driver/internal/httputil" ) // AutoEncryptionOptions represents options used to configure auto encryption/decryption behavior for a mongo.Client @@ -32,11 +35,16 @@ type AutoEncryptionOptions struct { BypassAutoEncryption *bool ExtraOptions map[string]interface{} TLSConfig map[string]*tls.Config + HTTPClient *http.Client + EncryptedFieldsMap map[string]interface{} + BypassQueryAnalysis *bool } // AutoEncryption creates a new AutoEncryptionOptions configured with default values. func AutoEncryption() *AutoEncryptionOptions { - return &AutoEncryptionOptions{} + return &AutoEncryptionOptions{ + HTTPClient: httputil.DefaultHTTPClient, + } } // SetKeyVaultClientOptions specifies options for the client used to communicate with the key vault collection. @@ -90,7 +98,35 @@ func (a *AutoEncryptionOptions) SetBypassAutoEncryption(bypass bool) *AutoEncryp return a } -// SetExtraOptions specifies a map of options to configure the mongocryptd process. +// SetExtraOptions specifies a map of options to configure the mongocryptd process or mongo_crypt shared library. +// +// # Supported Extra Options +// +// "mongocryptdURI" - The mongocryptd URI. Allows setting a custom URI used to communicate with the +// mongocryptd process. The default is "mongodb://localhost:27020", which works with the default +// mongocryptd process spawned by the Client. Must be a string. +// +// "mongocryptdBypassSpawn" - If set to true, the Client will not attempt to spawn a mongocryptd +// process. Must be a bool. +// +// "mongocryptdSpawnPath" - The path used when spawning mongocryptd. +// Defaults to empty string and spawns mongocryptd from system path. Must be a string. +// +// "mongocryptdSpawnArgs" - Command line arguments passed when spawning mongocryptd. +// Defaults to ["--idleShutdownTimeoutSecs=60"]. Must be an array of strings. +// +// "cryptSharedLibRequired" - If set to true, Client creation will return an error if the +// crypt_shared library is not loaded. If unset or set to false, Client creation will not return an +// error if the crypt_shared library is not loaded. The default is unset. Must be a bool. +// +// "cryptSharedLibPath" - The crypt_shared library override path. This must be the path to the +// crypt_shared dynamic library file (for example, a .so, .dll, or .dylib file), not the directory +// that contains it. If the override path is a relative path, it will be resolved relative to the +// working directory of the process. If the override path is a relative path and the first path +// component is the literal string "$ORIGIN", the "$ORIGIN" component will be replaced by the +// absolute path to the directory containing the linked libmongocrypt library. Setting an override +// path disables the default system library search path. If an override path is specified but the +// crypt_shared library cannot be loaded, Client creation will return an error. Must be a string. func (a *AutoEncryptionOptions) SetExtraOptions(extraOpts map[string]interface{}) *AutoEncryptionOptions { a.ExtraOptions = extraOpts return a @@ -113,7 +149,24 @@ func (a *AutoEncryptionOptions) SetTLSConfig(tlsOpts map[string]*tls.Config) *Au return a } +// SetEncryptedFieldsMap specifies a map from namespace to local EncryptedFieldsMap document. +// EncryptedFieldsMap is used for Queryable Encryption. +func (a *AutoEncryptionOptions) SetEncryptedFieldsMap(ef map[string]interface{}) *AutoEncryptionOptions { + a.EncryptedFieldsMap = ef + return a +} + +// SetBypassQueryAnalysis specifies whether or not query analysis should be used for automatic encryption. +// Use this option when using explicit encryption with Queryable Encryption. +func (a *AutoEncryptionOptions) SetBypassQueryAnalysis(bypass bool) *AutoEncryptionOptions { + a.BypassQueryAnalysis = &bypass + return a +} + // MergeAutoEncryptionOptions combines the argued AutoEncryptionOptions in a last-one wins fashion. +// +// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a +// single options struct instead. func MergeAutoEncryptionOptions(opts ...*AutoEncryptionOptions) *AutoEncryptionOptions { aeo := AutoEncryption() for _, opt := range opts { @@ -142,6 +195,15 @@ func MergeAutoEncryptionOptions(opts ...*AutoEncryptionOptions) *AutoEncryptionO if opt.TLSConfig != nil { aeo.TLSConfig = opt.TLSConfig } + if opt.EncryptedFieldsMap != nil { + aeo.EncryptedFieldsMap = opt.EncryptedFieldsMap + } + if opt.BypassQueryAnalysis != nil { + aeo.BypassQueryAnalysis = opt.BypassQueryAnalysis + } + if opt.HTTPClient != nil { + aeo.HTTPClient = opt.HTTPClient + } } return aeo diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/bulkwriteoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/bulkwriteoptions.go index 2786ab2..49d7a0f 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/bulkwriteoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/bulkwriteoptions.go @@ -13,10 +13,14 @@ var DefaultOrdered = true type BulkWriteOptions struct { // If true, writes executed as part of the operation will opt out of document-level validation on the server. This // option is valid for MongoDB versions >= 3.2 and is ignored for previous server versions. The default value is - // false. See https://docs.mongodb.com/manual/core/schema-validation/ for more information about document + // false. See https://www.mongodb.com/docs/manual/core/schema-validation/ for more information about document // validation. BypassDocumentValidation *bool + // A string or document that will be included in server logs, profiling logs, and currentOp queries to help trace + // the operation. The default value is nil, which means that no comment will be included in the logs. + Comment interface{} + // If true, no writes will be executed after one fails. The default value is true. Ordered *bool @@ -25,6 +29,12 @@ type BulkWriteOptions struct { // parameter names to values. Values must be constant or closed expressions that do not reference document fields. // Parameters can then be accessed as variables in an aggregate expression context (e.g. "$$var"). Let interface{} + + // If true, the server accepts empty Timestamp as a literal rather than replacing it with the current time. + // + // Deprecated: This option is for internal use only and should not be set. It may be changed or removed in any + // release. + BypassEmptyTsReplacement *bool } // BulkWrite creates a new *BulkWriteOptions instance. @@ -34,6 +44,12 @@ func BulkWrite() *BulkWriteOptions { } } +// SetComment sets the value for the Comment field. +func (b *BulkWriteOptions) SetComment(comment interface{}) *BulkWriteOptions { + b.Comment = comment + return b +} + // SetOrdered sets the value for the Ordered field. func (b *BulkWriteOptions) SetOrdered(ordered bool) *BulkWriteOptions { b.Ordered = &ordered @@ -57,12 +73,18 @@ func (b *BulkWriteOptions) SetLet(let interface{}) *BulkWriteOptions { // MergeBulkWriteOptions combines the given BulkWriteOptions instances into a single BulkWriteOptions in a last-one-wins // fashion. +// +// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a +// single options struct instead. func MergeBulkWriteOptions(opts ...*BulkWriteOptions) *BulkWriteOptions { b := BulkWrite() for _, opt := range opts { if opt == nil { continue } + if opt.Comment != nil { + b.Comment = opt.Comment + } if opt.Ordered != nil { b.Ordered = opt.Ordered } @@ -72,6 +94,9 @@ func MergeBulkWriteOptions(opts ...*BulkWriteOptions) *BulkWriteOptions { if opt.Let != nil { b.Let = opt.Let } + if opt.BypassEmptyTsReplacement != nil { + b.BypassEmptyTsReplacement = opt.BypassEmptyTsReplacement + } } return b diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/changestreamoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/changestreamoptions.go index eb9b064..3d06a66 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/changestreamoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/changestreamoptions.go @@ -23,11 +23,18 @@ type ChangeStreamOptions struct { // default value is nil, which means the default collation of the collection will be used. Collation *Collation - // Specifies whether the updated document should be returned in change notifications for update operations along - // with the deltas describing the changes made to the document. The default is options.Default, which means that - // the updated document will not be included in the change notification. + // A string that will be included in server logs, profiling logs, and currentOp queries to help trace the operation. + // The default is nil, which means that no comment will be included in the logs. + Comment *string + + // Specifies how the updated document should be returned in change notifications for update operations. The default + // is options.Default, which means that only partial update deltas will be included in the change notification. FullDocument *FullDocument + // Specifies how the pre-update document should be returned in change notifications for update operations. The default + // is options.Off, which means that the pre-update document will not be included in the change notification. + FullDocumentBeforeChange *FullDocument + // The maximum amount of time that the server should wait for new documents to satisfy a tailable cursor query. MaxAwaitTime *time.Duration @@ -36,6 +43,11 @@ type ChangeStreamOptions struct { // StartAfter must not be set. ResumeAfter interface{} + // ShowExpandedEvents specifies whether the server will return an expanded list of change stream events. Additional + // events include: createIndexes, dropIndexes, modify, create, shardCollection, reshardCollection and + // refineCollectionShardKey. This option is only valid for MongoDB versions >= 6.0. + ShowExpandedEvents *bool + // If specified, the change stream will only return changes that occurred at or after the given timestamp. This // option is only valid for MongoDB versions >= 4.0. If this is specified, ResumeAfter and StartAfter must not be // set. @@ -62,7 +74,6 @@ type ChangeStreamOptions struct { // ChangeStream creates a new ChangeStreamOptions instance. func ChangeStream() *ChangeStreamOptions { cso := &ChangeStreamOptions{} - cso.SetFullDocument(Default) return cso } @@ -78,12 +89,24 @@ func (cso *ChangeStreamOptions) SetCollation(c Collation) *ChangeStreamOptions { return cso } +// SetComment sets the value for the Comment field. +func (cso *ChangeStreamOptions) SetComment(comment string) *ChangeStreamOptions { + cso.Comment = &comment + return cso +} + // SetFullDocument sets the value for the FullDocument field. func (cso *ChangeStreamOptions) SetFullDocument(fd FullDocument) *ChangeStreamOptions { cso.FullDocument = &fd return cso } +// SetFullDocumentBeforeChange sets the value for the FullDocumentBeforeChange field. +func (cso *ChangeStreamOptions) SetFullDocumentBeforeChange(fdbc FullDocument) *ChangeStreamOptions { + cso.FullDocumentBeforeChange = &fdbc + return cso +} + // SetMaxAwaitTime sets the value for the MaxAwaitTime field. func (cso *ChangeStreamOptions) SetMaxAwaitTime(d time.Duration) *ChangeStreamOptions { cso.MaxAwaitTime = &d @@ -96,6 +119,12 @@ func (cso *ChangeStreamOptions) SetResumeAfter(rt interface{}) *ChangeStreamOpti return cso } +// SetShowExpandedEvents sets the value for the ShowExpandedEvents field. +func (cso *ChangeStreamOptions) SetShowExpandedEvents(see bool) *ChangeStreamOptions { + cso.ShowExpandedEvents = &see + return cso +} + // SetStartAtOperationTime sets the value for the StartAtOperationTime field. func (cso *ChangeStreamOptions) SetStartAtOperationTime(t *primitive.Timestamp) *ChangeStreamOptions { cso.StartAtOperationTime = t @@ -127,6 +156,9 @@ func (cso *ChangeStreamOptions) SetCustomPipeline(cp bson.M) *ChangeStreamOption // MergeChangeStreamOptions combines the given ChangeStreamOptions instances into a single ChangeStreamOptions in a // last-one-wins fashion. +// +// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a +// single options struct instead. func MergeChangeStreamOptions(opts ...*ChangeStreamOptions) *ChangeStreamOptions { csOpts := ChangeStream() for _, cso := range opts { @@ -139,15 +171,24 @@ func MergeChangeStreamOptions(opts ...*ChangeStreamOptions) *ChangeStreamOptions if cso.Collation != nil { csOpts.Collation = cso.Collation } + if cso.Comment != nil { + csOpts.Comment = cso.Comment + } if cso.FullDocument != nil { csOpts.FullDocument = cso.FullDocument } + if cso.FullDocumentBeforeChange != nil { + csOpts.FullDocumentBeforeChange = cso.FullDocumentBeforeChange + } if cso.MaxAwaitTime != nil { csOpts.MaxAwaitTime = cso.MaxAwaitTime } if cso.ResumeAfter != nil { csOpts.ResumeAfter = cso.ResumeAfter } + if cso.ShowExpandedEvents != nil { + csOpts.ShowExpandedEvents = cso.ShowExpandedEvents + } if cso.StartAtOperationTime != nil { csOpts.StartAtOperationTime = cso.StartAtOperationTime } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/clientencryptionoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/clientencryptionoptions.go index b8f6e87..2457f68 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/clientencryptionoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/clientencryptionoptions.go @@ -9,6 +9,9 @@ package options import ( "crypto/tls" "fmt" + "net/http" + + "go.mongodb.org/mongo-driver/internal/httputil" ) // ClientEncryptionOptions represents all possible options used to configure a ClientEncryption instance. @@ -16,11 +19,14 @@ type ClientEncryptionOptions struct { KeyVaultNamespace string KmsProviders map[string]map[string]interface{} TLSConfig map[string]*tls.Config + HTTPClient *http.Client } // ClientEncryption creates a new ClientEncryptionOptions instance. func ClientEncryption() *ClientEncryptionOptions { - return &ClientEncryptionOptions{} + return &ClientEncryptionOptions{ + HTTPClient: httputil.DefaultHTTPClient, + } } // SetKeyVaultNamespace specifies the namespace of the key vault collection. This is required. @@ -56,12 +62,12 @@ func (c *ClientEncryptionOptions) SetTLSConfig(tlsOpts map[string]*tls.Config) * // to the KMS provider. The input map should contain a mapping from each KMS provider to a document containing the necessary // options, as follows: // -// { -// "kmip": { -// "tlsCertificateKeyFile": "foo.pem", -// "tlsCAFile": "fooCA.pem" -// } -// } +// { +// "kmip": { +// "tlsCertificateKeyFile": "foo.pem", +// "tlsCAFile": "fooCA.pem" +// } +// } // // Currently, the following TLS options are supported: // @@ -116,6 +122,9 @@ func BuildTLSConfig(tlsOpts map[string]interface{}) (*tls.Config, error) { } // MergeClientEncryptionOptions combines the argued ClientEncryptionOptions in a last-one wins fashion. +// +// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a +// single options struct instead. func MergeClientEncryptionOptions(opts ...*ClientEncryptionOptions) *ClientEncryptionOptions { ceo := ClientEncryption() for _, opt := range opts { @@ -132,6 +141,9 @@ func MergeClientEncryptionOptions(opts ...*ClientEncryptionOptions) *ClientEncry if opt.TLSConfig != nil { ceo.TLSConfig = opt.TLSConfig } + if opt.HTTPClient != nil { + ceo.HTTPClient = opt.HTTPClient + } } return ceo diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/clientoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/clientoptions.go index 115cc64..b1dc0b6 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/clientoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/clientoptions.go @@ -15,23 +15,46 @@ import ( "errors" "fmt" "io/ioutil" + "math" "net" + "net/http" "strings" "time" "github.com/youmark/pkcs8" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/event" - "go.mongodb.org/mongo-driver/internal" + "go.mongodb.org/mongo-driver/internal/httputil" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/tag" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) +const ( + // ServerMonitoringModeAuto indicates that the client will behave like "poll" + // mode when running on a FaaS (Function as a Service) platform, or like + // "stream" mode otherwise. The client detects its execution environment by + // following the rules for generating the "client.env" handshake metadata field + // as specified in the MongoDB Handshake specification. This is the default + // mode. + ServerMonitoringModeAuto = connstring.ServerMonitoringModeAuto + + // ServerMonitoringModePoll indicates that the client will periodically check + // the server using a hello or legacy hello command and then sleep for + // heartbeatFrequencyMS milliseconds before running another check. + ServerMonitoringModePoll = connstring.ServerMonitoringModePoll + + // ServerMonitoringModeStream indicates that the client will use a streaming + // protocol when the server supports it. The streaming protocol optimally + // reduces the time it takes for a client to discover server state changes. + ServerMonitoringModeStream = connstring.ServerMonitoringModeStream +) + // ContextDialer is an interface that can be implemented by types that can create connections. It should be used to // provide a custom dialer when configuring a Client. // @@ -45,7 +68,7 @@ type ContextDialer interface { // AuthMechanism: the mechanism to use for authentication. Supported values include "SCRAM-SHA-256", "SCRAM-SHA-1", // "MONGODB-CR", "PLAIN", "GSSAPI", "MONGODB-X509", and "MONGODB-AWS". This can also be set through the "authMechanism" // URI option. (e.g. "authMechanism=PLAIN"). For more information, see -// https://docs.mongodb.com/manual/core/authentication-mechanisms/. +// https://www.mongodb.com/docs/manual/core/authentication-mechanisms/. // // AuthMechanismProperties can be used to specify additional configuration options for certain mechanisms. They can also // be set through the "authMechanismProperites" URI option @@ -67,9 +90,9 @@ type ContextDialer interface { // The SERVICE_HOST and CANONICALIZE_HOST_NAME properties must not be used at the same time on Linux and Darwin // systems. // -// AuthSource: the name of the database to use for authentication. This defaults to "$external" for MONGODB-X509, -// GSSAPI, and PLAIN and "admin" for all other mechanisms. This can also be set through the "authSource" URI option -// (e.g. "authSource=otherDb"). +// AuthSource: the name of the database to use for authentication. This defaults to "$external" for MONGODB-AWS, +// MONGODB-OIDC, MONGODB-X509, GSSAPI, and PLAIN. It defaults to "admin" for all other auth mechanisms. This can +// also be set through the "authSource" URI option (e.g. "authSource=otherDb"). // // Username: the username for authentication. This can also be set through the URI as a username:password pair before // the first @ character. For example, a URI for user "user", password "pwd", and host "localhost:27017" would be @@ -89,6 +112,116 @@ type Credential struct { Username string Password string PasswordSet bool + OIDCMachineCallback OIDCCallback + OIDCHumanCallback OIDCCallback +} + +// OIDCCallback is the type for both Human and Machine Callback flows. +// RefreshToken will always be nil in the OIDCArgs for the Machine flow. +type OIDCCallback func(context.Context, *OIDCArgs) (*OIDCCredential, error) + +// OIDCArgs contains the arguments for the OIDC callback. +type OIDCArgs struct { + Version int + IDPInfo *IDPInfo + RefreshToken *string +} + +// OIDCCredential contains the access token and refresh token. +type OIDCCredential struct { + AccessToken string + ExpiresAt *time.Time + RefreshToken *string +} + +// IDPInfo contains the information needed to perform OIDC authentication with +// an Identity Provider. +type IDPInfo struct { + Issuer string + ClientID string + RequestScopes []string +} + +// BSONOptions are optional BSON marshaling and unmarshaling behaviors. +type BSONOptions struct { + // UseJSONStructTags causes the driver to fall back to using the "json" + // struct tag if a "bson" struct tag is not specified. + UseJSONStructTags bool + + // ErrorOnInlineDuplicates causes the driver to return an error if there is + // a duplicate field in the marshaled BSON when the "inline" struct tag + // option is set. + ErrorOnInlineDuplicates bool + + // IntMinSize causes the driver to marshal Go integer values (int, int8, + // int16, int32, int64, uint, uint8, uint16, uint32, or uint64) as the + // minimum BSON int size (either 32 or 64 bits) that can represent the + // integer value. + IntMinSize bool + + // NilMapAsEmpty causes the driver to marshal nil Go maps as empty BSON + // documents instead of BSON null. + // + // Empty BSON documents take up slightly more space than BSON null, but + // preserve the ability to use document update operations like "$set" that + // do not work on BSON null. + NilMapAsEmpty bool + + // NilSliceAsEmpty causes the driver to marshal nil Go slices as empty BSON + // arrays instead of BSON null. + // + // Empty BSON arrays take up slightly more space than BSON null, but + // preserve the ability to use array update operations like "$push" or + // "$addToSet" that do not work on BSON null. + NilSliceAsEmpty bool + + // NilByteSliceAsEmpty causes the driver to marshal nil Go byte slices as + // empty BSON binary values instead of BSON null. + NilByteSliceAsEmpty bool + + // OmitZeroStruct causes the driver to consider the zero value for a struct + // (e.g. MyStruct{}) as empty and omit it from the marshaled BSON when the + // "omitempty" struct tag option is set. + OmitZeroStruct bool + + // StringifyMapKeysWithFmt causes the driver to convert Go map keys to BSON + // document field name strings using fmt.Sprint instead of the default + // string conversion logic. + StringifyMapKeysWithFmt bool + + // AllowTruncatingDoubles causes the driver to truncate the fractional part + // of BSON "double" values when attempting to unmarshal them into a Go + // integer (int, int8, int16, int32, or int64) struct field. The truncation + // logic does not apply to BSON "decimal128" values. + AllowTruncatingDoubles bool + + // BinaryAsSlice causes the driver to unmarshal BSON binary field values + // that are the "Generic" or "Old" BSON binary subtype as a Go byte slice + // instead of a primitive.Binary. + BinaryAsSlice bool + + // DefaultDocumentD causes the driver to always unmarshal documents into the + // primitive.D type. This behavior is restricted to data typed as + // "interface{}" or "map[string]interface{}". + DefaultDocumentD bool + + // DefaultDocumentM causes the driver to always unmarshal documents into the + // primitive.M type. This behavior is restricted to data typed as + // "interface{}" or "map[string]interface{}". + DefaultDocumentM bool + + // UseLocalTimeZone causes the driver to unmarshal time.Time values in the + // local timezone instead of the UTC timezone. + UseLocalTimeZone bool + + // ZeroMaps causes the driver to delete any existing values from Go maps in + // the destination value before unmarshaling BSON documents into them. + ZeroMaps bool + + // ZeroStructs causes the driver to delete any existing values from Go + // structs in the destination value before unmarshaling BSON documents into + // them. + ZeroStructs bool } // ClientOptions contains options to configure a Client instance. Each option can be set through setter functions. See @@ -104,8 +237,10 @@ type ClientOptions struct { DisableOCSPEndpointCheck *bool HeartbeatInterval *time.Duration Hosts []string + HTTPClient *http.Client LoadBalanced *bool LocalThreshold *time.Duration + LoggerOptions *LoggerOptions MaxConnIdleTime *time.Duration MaxPoolSize *uint64 MinPoolSize *uint64 @@ -115,22 +250,23 @@ type ClientOptions struct { ServerMonitor *event.ServerMonitor ReadConcern *readconcern.ReadConcern ReadPreference *readpref.ReadPref + BSONOptions *BSONOptions Registry *bsoncodec.Registry ReplicaSet *string RetryReads *bool RetryWrites *bool ServerAPIOptions *ServerAPIOptions + ServerMonitoringMode *string ServerSelectionTimeout *time.Duration - SocketTimeout *time.Duration SRVMaxHosts *int SRVServiceName *string + Timeout *time.Duration TLSConfig *tls.Config WriteConcern *writeconcern.WriteConcern ZlibLevel *int ZstdLevel *int err error - uri string cs *connstring.ConnString // AuthenticateToAnything skips server type checks when deciding if authentication is possible. @@ -151,77 +287,126 @@ type ClientOptions struct { // Deprecated: This option is for internal use only and should not be set. It may be changed or removed in any // release. Deployment driver.Deployment + + // SocketTimeout specifies the timeout to be used for the Client's socket reads and writes. + // + // NOTE(benjirewis): SocketTimeout will be deprecated in a future release. The more general Timeout option + // may be used in its place to control the amount of time that a single operation can run before returning + // an error. Setting SocketTimeout and Timeout on a single client will result in undefined behavior. + SocketTimeout *time.Duration } // Client creates a new ClientOptions instance. func Client() *ClientOptions { - return new(ClientOptions) + return &ClientOptions{ + HTTPClient: httputil.DefaultHTTPClient, + } } // Validate validates the client options. This method will return the first error found. func (c *ClientOptions) Validate() error { - c.validateAndSetError() - return c.err -} - -func (c *ClientOptions) validateAndSetError() { if c.err != nil { - return + return c.err } + c.err = c.validate() + return c.err +} +func (c *ClientOptions) validate() error { // Direct connections cannot be made if multiple hosts are specified or an SRV URI is used. if c.Direct != nil && *c.Direct { if len(c.Hosts) > 1 { - c.err = errors.New("a direct connection cannot be made if multiple hosts are specified") - return + return errors.New("a direct connection cannot be made if multiple hosts are specified") } if c.cs != nil && c.cs.Scheme == connstring.SchemeMongoDBSRV { - c.err = errors.New("a direct connection cannot be made if an SRV URI is used") - return + return errors.New("a direct connection cannot be made if an SRV URI is used") } } + if c.MaxPoolSize != nil && c.MinPoolSize != nil && *c.MaxPoolSize != 0 && *c.MinPoolSize > *c.MaxPoolSize { + return fmt.Errorf("minPoolSize must be less than or equal to maxPoolSize, got minPoolSize=%d maxPoolSize=%d", *c.MinPoolSize, *c.MaxPoolSize) + } + // verify server API version if ServerAPIOptions are passed in. if c.ServerAPIOptions != nil { - c.err = c.ServerAPIOptions.ServerAPIVersion.Validate() + if err := c.ServerAPIOptions.ServerAPIVersion.Validate(); err != nil { + return err + } } // Validation for load-balanced mode. if c.LoadBalanced != nil && *c.LoadBalanced { if len(c.Hosts) > 1 { - c.err = internal.ErrLoadBalancedWithMultipleHosts - return + return connstring.ErrLoadBalancedWithMultipleHosts } if c.ReplicaSet != nil { - c.err = internal.ErrLoadBalancedWithReplicaSet - return + return connstring.ErrLoadBalancedWithReplicaSet } - if c.Direct != nil { - c.err = internal.ErrLoadBalancedWithDirectConnection - return + if c.Direct != nil && *c.Direct { + return connstring.ErrLoadBalancedWithDirectConnection } } // Validation for srvMaxHosts. if c.SRVMaxHosts != nil && *c.SRVMaxHosts > 0 { if c.ReplicaSet != nil { - c.err = internal.ErrSRVMaxHostsWithReplicaSet + return connstring.ErrSRVMaxHostsWithReplicaSet } if c.LoadBalanced != nil && *c.LoadBalanced { - c.err = internal.ErrSRVMaxHostsWithLoadBalanced + return connstring.ErrSRVMaxHostsWithLoadBalanced + } + } + + if mode := c.ServerMonitoringMode; mode != nil && !connstring.IsValidServerMonitoringMode(*mode) { + return fmt.Errorf("invalid server monitoring mode: %q", *mode) + } + + // OIDC Validation + if c.Auth != nil && c.Auth.AuthMechanism == auth.MongoDBOIDC { + if c.Auth.Password != "" { + return fmt.Errorf("password must not be set for the %s auth mechanism", auth.MongoDBOIDC) + } + if c.Auth.OIDCMachineCallback != nil && c.Auth.OIDCHumanCallback != nil { + return fmt.Errorf("cannot set both OIDCMachineCallback and OIDCHumanCallback, only one may be specified") + } + if c.Auth.OIDCHumanCallback == nil && c.Auth.AuthMechanismProperties[auth.AllowedHostsProp] != "" { + return fmt.Errorf("Cannot specify ALLOWED_HOSTS without an OIDCHumanCallback") + } + if env, ok := c.Auth.AuthMechanismProperties[auth.EnvironmentProp]; ok { + switch env { + case auth.GCPEnvironmentValue, auth.AzureEnvironmentValue: + if c.Auth.OIDCMachineCallback != nil { + return fmt.Errorf("OIDCMachineCallback cannot be specified with the %s %q", env, auth.EnvironmentProp) + } + if c.Auth.OIDCHumanCallback != nil { + return fmt.Errorf("OIDCHumanCallback cannot be specified with the %s %q", env, auth.EnvironmentProp) + } + if c.Auth.AuthMechanismProperties[auth.ResourceProp] == "" { + return fmt.Errorf("%q must be set for the %s %q", auth.ResourceProp, env, auth.EnvironmentProp) + } + default: + if c.Auth.AuthMechanismProperties[auth.ResourceProp] != "" { + return fmt.Errorf("%q must not be set for the %s %q", auth.ResourceProp, env, auth.EnvironmentProp) + } + } } } + + return nil } // GetURI returns the original URI used to configure the ClientOptions instance. If ApplyURI was not called during // construction, this returns "". func (c *ClientOptions) GetURI() string { - return c.uri + if c.cs == nil { + return "" + } + return c.cs.Original } // ApplyURI parses the given URI and sets options accordingly. The URI can contain host names, IPv4/IPv6 literals, or // an SRV record that will be resolved when the Client is created. When using an SRV record, TLS support is -// implictly enabled. Specify the "tls=false" URI option to override this. +// implicitly enabled. Specify the "tls=false" URI option to override this. // // If the connection string contains any options that have previously been set, it will overwrite them. Options that // correspond to multiple URI parameters, such as WriteConcern, will be completely overwritten if any of the query @@ -231,20 +416,19 @@ func (c *ClientOptions) GetURI() string { // If the URI format is incorrect or there are conflicting options specified in the URI an error will be recorded and // can be retrieved by calling Validate. // -// For more information about the URI format, see https://docs.mongodb.com/manual/reference/connection-string/. See +// For more information about the URI format, see https://www.mongodb.com/docs/manual/reference/connection-string/. See // mongo.Connect documentation for examples of using URIs for different Client configurations. func (c *ClientOptions) ApplyURI(uri string) *ClientOptions { if c.err != nil { return c } - c.uri = uri cs, err := connstring.ParseAndValidate(uri) if err != nil { c.err = err return c } - c.cs = &cs + c.cs = cs if cs.AppName != "" { c.AppName = &cs.AppName @@ -445,6 +629,10 @@ func (c *ClientOptions) ApplyURI(uri string) *ClientOptions { c.DisableOCSPEndpointCheck = &cs.SSLDisableOCSPEndpointCheck } + if cs.TimeoutSet { + c.Timeout = &cs.Timeout + } + return c } @@ -470,12 +658,12 @@ func (c *ClientOptions) SetAuth(auth Credential) *ClientOptions { // // 2. "zlib" - requires server version >= 3.6 // -// 3. "zstd" - requires server version >= 4.2, and driver version >= 1.2.0 with cgo support enabled or driver version >= 1.3.0 -// without cgo +// 3. "zstd" - requires server version >= 4.2, and driver version >= 1.2.0 with cgo support enabled or driver +// version >= 1.3.0 without cgo. // -// If this option is specified, the driver will perform a negotiation with the server to determine a common list of of +// If this option is specified, the driver will perform a negotiation with the server to determine a common list of // compressors and will use the first one in that list when performing operations. See -// https://docs.mongodb.com/manual/reference/program/mongod/#cmdoption-mongod-networkmessagecompressors for more +// https://www.mongodb.com/docs/manual/reference/program/mongod/#cmdoption-mongod-networkmessagecompressors for more // information about configuring compression on the server and the server-side defaults. // // This can also be set through the "compressors" URI option (e.g. "compressors=zstd,zlib,snappy"). The default is @@ -486,18 +674,17 @@ func (c *ClientOptions) SetCompressors(comps []string) *ClientOptions { return c } -// SetConnectTimeout specifies a timeout that is used for creating connections to the server. If a custom Dialer is -// specified through SetDialer, this option must not be used. This can be set through ApplyURI with the -// "connectTimeoutMS" (e.g "connectTimeoutMS=30") option. If set to 0, no timeout will be used. The default is 30 -// seconds. +// SetConnectTimeout specifies a timeout that is used for creating connections to the server. This can be set through +// ApplyURI with the "connectTimeoutMS" (e.g "connectTimeoutMS=30") option. If set to 0, no timeout will be used. The +// default is 30 seconds. func (c *ClientOptions) SetConnectTimeout(d time.Duration) *ClientOptions { c.ConnectTimeout = &d return c } -// SetDialer specifies a custom ContextDialer to be used to create new connections to the server. The default is a -// net.Dialer with the Timeout field set to ConnectTimeout. See https://golang.org/pkg/net/#Dialer for more information -// about the net.Dialer type. +// SetDialer specifies a custom ContextDialer to be used to create new connections to the server. This method overrides +// the default net.Dialer, so dialer options such as Timeout, KeepAlive, Resolver, etc can be set. +// See https://golang.org/pkg/net/#Dialer for more information about the net.Dialer type. func (c *ClientOptions) SetDialer(d ContextDialer) *ClientOptions { c.Dialer = d return c @@ -564,6 +751,14 @@ func (c *ClientOptions) SetLocalThreshold(d time.Duration) *ClientOptions { return c } +// SetLoggerOptions specifies a LoggerOptions containing options for +// configuring a logger. +func (c *ClientOptions) SetLoggerOptions(opts *LoggerOptions) *ClientOptions { + c.LoggerOptions = opts + + return c +} + // SetMaxConnIdleTime specifies the maximum amount of time that a connection will remain idle in a connection pool // before it is removed from the pool and closed. This can also be set through the "maxIdleTimeMS" URI option (e.g. // "maxIdleTimeMS=10000"). The default is 0, meaning a connection can remain unused indefinitely. @@ -636,7 +831,7 @@ func (c *ClientOptions) SetReadConcern(rc *readconcern.ReadConcern) *ClientOptio // 3. "maxStalenessSeconds" (or "maxStaleness"): Specify a maximum replication lag for reads from secondaries in a // replica set (e.g. "maxStalenessSeconds=10"). // -// The default is readpref.Primary(). See https://docs.mongodb.com/manual/core/read-preference/#read-preference for +// The default is readpref.Primary(). See https://www.mongodb.com/docs/manual/core/read-preference/#read-preference for // more information about read preferences. func (c *ClientOptions) SetReadPreference(rp *readpref.ReadPref) *ClientOptions { c.ReadPreference = rp @@ -644,6 +839,12 @@ func (c *ClientOptions) SetReadPreference(rp *readpref.ReadPref) *ClientOptions return c } +// SetBSONOptions configures optional BSON marshaling and unmarshaling behavior. +func (c *ClientOptions) SetBSONOptions(opts *BSONOptions) *ClientOptions { + c.BSONOptions = opts + return c +} + // SetRegistry specifies the BSON registry to use for BSON marshalling/unmarshalling operations. The default is // bson.DefaultRegistry. func (c *ClientOptions) SetRegistry(registry *bsoncodec.Registry) *ClientOptions { @@ -702,11 +903,31 @@ func (c *ClientOptions) SetServerSelectionTimeout(d time.Duration) *ClientOption // SetSocketTimeout specifies how long the driver will wait for a socket read or write to return before returning a // network error. This can also be set through the "socketTimeoutMS" URI option (e.g. "socketTimeoutMS=1000"). The // default value is 0, meaning no timeout is used and socket operations can block indefinitely. +// +// NOTE(benjirewis): SocketTimeout will be deprecated in a future release. The more general Timeout option may be used +// in its place to control the amount of time that a single operation can run before returning an error. Setting +// SocketTimeout and Timeout on a single client will result in undefined behavior. func (c *ClientOptions) SetSocketTimeout(d time.Duration) *ClientOptions { c.SocketTimeout = &d return c } +// SetTimeout specifies the amount of time that a single operation run on this Client can execute before returning an error. +// The deadline of any operation run through the Client will be honored above any Timeout set on the Client; Timeout will only +// be honored if there is no deadline on the operation Context. Timeout can also be set through the "timeoutMS" URI option +// (e.g. "timeoutMS=1000"). The default value is nil, meaning operations do not inherit a timeout from the Client. +// +// If any Timeout is set (even 0) on the Client, the values of MaxTime on operation options, TransactionOptions.MaxCommitTime and +// SessionOptions.DefaultMaxCommitTime will be ignored. Setting Timeout and SocketTimeout or WriteConcern.wTimeout will result +// in undefined behavior. +// +// NOTE(benjirewis): SetTimeout represents unstable, provisional API. The behavior of the driver when a Timeout is specified is +// subject to change. +func (c *ClientOptions) SetTimeout(d time.Duration) *ClientOptions { + c.Timeout = &d + return c +} + // SetTLSConfig specifies a tls.Config instance to use use to configure TLS on all connections created to the cluster. // This can also be set through the following URI options: // @@ -716,7 +937,8 @@ func (c *ClientOptions) SetSocketTimeout(d time.Duration) *ClientOptions { // "tlsPrivateKeyFile". The "tlsCertificateKeyFile" option specifies a path to the client certificate and private key, // which must be concatenated into one file. The "tlsCertificateFile" and "tlsPrivateKey" combination specifies separate // paths to the client certificate and private key, respectively. Note that if "tlsCertificateKeyFile" is used, the -// other two options must not be specified. +// other two options must not be specified. Only the subject name of the first certificate is honored as the username +// for X509 auth in a file with multiple certs. // // 3. "tlsCertificateKeyFilePassword" (or "sslClientCertificateKeyPassword"): Specify the password to decrypt the client // private key file (e.g. "tlsCertificateKeyFilePassword=password"). @@ -735,6 +957,14 @@ func (c *ClientOptions) SetTLSConfig(cfg *tls.Config) *ClientOptions { return c } +// SetHTTPClient specifies the http.Client to be used for any HTTP requests. +// +// This should only be used to set custom HTTP client configurations. By default, the connection will use an httputil.DefaultHTTPClient. +func (c *ClientOptions) SetHTTPClient(client *http.Client) *ClientOptions { + c.HTTPClient = client + return c +} + // SetWriteConcern specifies the write concern to use to for write operations. This can also be set through the following // URI options: // @@ -803,6 +1033,16 @@ func (c *ClientOptions) SetServerAPIOptions(opts *ServerAPIOptions) *ClientOptio return c } +// SetServerMonitoringMode specifies the server monitoring protocol to use. See +// the helper constants ServerMonitoringModeAuto, ServerMonitoringModePoll, and +// ServerMonitoringModeStream for more information about valid server +// monitoring modes. +func (c *ClientOptions) SetServerMonitoringMode(mode string) *ClientOptions { + c.ServerMonitoringMode = &mode + + return c +} + // SetSRVMaxHosts specifies the maximum number of SRV results to randomly select during polling. To limit the number // of hosts selected in SRV discovery, this function must be called before ApplyURI. This can also be set through // the "srvMaxHosts" URI option. @@ -857,6 +1097,9 @@ func MergeClientOptions(opts ...*ClientOptions) *ClientOptions { if len(opt.Hosts) > 0 { c.Hosts = opt.Hosts } + if opt.HTTPClient != nil { + c.HTTPClient = opt.HTTPClient + } if opt.LoadBalanced != nil { c.LoadBalanced = opt.LoadBalanced } @@ -893,6 +1136,9 @@ func MergeClientOptions(opts ...*ClientOptions) *ClientOptions { if opt.ReadPreference != nil { c.ReadPreference = opt.ReadPreference } + if opt.BSONOptions != nil { + c.BSONOptions = opt.BSONOptions + } if opt.Registry != nil { c.Registry = opt.Registry } @@ -920,6 +1166,9 @@ func MergeClientOptions(opts ...*ClientOptions) *ClientOptions { if opt.SRVServiceName != nil { c.SRVServiceName = opt.SRVServiceName } + if opt.Timeout != nil { + c.Timeout = opt.Timeout + } if opt.TLSConfig != nil { c.TLSConfig = opt.TLSConfig } @@ -944,12 +1193,15 @@ func MergeClientOptions(opts ...*ClientOptions) *ClientOptions { if opt.err != nil { c.err = opt.err } - if opt.uri != "" { - c.uri = opt.uri - } if opt.cs != nil { c.cs = opt.cs } + if opt.LoggerOptions != nil { + c.LoggerOptions = opt.LoggerOptions + } + if opt.ServerMonitoringMode != nil { + c.ServerMonitoringMode = opt.ServerMonitoringMode + } } return c @@ -983,7 +1235,21 @@ func addClientCertFromSeparateFiles(cfg *tls.Config, keyFile, certFile, keyPassw return "", err } - data := append(keyData, '\n') + keySize := len(keyData) + if keySize > 64*1024*1024 { + return "", errors.New("X.509 key must be less than 64 MiB") + } + certSize := len(certData) + if certSize > 64*1024*1024 { + return "", errors.New("X.509 certificate must be less than 64 MiB") + } + dataSize := keySize + certSize + 1 + if dataSize > math.MaxInt { + return "", errors.New("size overflow") + } + data := make([]byte, 0, dataSize) + data = append(data, keyData...) + data = append(data, '\n') data = append(data, certData...) return addClientCertFromBytes(cfg, data, keyPassword) } @@ -997,8 +1263,8 @@ func addClientCertFromConcatenatedFile(cfg *tls.Config, certKeyFile, keyPassword return addClientCertFromBytes(cfg, data, keyPassword) } -// addClientCertFromBytes adds a client certificate to the configuration given a path to the -// containing file and returns the certificate's subject name. +// addClientCertFromBytes adds client certificates to the configuration given a path to the +// containing file and returns the subject name in the first certificate. func addClientCertFromBytes(cfg *tls.Config, data []byte, keyPasswd string) (string, error) { var currentBlock *pem.Block var certDecodedBlock []byte @@ -1015,7 +1281,11 @@ func addClientCertFromBytes(cfg *tls.Config, data []byte, keyPasswd string) (str if currentBlock.Type == "CERTIFICATE" { certBlock := data[start : len(data)-len(remaining)] certBlocks = append(certBlocks, certBlock) - certDecodedBlock = currentBlock.Bytes + // Assign the certDecodedBlock when it is never set, + // so only the first certificate is honored in a file with multiple certs. + if certDecodedBlock == nil { + certDecodedBlock = currentBlock.Bytes + } start += len(certBlock) } else if strings.HasSuffix(currentBlock.Type, "PRIVATE KEY") { isEncrypted := x509.IsEncryptedPEMBlock(currentBlock) || strings.Contains(currentBlock.Type, "ENCRYPTED PRIVATE KEY") @@ -1045,7 +1315,10 @@ func addClientCertFromBytes(cfg *tls.Config, data []byte, keyPasswd string) (str } } var encoded bytes.Buffer - pem.Encode(&encoded, &pem.Block{Type: currentBlock.Type, Bytes: keyBytes}) + err = pem.Encode(&encoded, &pem.Block{Type: currentBlock.Type, Bytes: keyBytes}) + if err != nil { + return "", fmt.Errorf("error encoding private key as PEM: %w", err) + } keyBlock := encoded.Bytes() keyBlocks = append(keyBlocks, keyBlock) start = len(data) - len(remaining) diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/collectionoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/collectionoptions.go index 5c81114..7904dbd 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/collectionoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/collectionoptions.go @@ -15,20 +15,24 @@ import ( // CollectionOptions represents options that can be used to configure a Collection. type CollectionOptions struct { - // The read concern to use for operations executed on the Collection. The default value is nil, which means that - // the read concern of the database used to configure the Collection will be used. + // ReadConcern is the read concern to use for operations executed on the Collection. The default value is nil, which means that + // the read concern of the Database used to configure the Collection will be used. ReadConcern *readconcern.ReadConcern - // The write concern to use for operations executed on the Collection. The default value is nil, which means that - // the write concern of the database used to configure the Collection will be used. + // WriteConcern is the write concern to use for operations executed on the Collection. The default value is nil, which means that + // the write concern of the Database used to configure the Collection will be used. WriteConcern *writeconcern.WriteConcern - // The read preference to use for operations executed on the Collection. The default value is nil, which means that - // the read preference of the database used to configure the Collection will be used. + // ReadPreference is the read preference to use for operations executed on the Collection. The default value is nil, which means that + // the read preference of the Database used to configure the Collection will be used. ReadPreference *readpref.ReadPref - // The BSON registry to marshal and unmarshal documents for operations executed on the Collection. The default value - // is nil, which means that the registry of the database used to configure the Collection will be used. + // BSONOptions configures optional BSON marshaling and unmarshaling + // behavior. + BSONOptions *BSONOptions + + // Registry is the BSON registry to marshal and unmarshal documents for operations executed on the Collection. The default value + // is nil, which means that the registry of the Database used to configure the Collection will be used. Registry *bsoncodec.Registry } @@ -55,6 +59,12 @@ func (c *CollectionOptions) SetReadPreference(rp *readpref.ReadPref) *Collection return c } +// SetBSONOptions configures optional BSON marshaling and unmarshaling behavior. +func (c *CollectionOptions) SetBSONOptions(opts *BSONOptions) *CollectionOptions { + c.BSONOptions = opts + return c +} + // SetRegistry sets the value for the Registry field. func (c *CollectionOptions) SetRegistry(r *bsoncodec.Registry) *CollectionOptions { c.Registry = r @@ -63,6 +73,9 @@ func (c *CollectionOptions) SetRegistry(r *bsoncodec.Registry) *CollectionOption // MergeCollectionOptions combines the given CollectionOptions instances into a single *CollectionOptions in a // last-one-wins fashion. +// +// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a +// single options struct instead. func MergeCollectionOptions(opts ...*CollectionOptions) *CollectionOptions { c := Collection() @@ -82,6 +95,9 @@ func MergeCollectionOptions(opts ...*CollectionOptions) *CollectionOptions { if opt.Registry != nil { c.Registry = opt.Registry } + if opt.BSONOptions != nil { + c.BSONOptions = opt.BSONOptions + } } return c diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/countoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/countoptions.go index 094524c..bb765d9 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/countoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/countoptions.go @@ -15,6 +15,13 @@ type CountOptions struct { // default value is nil, which means the default collation of the collection will be used. Collation *Collation + // TODO(GODRIVER-2386): CountOptions executor uses aggregation under the hood, which means this type has to be + // TODO a string for now. This can be replaced with `Comment interface{}` once 2386 is implemented. + + // A string or document that will be included in server logs, profiling logs, and currentOp queries to help trace + // the operation. The default is nil, which means that no comment will be included in the logs. + Comment *string + // The index to use for the aggregation. This should either be the index name as a string or the index specification // as a document. The driver will return an error if the hint parameter is a multi-key map. The default value is nil, // which means that no hint will be sent. @@ -26,6 +33,10 @@ type CountOptions struct { // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there is // no time limit for query execution. + // + // NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout option may be used in + // its place to control the amount of time that a single operation can run before returning an error. MaxTime is + // ignored if Timeout is set on the client. MaxTime *time.Duration // The number of documents to skip before counting. The default value is 0. @@ -43,6 +54,12 @@ func (co *CountOptions) SetCollation(c *Collation) *CountOptions { return co } +// SetComment sets the value for the Comment field. +func (co *CountOptions) SetComment(c string) *CountOptions { + co.Comment = &c + return co +} + // SetHint sets the value for the Hint field. func (co *CountOptions) SetHint(h interface{}) *CountOptions { co.Hint = h @@ -56,6 +73,10 @@ func (co *CountOptions) SetLimit(i int64) *CountOptions { } // SetMaxTime sets the value for the MaxTime field. +// +// NOTE(benjirewis): MaxTime will be deprecated in a future release. The more general Timeout +// option may be used in its place to control the amount of time that a single operation can +// run before returning an error. MaxTime is ignored if Timeout is set on the client. func (co *CountOptions) SetMaxTime(d time.Duration) *CountOptions { co.MaxTime = &d return co @@ -68,6 +89,9 @@ func (co *CountOptions) SetSkip(i int64) *CountOptions { } // MergeCountOptions combines the given CountOptions instances into a single CountOptions in a last-one-wins fashion. +// +// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a +// single options struct instead. func MergeCountOptions(opts ...*CountOptions) *CountOptions { countOpts := Count() for _, co := range opts { @@ -77,6 +101,9 @@ func MergeCountOptions(opts ...*CountOptions) *CountOptions { if co.Collation != nil { countOpts.Collation = co.Collation } + if co.Comment != nil { + countOpts.Comment = co.Comment + } if co.Hint != nil { countOpts.Hint = co.Hint } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/createcollectionoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/createcollectionoptions.go index 130c8e7..d8ffaaf 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/createcollectionoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/createcollectionoptions.go @@ -6,6 +6,8 @@ package options +import "time" + // DefaultIndexOptions represents the default options for a collection to apply on new indexes. This type can be used // when creating a new collection through the CreateCollectionOptions.SetDefaultIndexOptions method. type DefaultIndexOptions struct { @@ -28,18 +30,30 @@ func (d *DefaultIndexOptions) SetStorageEngine(storageEngine interface{}) *Defau // TimeSeriesOptions specifies options on a time-series collection. type TimeSeriesOptions struct { - // Name of the top-level field to be used for time. Inserted documents must have this field, + // TimeField is the top-level field to be used for time. Inserted documents must have this field, // and the field must be of the BSON UTC datetime type (0x9). TimeField string - // Optional name of the top-level field describing the series. This field is used to group + // MetaField is the name of the top-level field describing the series. This field is used to group // related data and may be of any BSON type, except for array. This name may not be the same - // as the TimeField or _id. + // as the TimeField or _id. This field is optional. MetaField *string - // Optional string specifying granularity of time-series data. Allowed granularity options are - // "seconds", "minutes" and "hours". + // Granularity is the granularity of time-series data. Allowed granularity options are + // "seconds", "minutes" and "hours". This field is optional. Granularity *string + + // BucketMaxSpan is the maximum range of time values for a bucket. The + // time.Duration is rounded down to the nearest second and applied as + // the command option: "bucketRoundingSeconds". This field is optional. + BucketMaxSpan *time.Duration + + // BucketRounding is used to determine the minimum time boundary when + // opening a new bucket by rounding the first timestamp down to the next + // multiple of this value. The time.Duration is rounded down to the + // nearest second and applied as the command option: + // "bucketRoundingSeconds". This field is optional. + BucketRounding *time.Duration } // TimeSeries creates a new TimeSeriesOptions instance. @@ -65,9 +79,23 @@ func (tso *TimeSeriesOptions) SetGranularity(granularity string) *TimeSeriesOpti return tso } +// SetBucketMaxSpan sets the value for BucketMaxSpan. +func (tso *TimeSeriesOptions) SetBucketMaxSpan(dur time.Duration) *TimeSeriesOptions { + tso.BucketMaxSpan = &dur + + return tso +} + +// SetBucketRounding sets the value for BucketRounding. +func (tso *TimeSeriesOptions) SetBucketRounding(dur time.Duration) *TimeSeriesOptions { + tso.BucketRounding = &dur + + return tso +} + // CreateCollectionOptions represents options that can be used to configure a CreateCollection operation. type CreateCollectionOptions struct { - // Specifies if the collection is capped (see https://docs.mongodb.com/manual/core/capped-collections/). If true, + // Specifies if the collection is capped (see https://www.mongodb.com/docs/manual/core/capped-collections/). If true, // the SizeInBytes option must also be specified. The default value is false. Capped *bool @@ -75,6 +103,12 @@ type CreateCollectionOptions struct { // For previous server versions, the driver will return an error if this option is used. The default value is nil. Collation *Collation + // Specifies how change streams opened against the collection can return pre- and post-images of updated + // documents. The value must be a document in the form {