From b77347ee6f7b849d68b02f3234df232ff961a06b Mon Sep 17 00:00:00 2001 From: Sam Warters Date: Wed, 18 Jun 2025 15:34:00 -0500 Subject: [PATCH] Updates to support MongoDB 5.1+ In MongoDB 3.6, the `OP_MSG` and `OP_COMPRESSED` opcodes were added to the wired message protocol. `OP_MSG` was added to standardize the MongoDB message format, and `OP_COMPRESSED` takes things a step further by compressing the message to increase network efficiency. As of MongoDB 5.1, the old opcodes were removed, and all messages are sent with either `OP_MSG` or `OP_COMPRESSED`. Because this driver did not have the ability to send messages with the `OP_MSG` opcode, it couldn't be used with any versions of MongoDB greater than 5.1. To bring this driver into the modern era, this PR ports over the changes made in the [emqx fork](https://github.com/emqx/mongodb-erlang/) of this driver to support the `OP_MSG` opcode. Additionally, there are some slightly unrelated, but relevant changes in this PR as well such as fixing some dialyzer errors and updating the GitHub actions workflow definitions to get the tests up and running again. --- .github/workflows/dialyzer.yml | 6 +- .github/workflows/test.yml | 18 +-- .github/workflows/test_coverage.yml | 14 +-- .gitignore | 6 +- Makefile | 10 +- README.md | 23 +++- include/mongo_protocol.hrl | 23 +++- include/mongo_types.hrl | 17 ++- include/mongoc.hrl | 4 +- src/api/mc_worker_api.erl | 169 +++++++++++++++++++++++---- src/api/mongo_api.erl | 13 ++- src/api/mongoc.erl | 52 +++++++-- src/connection/mc_auth_logic.erl | 9 +- src/connection/mc_connection_man.erl | 130 +++++++++++++++++++-- src/connection/mc_cursor.erl | 51 +++++--- src/connection/mc_worker.erl | 60 ++++++++-- src/connection/mc_worker_logic.erl | 15 ++- src/connection/mongo_protocol.erl | 76 +++++++++++- src/main/mc_super_sup.erl | 3 +- src/main/mc_worker_pid_info.erl | 136 +++++++++++++++++++++ src/mongoc/mc_monitor.erl | 7 +- src/mongoc/mc_topology.erl | 12 +- src/support/mc_utils.erl | 19 ++- test/mc_test_utils.erl | 7 +- test/mc_worker_api_SUITE.erl | 11 +- test/mongo_api_SUITE.erl | 46 +++++++- test/switch_db_SUITE.erl | 2 +- 27 files changed, 818 insertions(+), 121 deletions(-) create mode 100644 src/main/mc_worker_pid_info.erl diff --git a/.github/workflows/dialyzer.yml b/.github/workflows/dialyzer.yml index c821eee0..8049b0d2 100644 --- a/.github/workflows/dialyzer.yml +++ b/.github/workflows/dialyzer.yml @@ -7,14 +7,14 @@ on: jobs: dialyzer: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 container: image: erlang:24-slim steps: - - uses: actions/checkout@v2.0.0 + - uses: actions/checkout@v4 - name: Cache PLTs id: cache-plts - uses: actions/cache@v2 + uses: actions/cache@v4 with: path: ~/.cache/rebar3/ key: ${{ runner.os }}-erlang-${{ hashFiles(format('{0}{1}', github.workspace, '/rebar.lock')) }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ff030ac5..785e9abf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -7,37 +7,37 @@ on: jobs: test: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 strategy: matrix: - erlang: [22, 23] + erlang: [24, 25] mongodb: ["4.4.8", "5.0.2"] container: image: erlang:${{ matrix.erlang }} steps: - - uses: actions/checkout@v2.0.0 + - uses: actions/checkout@v4 - run: ./scripts/install_mongo_debian.sh ${{ matrix.mongodb }} - run: ./scripts/start_mongo_single_node.sh - run: ./scripts/start_mongo_cluster.sh - run: ./rebar3 eunit - run: ./rebar3 ct - name: Archive Replica Set Logs - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: failure() with: - name: mongodb_replica_set_logs + name: erlang-${{ matrix.erlang }}-mongodb-${{ matrix.mongodb }}-mongodb_replica_set_logs path: rs0-logs retention-days: 1 - name: Archive Single Node Log - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: failure() with: - name: single_node.log + name: erlang-${{ matrix.erlang }}-mongodb-${{ matrix.mongodb }}-single_node.log path: single_node.log retention-days: 1 - name: CT Logs - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: - name: ct_logs + name: erlang-${{ matrix.erlang }}-mongodb-${{ matrix.mongodb }}-ct_logs path: _build/test/logs/ retention-days: 5 diff --git a/.github/workflows/test_coverage.yml b/.github/workflows/test_coverage.yml index 76ee41d0..449a7b92 100644 --- a/.github/workflows/test_coverage.yml +++ b/.github/workflows/test_coverage.yml @@ -7,11 +7,11 @@ on: jobs: test_coverage: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 container: - image: erlang:23 + image: erlang:24 steps: - - uses: actions/checkout@v2.0.0 + - uses: actions/checkout@v4 - run: ./scripts/install_mongo_debian.sh 5.0.2 - run: ./scripts/start_mongo_single_node.sh - run: ./scripts/start_mongo_cluster.sh @@ -19,27 +19,27 @@ jobs: - run: ./rebar3 ct --cover --cover_export_name ct.coverdata - run: rebar3 cover --verbose - name: Archive Replica Set Logs - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: failure() with: name: mongodb_replica_set_logs path: rs0-logs retention-days: 1 - name: Archive Single Node Log - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: failure() with: name: single_node.log path: single_node.log retention-days: 1 - name: Coverage Report - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: Coverage Report path: _build/test/cover/ retention-days: 5 - name: CT Logs - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: ct_logs path: _build/test/logs/ diff --git a/.gitignore b/.gitignore index 77179f06..95249589 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,8 @@ variables-ct* *.iml data _build -.idea \ No newline at end of file +.idea +*.log +rebar.lock +*.crashdump +*.plt diff --git a/Makefile b/Makefile index d62b6a1a..14ce2ce0 100644 --- a/Makefile +++ b/Makefile @@ -40,7 +40,7 @@ report-rebar3-version: # Dialyzer. .$(PROJECT).plt: - @$(DIALYZER) --build_plt --output_plt .$(PROJECT).plt -r deps \ + @$(DIALYZER) --build_plt --output_plt .$(PROJECT).plt -r _build/default/lib/ \ --apps erts kernel stdlib sasl inets crypto public_key ssl mnesia syntax_tools asn1 clean-plt: @@ -48,8 +48,8 @@ clean-plt: build-plt: clean-plt .$(PROJECT).plt -dialyze: .$(PROJECT).plt - @$(DIALYZER) -I include -I deps --src -r src --plt .$(PROJECT).plt --no_native \ - -Werror_handling -Wrace_conditions -Wunmatched_returns +dialyzer: .$(PROJECT).plt + @$(DIALYZER) -I include -I _build/default/lib/ --src -r src --plt .$(PROJECT).plt --no_native \ + -Werror_handling -Wrace_conditions -Wunmatched_returns --get_warnings -.PHONY: deps clean-plt build-plt dialyze report-rebar3-version +.PHONY: deps clean-plt build-plt dialyzer report-rebar3-version diff --git a/README.md b/README.md index cc705573..4154f82d 100644 --- a/README.md +++ b/README.md @@ -42,9 +42,30 @@ If you are choosing between using [mongos](https://docs.mongodb.com/manual/reference/program/mongos/) and using mongo shard with `mongo_api` - prefer mongos and use `mc_worker_api`. +In MongoDB 3.6, the `OP_MSG` and `OP_COMPRESSED` opcodes were added to the wired +message protocol. `OP_MSG` was added to standardize the MongoDB message format, +and `OP_COMPRESSED` takes things a step further by compressing the message to increase +network efficiency. As of MongoDB 5.1, the old opcodes were deprecated, and all +messages are sent with either `OP_MSG` or `OP_COMPRESSED`. + +By default, this driver tries to automatically detect which MongoDB messaging protocol +to use - the legacy one that existed before the `OP_MSG` opcode was introduced, or the +modern messaging protocol based on `OP_MSG`. This is accomplished by setting the default +value for the `use_legacy_protocol` env variable to `auto`. One can force the driver to +use the legacy opcodes or the `OP_MSG`opcode by setting the application env +`use_legacy_protocol` to `true` or `false` (for example by calling +`application:set_env(mongodb, use_legacy_protocol, false)`). + +It's also possible to define the usage of legacy protocol on a per-connection basis. +Simple pass `{use_legacy_protocol, false | true}` to mc_worker start options. + +As the `OP_MSG` opcode has existed in MongoDB since 3.6, and the driver defaults to auto- +detecting which set of opcodes to use, ALL of the messages will be sent using the `OP_MSG` +opcode if you are using a version for MongoDB greater 3.6 unless you explicitly set the +`use_legacy_protocol` variable to `true`. + mc_worker_api -- direct connection client --------------------------------- - ### Connecting To connect to a database `test` on mongodb server listening on `localhost:27017` (or any address & port of your choosing) diff --git a/include/mongo_protocol.hrl b/include/mongo_protocol.hrl index 45be0124..e4499893 100644 --- a/include/mongo_protocol.hrl +++ b/include/mongo_protocol.hrl @@ -8,6 +8,7 @@ -type colldb() :: collection() | {database(), collection()}. -type collection() :: binary() | atom(). % without db prefix -type database() :: binary() | atom(). +-type command() :: insert | update | delete. %% write @@ -45,6 +46,24 @@ projector = #{} :: mc_worker_api:projector() }). +-record(op_msg_write_op, { + command :: command(), + collection :: colldb(), + database :: undefined | mc_worker_api:database(), + extra_fields = [] :: bson:document() | nonempty_list({binary(),any()}), + documents_name = <<"documents">> :: bson:utf8(), + documents = [] :: any() +}). + +-record(op_msg_response, { + response_doc :: map() +}). + +-record(op_msg_command, { + database :: undefined | mc_worker_api:database(), + command_doc :: bson:document() | nonempty_list({binary(),any()}) +}). + -record(getmore, { collection :: colldb(), batchsize = 0 :: mc_worker_api:batchsize(), @@ -73,9 +92,9 @@ -record(reply, { cursornotfound :: boolean(), queryerror :: boolean(), - awaitcapable = false :: boolean(), + awaitcapable = false :: boolean() | undefined, cursorid :: mc_worker_api:cursorid(), - startingfrom = 0 :: integer(), + startingfrom = 0 :: integer() | undefined, documents :: [map()] }). -endif. diff --git a/include/mongo_types.hrl b/include/mongo_types.hrl index e2cab5ea..ffe0b361 100644 --- a/include/mongo_types.hrl +++ b/include/mongo_types.hrl @@ -31,11 +31,24 @@ | {ssl, boolean()} | {ssl_opts, proplists:proplist()} | {register, atom() | fun()}. +-type socket() :: gen_tcp:socket() | ssl:sslsocket(). -type write_mode() :: unsafe | safe | {safe, bson:document()}. -type read_mode() :: master | slave_ok. -type service() :: {Host :: inet:hostname() | inet:ip_address(), Post :: 0..65535}. -type options() :: [option()]. -type option() :: {timeout, timeout()} | {ssl, boolean()} | ssl | {database, database()} | {read_mode, read_mode()} | {write_mode, write_mode()}. -type cursor() :: pid(). --type query() :: #query{}. --endif. \ No newline at end of file +-type query() :: #'query'{}. +-type op_msg_command() :: #op_msg_command{}. +-type op_msg_write_op() :: #op_msg_write_op{}. +-type op_msg_response() :: #op_msg_response{}. +-type request() :: query() +| op_msg_command() +| op_msg_write_op() +| #killcursor{} +| #insert{} +| #update{} +| #delete{} +| #getmore{} +| #ensure_index{}. +-endif. diff --git a/include/mongoc.hrl b/include/mongoc.hrl index 148ad2b9..63b669f9 100644 --- a/include/mongoc.hrl +++ b/include/mongoc.hrl @@ -76,8 +76,8 @@ | {password, binary()} | {w_mode, mc_worker_api:write_mode()}. -type readprefs() :: [readpref()]. --type readpref() :: #{rp_mode => readmode()} +-type readpref() :: #{rp_mode => readmode()} | #{mode => binary()} | #{mode => binary(), tags => tuple()} |{rp_tags, [tuple()]}. -type reason() :: atom(). --endif. \ No newline at end of file +-endif. diff --git a/src/api/mc_worker_api.erl b/src/api/mc_worker_api.erl index 74443956..16922520 100644 --- a/src/api/mc_worker_api.erl +++ b/src/api/mc_worker_api.erl @@ -107,10 +107,23 @@ insert(Connection, Coll, Doc, WC, DB) when is_tuple(Doc); is_map(Doc) -> {Res, [UDoc | _]} = insert(Connection, Coll, [Doc], WC, DB), {Res, UDoc}; insert(Connection, Coll, Docs, WC, DB) -> - Converted = prepare(Docs, fun assign_id/1), - {command(DB, Connection, {<<"insert">>, Coll, <<"documents">>, Converted, <<"writeConcern">>, WC}), Converted}. - -%% @doc Insert one document or multiple documents into a colleciton. + ConvertedDocs = prepare(Docs, fun assign_id/1), + UseLegacyProtocol = mc_utils:use_legacy_protocol(Connection), + insert(UseLegacyProtocol, Connection, Coll, WC, DB, ConvertedDocs). + +insert(true, Connection, Coll, WriteConcern, DB, ConvertedDocs) -> + Command = + command(DB, Connection, {<<"insert">>, Coll, <<"documents">>, ConvertedDocs, <<"writeConcern">>, WriteConcern}), + {Command, ConvertedDocs}; +insert(false, Connection, Coll, WriteConcern, _DB, ConvertedDocs) -> + Msg = #op_msg_write_op{ + command = insert, + collection = Coll, + extra_fields = [{<<"writeConcern">>, WriteConcern}], + documents = ConvertedDocs}, + {mc_connection_man:op_msg(Connection, Msg), ConvertedDocs}. + +%% @doc Insert one document or multiple documents into a collection. %% params: %% connection - mc_worker pid %% collection - collection() @@ -132,9 +145,7 @@ update(Connection, Coll, Selector, Doc) -> %% @doc Replace the document matching criteria entirely with the new Document. -spec update(pid(), collection(), selector(), map() | bson:document(), boolean(), boolean()) -> {boolean(), map()}. update(Connection, Coll, Selector, Doc, Upsert, MultiUpdate) -> - Converted = prepare(Doc, fun(D) -> D end), - command(Connection, {<<"update">>, Coll, <<"updates">>, - [#{<<"q">> => Selector, <<"u">> => Converted, <<"upsert">> => Upsert, <<"multi">> => MultiUpdate}]}). + update(Connection, Coll, Selector, Doc, Upsert, MultiUpdate, {<<"w">>, 1}). %% @deprecated %% @doc Replace the document matching criteria entirely with the new Document. @@ -145,11 +156,26 @@ update(Connection, Coll, Selector, Doc, Upsert, MultiUpdate, WC) -> %% @deprecated %% @doc Replace the document matching criteria entirely with the new Document. -spec update(pid(), collection(), selector(), map() | bson:document(), boolean(), boolean(), bson:document(), database()) -> {boolean(), map()}. -update(Connection, Coll, Selector, Doc, Upsert, MultiUpdate, WC, DB) -> - Converted = prepare(Doc, fun(D) -> D end), - command(DB, Connection, {<<"update">>, Coll, <<"updates">>, - [#{<<"q">> => Selector, <<"u">> => Converted, <<"upsert">> => Upsert, <<"multi">> => MultiUpdate}], - <<"writeConcern">>, WC}). +update(Connection, Coll, Selector, Doc, Upsert, MultiUpdate, WriteConcern, DB) -> + ConvertedDocs = prepare(Doc, fun(D) -> D end), + UseLegacyProtocol = mc_utils:use_legacy_protocol(Connection), + update(UseLegacyProtocol, Connection, Coll, Selector, ConvertedDocs, Upsert, MultiUpdate, WriteConcern, DB). + +update(true, Connection, Coll, Selector, ConvertedDocs, Upsert, MultiUpdate, WriteConcern, DB) -> + command(DB, Connection, {<<"update">>, Coll, <<"updates">>, [#{<<"q">> => Selector, <<"u">> => ConvertedDocs, + <<"upsert">> => Upsert, <<"multi">> => MultiUpdate}], <<"writeConcern">>, WriteConcern}); +update(false, Connection, Coll, Selector, ConvertedDocs, Upsert, MultiUpdate, WriteConcern, _DB) -> + Msg = #op_msg_write_op{ + command = update, + collection = Coll, + extra_fields = [{<<"writeConcern">>, WriteConcern}], + documents_name = <<"updates">>, + documents = [#{ + <<"q">> => Selector, + <<"u">> => ConvertedDocs, + <<"upsert">> => Upsert, + <<"multi">> => MultiUpdate}]}, + mc_connection_man:op_msg(Connection, Msg). %% @doc Replace the document matching criteria entirely with the new Document. %% params: @@ -183,22 +209,35 @@ delete_one(Connection, Coll, Selector) -> %% @deprecated %% @doc Delete selected documents -spec delete_limit(pid(), collection(), selector(), integer()) -> {boolean(), map()}. -delete_limit(Connection, Coll, Selector, N) -> - command(Connection, {<<"delete">>, Coll, <<"deletes">>, - [#{<<"q">> => Selector, <<"limit">> => N}]}). +delete_limit(Connection, Coll, Selector, Limit) -> + delete_limit(Connection, Coll, Selector, Limit, undefined). %% @deprecated %% @doc Delete selected documents --spec delete_limit(pid(), collection(), selector(), integer(), bson:document()) -> {boolean(), map()}. -delete_limit(Connection, Coll, Selector, N, WC) -> - delete_limit(Connection, Coll, Selector, N, WC, undefined). +-spec delete_limit(pid(), collection(), selector(), integer(), bson:document() | undefined) -> {boolean(), map()}. +delete_limit(Connection, Coll, Selector, Limit, undefined) -> + delete_limit(Connection, Coll, Selector, Limit, {<<"w">>, 1}, undefined); +delete_limit(Connection, Coll, Selector, Limit, WriteConcern) -> + delete_limit(Connection, Coll, Selector, Limit, WriteConcern, undefined). %% @deprecated %% @doc Delete selected documents -spec delete_limit(pid(), collection(), selector(), integer(), bson:document(), database()) -> {boolean(), map()}. -delete_limit(Connection, Coll, Selector, N, WC, DB) -> +delete_limit(Connection, Coll, Selector, Limit, WriteConcern, DB) -> + UseLegacyProtocol = mc_utils:use_legacy_protocol(Connection), + delete_limit(UseLegacyProtocol, Connection, Coll, Selector, Limit, WriteConcern, DB). + +delete_limit(true, Connection, Coll, Selector, Limit, WriteConcern, DB) -> command(DB, Connection, {<<"delete">>, Coll, <<"deletes">>, - [#{<<"q">> => Selector, <<"limit">> => N}], <<"writeConcern">>, WC}). + [#{<<"q">> => Selector, <<"limit">> => Limit}], <<"writeConcern">>, WriteConcern}); +delete_limit(false, Connection, Coll, Selector, Limit, WriteConcern, _DB) -> + Msg = #op_msg_write_op{command = delete, + collection = Coll, + extra_fields = [{<<"writeConcern">>, WriteConcern}], + documents_name = <<"deletes">>, + documents = [#{<<"q">> => Selector, + <<"limit">> => Limit}]}, + mc_connection_man:op_msg(Connection, Msg). %% @doc Delete selected documents %% params: @@ -230,6 +269,10 @@ find_one(Connection, Coll, Selector, Args, Db) -> Projector = maps:get(projector, Args, #{}), Skip = maps:get(skip, Args, 0), ReadPref = maps:get(readopts, Args, #{<<"mode">> => <<"primary">>}), + UseLegacyProtocol = mc_utils:use_legacy_protocol(Connection), + find_one(UseLegacyProtocol, Connection, Coll, Selector, Projector, Skip, ReadPref, Db). + +find_one(true, Connection, Coll, Selector, Projector, Skip, ReadPref, Db) -> find_one(Connection, #'query'{ database = Db, @@ -237,6 +280,21 @@ find_one(Connection, Coll, Selector, Args, Db) -> selector = mongoc:append_read_preference(Selector, ReadPref), projector = Projector, skip = Skip + }); +find_one(false, Connection, Coll, Selector, Projector, Skip, ReadPref, _Db) -> + CommandDoc = [ + {<<"find">>, Coll}, + {<<"$readPreference">>, ReadPref}, + {<<"filter">>, Selector}, + {<<"projection">>, Projector}, + {<<"skip">>, Skip}, + {<<"batchSize">>, 1}, + {<<"limit">>, 1}, + {<<"singleBatch">>, true} %% Close cursor after first batch + ], + mc_connection_man:op_msg_read_one(Connection, + #'op_msg_command'{ + command_doc = CommandDoc }). %% @doc Return projection of selected documents. @@ -254,7 +312,21 @@ find_one(Cmd = #{connection := Connection, collection := Collection, selector := -spec find_one(pid() | atom(), query()) -> map() | undefined. find_one(Connection, Query) when is_record(Query, query) -> - mc_connection_man:read_one(Connection, Query). + UseLegacyProtocol = mc_utils:use_legacy_protocol(Connection), + find_one_with_query_record(UseLegacyProtocol, Connection, Query). + +find_one_with_query_record(true, Connection, Query) -> + mc_connection_man:read_one(Connection, Query); +find_one_with_query_record(false, Connection, Query) -> + #'query'{collection = Coll, + skip = Skip, + selector = Selector, + projector = Projector} = Query, + {RP, NewSelector, _} = mongoc:extract_read_preference(Selector), + Args = #{projector => Projector, + skip => Skip, + readopts => RP}, +find_one(Connection, Coll, NewSelector, Args). %% @deprecated %% @doc Return selected documents. @@ -274,7 +346,13 @@ find(Connection, Coll, Selector, Args) -> find(Connection, Coll, Selector, Args, Db) -> Projector = maps:get(projector, Args, #{}), Skip = maps:get(skip, Args, 0), - BatchSize = maps:get(batchsize, Args, 0), + BatchSize = + case mc_utils:use_legacy_protocol(Connection) of + true -> + maps:get(batchsize, Args, 0); + false -> + maps:get(batchsize, Args, 101) + end, ReadPref = maps:get(readopts, Args, #{<<"mode">> => <<"primary">>}), find(Connection, #'query'{ @@ -304,12 +382,49 @@ find(Cmd = #{connection := Connection, collection := Collection, selector := Sel -spec find(pid() | atom(), query()) -> {ok, cursor()} | []. find(Connection, Query) when is_record(Query, query) -> - case mc_connection_man:read(Connection, Query) of + FixedQuery = fixed_query(mc_utils:use_legacy_protocol(Connection), Query), + case mc_connection_man:read(Connection, FixedQuery) of [] -> []; {ok, Cursor} when is_pid(Cursor) -> {ok, Cursor} end. +fixed_query(true, Query) -> + Query; +fixed_query(false, Query) -> + #'query'{collection = Coll, + skip = Skip, + selector = Selector, + batchsize = BatchSize, + projector = Projector} = Query, + {ReadPref, NewSelector, OrderBy} = mongoc:extract_read_preference(Selector), + %% We might need to do some transformations: + %% See: https://github.com/mongodb/specifications/blob/master/source/find_getmore_killcursors_commands.rst#mapping-op-query-behavior-to-the-find-command-limit-and-batchsize-fields + SingleBatch = BatchSize < 0, + AbsBatchSize = erlang:abs(BatchSize), + BatchSizeField = batch_size(AbsBatchSize =:= 0, AbsBatchSize), + SingleBatchField = single_batch(SingleBatch), + SortField = sort_field(OrderBy), + CommandDoc = [ + {<<"find">>, Coll}, + {<<"$readPreference">>, ReadPref}, + {<<"filter">>, NewSelector}, + {<<"projection">>, Projector}, + {<<"skip">>, Skip} + ] ++ SortField + ++ BatchSizeField + ++ SingleBatchField, + #op_msg_command{command_doc = CommandDoc}. + +batch_size(true, _BatchSize) -> []; +batch_size(false, BatchSize) -> [{<<"batchSize">>, BatchSize}]. + +single_batch(true) -> []; +single_batch(false = SingleBatch) -> [{<<"singleBatch">>, SingleBatch}]. + +sort_field(OrderBy) when is_map(OrderBy), map_size(OrderBy) =:= 0 -> []; +sort_field(OrderBy) -> [{<<"sort">>, OrderBy}]. + %% @deprecated %% @doc Count selected documents -spec count(pid(), collection(), selector()) -> integer(). @@ -358,7 +473,12 @@ ensure_index(Connection, Coll, IndexSpec) -> -spec ensure_index(pid(), colldb(), bson:document(), database()) -> ok | {error, any()}. ensure_index(Connection, Coll, IndexSpec, DB) -> - mc_connection_man:request_worker(Connection, #ensure_index{database = DB, collection = Coll, index_spec = IndexSpec}). + ensure_index(mc_utils:use_legacy_protocol(Connection), Connection, Coll, IndexSpec, DB). + +ensure_index(true, Connection, Coll, IndexSpec, DB) -> + mc_connection_man:request_worker(Connection, #ensure_index{database = DB, collection = Coll, index_spec = IndexSpec}); +ensure_index(false, Connection, Coll, IndexSpec, DB) -> + command(DB, Connection, {<<"createIndexes">>, Coll, <<"indexes">>, IndexSpec}). %% @doc Execute given MongoDB command and return its result. -spec command(pid(), selector()) -> {boolean(), map()} | {ok, cursor()}. @@ -374,7 +494,6 @@ command(Db, Connection, Command) -> command(Db, Connection, Command, IsSlaveOk) -> mc_connection_man:database_command(Connection, Db, Command, IsSlaveOk). - %% @private -spec prepare(tuple() | list() | map(), fun()) -> list(). prepare(Docs, AssignFun) when is_tuple(Docs) -> %bson diff --git a/src/api/mongo_api.erl b/src/api/mongo_api.erl index e6072e2c..71ede345 100644 --- a/src/api/mongo_api.erl +++ b/src/api/mongo_api.erl @@ -13,6 +13,8 @@ -include("mongoc.hrl"). -include("mongo_protocol.hrl"). +-type transaction_result(T) :: T | {error, term()}. + %% API -export([ connect/4, @@ -23,6 +25,7 @@ delete/3, count/4, ensure_index/3, + command/3, disconnect/1]). -spec connect(atom(), list(), proplists:proplist(), proplists:proplist()) -> {ok, pid()}. @@ -104,6 +107,14 @@ ensure_index(Topology, Coll, IndexSpec) -> mc_worker_api:ensure_index(Worker, Coll, IndexSpec) end, #{}). +-spec command(atom() | pid(), selector(), timeout()) -> transaction_result(integer()). +command(Topology, Command, Timeout) -> + mongoc:transaction_query(Topology, + fun(Conf = #{pool := Worker}) -> + Query = mongoc:command_query(Conf, Command), + mc_worker_api:command(Worker, Query) + end, #{}, Timeout). + -spec disconnect(atom() | pid()) -> ok. disconnect(Topology) -> - mongoc:disconnect(Topology). \ No newline at end of file + mongoc:disconnect(Topology). diff --git a/src/api/mongoc.erl b/src/api/mongoc.erl index 0d0c1a7c..28f9de77 100644 --- a/src/api/mongoc.erl +++ b/src/api/mongoc.erl @@ -21,8 +21,10 @@ transaction/4, status/1, append_read_preference/2, + extract_read_preference/1, find_query/6, count_query/4, + command_query/2, find_one_query/5]). @@ -60,19 +62,23 @@ transaction(Topology, Transaction, Options, Timeout) -> catch error:not_master -> mc_topology:update_topology(Topology), + error_logger:error_msg("transaction error:~p reason:~p~n", [error, not_master]), {error, not_master}; error:{bad_query, {not_master, _}} -> mc_topology:update_topology(Topology), + error_logger:error_msg("transaction error:~p reason:~p~n", [error, not_master]), {error, not_master}; - _:R -> + E:R -> mc_topology:update_topology(Topology), + error_logger:error_msg("transaction error:~p reason:~p~n", [E, R]), {error, R} end; Error -> + error_logger:error_msg("mc_topology get_pool error:~p~n", [Error]), Error end. -%% @doc Get worker from pool and run transaction with additioanl query options on it. Suitable for read transactions +%% @doc Get worker from pool and run transaction with additional query options on it. Suitable for read transactions -spec transaction_query(pid() | atom(), fun()) -> any(). transaction_query(Topology, Transaction) -> transaction_query(Topology, Transaction, #{}). @@ -83,12 +89,7 @@ transaction_query(Topology, Transaction, Options) -> -spec transaction_query(pid() | atom(), fun(), map(), integer() | infinity) -> any(). transaction_query(Topology, Transaction, Options, Timeout) -> - case mc_topology:get_pool(Topology, Options) of - {ok, Pool = #{pool := C}} -> - poolboy:transaction(C, fun(Worker) -> Transaction(Pool#{pool => Worker}) end, Timeout); - Error -> - Error - end. + transaction(Topology, Transaction, Options, Timeout). -spec find_one_query(map(), collection(), selector(), projector(), integer()) -> query(). find_one_query(#{server_type := ServerType, read_preference := RPrefs}, Coll, Selector, Projector, Skip) -> @@ -129,7 +130,15 @@ count_query(#{server_type := ServerType, read_preference := RPrefs}, Coll, Selec }, mongos_query_transform(ServerType, Q, RPrefs). --spec append_read_preference(selector(), map()) -> selector(). +-spec command_query(map(), selector()) -> query(). +command_query(#{server_type := ServerType, read_preference := RPrefs}, Command) -> + Q = #'query'{ + collection = <<"$cmd">>, + selector = Command + }, + mongos_query_transform(ServerType, Q, RPrefs). + +-spec append_read_preference(selector(), readpref()) -> selector(). append_read_preference(Selector = #{<<"$query">> := _}, RP) -> Selector#{<<"$readPreference">> => RP}; append_read_preference(Selector, RP) when is_tuple(Selector) andalso element(1, Selector) =:= <<"count">> -> @@ -140,6 +149,31 @@ append_read_preference(Selector, RP) -> #{<<"$query">> => Selector, <<"$readPreference">> => RP}. +extract_read_preference(#{<<"$readPreference">> := RP} = Selector) -> + {RP, + maps:get(<<"$query">>, Selector, #{}), + maps:get(<<"$orderby">>, Selector, #{})}; +extract_read_preference(Selector) when is_map(Selector) -> + {#{<<"mode">> => <<"primary">>}, + maps:get(<<"$query">>, Selector, Selector), + maps:get(<<"$orderby">>, Selector, #{})};%TODO also extract orderby and what else might be inside (strange but needed to pass test) +extract_read_preference(Selector) when is_tuple(Selector) -> + Fields = bson:fields(Selector), + Query = case lists:keyfind(<<"$query">>, 1, Fields) of + {_, Q} -> Q; + false -> Selector + end, + OrderBy = case lists:keyfind(<<"$orderby">>, 1, Fields) of + {_, OB} -> OB; + false -> #{} + end, + case lists:keyfind(<<"$readPreference">>, 1, Fields) of + {_, RP} -> + {RP, Query, OrderBy}; + false -> + {#{<<"mode">> => <<"primary">>}, Query, OrderBy} + end. + %%%=================================================================== %%% Internal functions %%%=================================================================== diff --git a/src/connection/mc_auth_logic.erl b/src/connection/mc_auth_logic.erl index 92eed8e8..d4440dcd 100644 --- a/src/connection/mc_auth_logic.erl +++ b/src/connection/mc_auth_logic.erl @@ -36,7 +36,7 @@ auth(Connection, _, Database, Login, Password) -> %old authorisation %% @private --spec mongodb_cr_auth(pid(), binary(), binary(), binary()) -> boolean(). +-spec mongodb_cr_auth(pid(), binary(), binary(), binary()) -> boolean() | no_return(). mongodb_cr_auth(Connection, Database, Login, Password) -> {true, Res} = mc_connection_man:database_command(Connection, Database, {<<"getnonce">>, 1}), Nonce = maps:get(<<"nonce">>, Res), @@ -46,7 +46,7 @@ mongodb_cr_auth(Connection, Database, Login, Password) -> end. %% @private --spec scram_sha_1_auth(port(), binary(), binary(), binary()) -> boolean(). +-spec scram_sha_1_auth(pid(), database(), binary(), binary()) -> boolean(). scram_sha_1_auth(Connection, Database, Login, Password) -> try scram_first_step(Connection, Database, Login, Password) @@ -124,9 +124,14 @@ generate_sig(SaltedPassword, AuthMessage) -> mc_utils:hmac(ServerKey, AuthMessage). %% @private +-if(?OTP_RELEASE >= 24). +hi(Password, Salt, Iterations) -> + crypto:pbkdf2_hmac(sha, Password, Salt, Iterations, 20). +-else. hi(Password, Salt, Iterations) -> {ok, Key} = pbkdf2:pbkdf2(sha, Password, Salt, Iterations, 20), Key. +-endif. %% @private xorKeys(<<>>, _, Res) -> Res; diff --git a/src/connection/mc_connection_man.erl b/src/connection/mc_connection_man.erl index 61e0e6a4..afdd580e 100644 --- a/src/connection/mc_connection_man.erl +++ b/src/connection/mc_connection_man.erl @@ -12,28 +12,42 @@ -include("mongo_types.hrl"). -include("mongo_protocol.hrl"). +-dialyzer({no_fail_call, query_to_op_msg_cmd/2}). + -define(NOT_MASTER_ERROR, 13435). -define(UNAUTHORIZED_ERROR(C), C =:= 10057; C =:= 16550). %% API --export([request_worker/2]). +-export([request_worker/2, process_reply/2]). -export([read/2, read/3, read_one/2]). --export([command/2, command/3, database_command/3, database_command/4]). +-export([op_msg/2, op_msg_read_one/2, op_msg_raw_result/2]). +-export([command/2, command/3, database_command/3, database_command/4, request_raw_no_parse/4]). -spec read(pid() | atom(), query()) -> [] | {ok, pid()}. read(Connection, Request) -> read(Connection, Request, undefined). --spec read(pid() | atom(), query(), undefined | mc_worker_api:batchsize()) -> [] | {ok, pid()}. +-spec read(pid() | atom(), query() | op_msg_command(), undefined | mc_worker_api:batchsize()) -> [] | {ok, pid()}. read(Connection, Request = #'query'{collection = Collection, batchsize = BatchSize, database = DB}, CmdBatchSize) -> + read(Connection, Request, Collection, select_batchsize(CmdBatchSize, BatchSize), DB); +read(Connection, Request = #'op_msg_command'{database = DB, command_doc = ([{_, Collection} | _ ] = Fields)}, + _CmdBatchSize) -> + BatchSize = case lists:keyfind(<<"batchSize">>, 1, Fields) of + {_, Size} -> Size; + false -> 101 + end, + read(Connection, Request, Collection, BatchSize, DB). + +read(Connection, Request, Collection, BatchSize, DB) -> case request_worker(Connection, Request) of {_, []} -> []; {Cursor, Batch} -> - mc_cursor:start_link(Connection, Collection, Cursor, select_batchsize(CmdBatchSize, BatchSize), Batch, DB) + mc_cursor:start_link(Connection, Collection, Cursor, BatchSize, Batch, DB); + X -> + erlang:error({error_unexpected_response, X}) end. - --spec read_one(pid() | atom(), query()) -> undefined | map(). +-spec read_one(pid() | atom(), request()) -> undefined | map(). read_one(Connection, Request) -> {0, Docs} = request_worker(Connection, Request#'query'{batchsize = -1}), case Docs of @@ -41,25 +55,52 @@ read_one(Connection, Request) -> [Doc | _] -> Doc end. +-spec command(pid(), mc_worker_api:selector()) -> {boolean(), map()}. command(Connection, Query = #query{selector = Cmd}) -> + QueryOrOpMsg = query_to_op_msg_cmd(mc_utils:use_legacy_protocol(Connection), Query), case determine_cursor(Cmd) of false -> - Doc = read_one(Connection, Query), - process_reply(Doc, Query); + legacy_command(mc_utils:use_legacy_protocol(Connection), Connection, QueryOrOpMsg); BatchSize -> - case read(Connection, Query#query{batchsize = -1}, BatchSize) of + case read(Connection, QueryOrOpMsg, BatchSize) of [] -> []; {ok, Cursor} when is_pid(Cursor) -> {ok, Cursor} end end; -command(Connection, Command) when not is_record(Command, query)-> +command(Connection, Command) when not is_record(Command, query) -> command(Connection, #'query'{ collection = <<"$cmd">>, selector = Command }). +legacy_command(true, Connection, Query) -> + Doc = read_one(Connection, Query), + process_reply(Doc, Query); +legacy_command(false, Connection, OpMsg) -> + {true, mc_connection_man:op_msg_raw_result(Connection, OpMsg)}. + +-spec query_to_op_msg_cmd(boolean(), mc_worker_api:selector() ) -> query() | op_msg_command(). +query_to_op_msg_cmd(true, Query) -> + Query#query{batchsize = -1}; +query_to_op_msg_cmd(false, Query) -> + #query{database = DB, slaveok = SlaveOk, selector = Selector} = Query, + Fields = bson:fields(Selector), + NewSelector = + case {lists:keyfind(<<"$readPreference">>, 1, Fields), SlaveOk} of + {{<<"$readPreference">>, _}, _} -> Selector; + {false, true} -> + bson:merge(Fields, {<<"$readPreference">>, #{<<"mode">> => <<"primaryPreferred">>}}); + {false, false} -> + %% primary is the default mode so we do not need to change anything + Fields + end, + #'op_msg_command'{ + database = DB, + command_doc = NewSelector + }. + command(Connection, Command, _IsSlaveOk = true) -> command(Connection, #'query'{ @@ -90,7 +131,7 @@ database_command(Connection, Database, Command, IsSlaveOk) -> }, IsSlaveOk). --spec request_worker(pid(), mongo_protocol:message()) -> ok | {non_neg_integer(), [map()]}. +-spec request_worker(pid(), mongo_protocol:message()) -> ok | {non_neg_integer(), [map()]} | map(). request_worker(Connection, Request) -> %request to worker Timeout = mc_utils:get_timeout(), reply(gen_server:call(Connection, Request, Timeout)). @@ -101,6 +142,64 @@ process_reply(Doc = #{<<"ok">> := N}, _) when is_number(N) -> %command succeed process_reply(Doc, Command) -> %unknown result erlang:error({bad_command, Doc}, [Command]). +op_msg(Connection, OpMsg) -> + Doc = request_worker(Connection, OpMsg), + process_reply(Doc, OpMsg). + +op_msg_read_one(Connection, OpMsg) -> + Timeout = mc_utils:get_timeout(), + Response = gen_server:call(Connection, OpMsg, Timeout), + case Response of + #op_msg_response{response_doc = + #{<<"ok">> := 1.0, + <<"cursor">>:= + #{<<"firstBatch">>:=[Doc], + <<"id">>:=0} + }} -> + Doc; + #op_msg_response{response_doc = + #{<<"ok">> := 1.0}} -> + undefined; + #op_msg_response{response_doc = Doc} -> + erlang:error({error, Doc}); + _ -> + erlang:error({error_unexpected_response, Response}) + end. + +op_msg_raw_result(Connection, OpMsg) -> + Timeout = mc_utils:get_timeout(), + FromServer = gen_server:call(Connection, OpMsg, Timeout), + case FromServer of + #op_msg_response{response_doc = + (#{<<"ok">> := 1.0} = Res)} -> + Res; + _ -> + erlang:error({error, FromServer}) + end. + +request_raw_no_parse(Socket, Database, Request, NetModule) -> + Timeout = mc_utils:get_timeout(), + ok = set_opts(Socket, NetModule, false), + {ok, _, _} = mc_worker_logic:make_request(Socket, NetModule, Database, Request), + Result = recv_all(Socket, Timeout, NetModule), + ok = set_opts(Socket, NetModule, true), + Result. + +%% @private +set_opts(Socket, ssl, Value) -> + ssl:setopts(Socket, [{active, Value}]); +set_opts(Socket, gen_tcp, Value) -> + inet:setopts(Socket, [{active, Value}]). + +%% @private +recv_all(Socket, Timeout, NetModule) -> + recv_all(Socket, Timeout, NetModule, <<>>). +recv_all(Socket, Timeout, NetModule, Rest) -> + {ok, Packet} = NetModule:recv(Socket, 0, Timeout), + case mc_worker_logic:decode_responses(<>) of + {[], Unfinished} -> recv_all(Socket, Timeout, NetModule, Unfinished); + {Responses, _} -> Responses + end. %% @private reply(ok) -> ok; @@ -112,7 +211,14 @@ reply(#reply{cursornotfound = false, queryerror = true} = Reply) -> reply(#reply{cursornotfound = true, queryerror = false} = Reply) -> erlang:error({bad_cursor, Reply#reply.cursorid}); reply({error, Error}) -> - process_error(error, Error). + process_error(error, Error); +reply(#op_msg_response{response_doc = (#{<<"cursor">> := #{<<"firstBatch">> := Batch, <<"id">> := Id}} = Doc)}) when + map_get(<<"ok">>, Doc) == 1 -> + {Id, Batch}; +reply(#op_msg_response{response_doc = Document}) when map_get(<<"ok">>, Document) == 1 -> + Document; +reply(Resp) -> + erlang:error({error_cannot_parse_response, Resp}). %% @private -spec process_error(atom() | integer(), term()) -> no_return(). diff --git a/src/connection/mc_cursor.erl b/src/connection/mc_cursor.erl index 8fecf0ab..4b48e23d 100644 --- a/src/connection/mc_cursor.erl +++ b/src/connection/mc_cursor.erl @@ -29,6 +29,8 @@ code_change/3 ]). +-dialyzer({no_match, handle_call/3}). + -record(state, { connection :: mc_worker:connection(), collection :: atom(), @@ -40,7 +42,7 @@ }). --spec next(pid()) -> error | {bson:document()}. +-spec next(pid()) -> error | {} | {bson:document()}. next(Cursor) -> next(Cursor, cursor_default_timeout()). @@ -51,33 +53,33 @@ next(Cursor, Timeout) -> exit:{noproc, _} -> error end. --spec next_batch(pid()) -> error | {bson:document()}. +-spec next_batch(pid()) -> error | {} | {bson:document()}. next_batch(Cursor) -> next_batch(Cursor, cursor_default_timeout()). --spec next_batch(pid(), timeout()) -> error | {} | {bson:document()}. +-spec next_batch(pid(),timeout()) -> error | {} | {bson:document()}. next_batch(Cursor, Timeout) -> try gen_server:call(Cursor, {next_batch, Timeout}, Timeout) catch exit:{noproc, _} -> error end. --spec rest(pid()) -> [bson:document()] | error. +-spec rest(pid()) -> [bson:document()] | {} | error. rest(Cursor) -> rest(Cursor, cursor_default_timeout()). --spec rest(pid(), timeout()) -> [bson:document()] | error. +-spec rest(pid(), timeout()) -> [bson:document()] | {} | error. rest(Cursor, Timeout) -> try gen_server:call(Cursor, {rest, infinity, Timeout}, Timeout) catch exit:{noproc, _} -> error end. --spec take(pid(), non_neg_integer()) -> [bson:document()] | error. +-spec take(pid(), non_neg_integer()) -> [bson:document()] | {} | error. take(Cursor, Limit) -> take(Cursor, Limit, cursor_default_timeout()). --spec take(pid(), non_neg_integer(), timeout()) -> [bson:document()] | error. +-spec take(pid(), non_neg_integer(), timeout()) -> [bson:document()] | {} | error. take(Cursor, Limit, Timeout) -> try gen_server:call(Cursor, {rest, Limit, Timeout}, Timeout) catch @@ -121,6 +123,7 @@ start_link(Connection, Collection, Cursor, BatchSize, Batch, DB) -> %% @hidden init([Owner, Connection, Collection, Cursor, BatchSize, Batch,DB]) -> + process_flag(trap_exit, true), Monitor = erlang:monitor(process, Owner), {ok, #state{ connection = Connection, @@ -169,8 +172,16 @@ handle_info(_, State) -> %% @hidden terminate(_, #state{cursor = 0}) -> ok; -terminate(_, State) -> - gen_server:call(State#state.connection, #killcursor{cursorids = [State#state.cursor]}). +terminate(_, #state{collection = Collection, connection = Connection, cursor = Cursor}) -> + terminate(mc_utils:use_legacy_protocol(Connection), Connection, Collection, Cursor). + +terminate(true, Connection, _Collection, Cursor) -> + gen_server:call(Connection, #killcursor{cursorids = [Cursor]}); +terminate(false, Connection, Collection, Cursor) -> + KillCursorCommand = + #op_msg_command{command_doc = [{<<"killCursors">>, Collection}, + {<<"cursors">>, [Cursor]}]}, + mc_connection_man:request_worker(Connection, KillCursorCommand). %% @hidden code_change(_Old, State, _Extra) -> @@ -181,19 +192,33 @@ next_i(#state{batch = [Doc | Rest]} = State, _Timeout) -> {{Doc}, State#state{batch = Rest}}; next_i(#state{batch = [], cursor = 0} = State, _Timeout) -> {{}, State}; -next_i(#state{batch = []} = State, Timeout) -> +next_i(#state{batch = [], connection = Connection} = State, Timeout) -> + next_i(mc_utils:use_legacy_protocol(Connection), State, Timeout). + +next_i(true, State, Timeout) -> Reply = gen_server:call( State#state.connection, #getmore{ collection = State#state.collection, batchsize = State#state.batchsize, - cursorid = State#state.cursor, - database = State#state.database + cursorid = State#state.cursor }, Timeout), Cursor = Reply#reply.cursorid, Batch = Reply#reply.documents, - next_i(State#state{cursor = Cursor, batch = Batch}, Timeout). + next_i(State#state{cursor = Cursor, batch = Batch}, Timeout); +next_i(false, State, Timeout) -> + GetMoreCommand = + #op_msg_command{command_doc = [{<<"getMore">>, State#state.cursor}, + {<<"collection">>, State#state.collection}, + {<<"batchSize">>, State#state.batchsize}]}, + Result = mc_connection_man:request_worker(State#state.connection, GetMoreCommand), + case Result of + #{<<"cursor">>:=#{<<"id">>:=NewCursorId,<<"nextBatch">>:=Batch},<<"ok">>:=1.0} -> + next_i(State#state{cursor = NewCursorId, batch = Batch}, Timeout); + _ -> + erlang:error({error_unexpected_cursor_result, Result}) + end. %% @private rest_i(State, infinity, Timeout) -> diff --git a/src/connection/mc_worker.erl b/src/connection/mc_worker.erl index 4bee3ce5..66286dd4 100644 --- a/src/connection/mc_worker.erl +++ b/src/connection/mc_worker.erl @@ -5,6 +5,7 @@ -define(WRITE(Req), is_record(Req, insert); is_record(Req, update); is_record(Req, delete)). -define(READ(Req), is_record(Request, 'query'); is_record(Request, getmore)). +-define(OP_MSG(Req), is_record(Request, 'op_msg_command'); is_record(Request, 'op_msg_write_op')). -define(LOG_DEFAULT_DB, fun() -> error_logger:info_msg("Using default 'test' database"), <<"test">> end). -export([start_link/1, database/2, disconnect/1, hibernate/1]). @@ -16,12 +17,14 @@ terminate/2, code_change/3]). +-dialyzer({no_fail_call, get_op_msg_write_concern/1}). + -record(state, { socket :: gen_tcp:socket() | ssl:sslsocket(), request_storage = #{} :: map(), buffer = <<>> :: binary(), conn_state :: conn_state(), - hibernate_timer :: reference() | undefined, + hibernate_timer :: timer:tref() | reference() | undefined, next_req_fun :: fun(), net_module :: ssl | gen_tcp }). @@ -53,12 +56,20 @@ init(Options) -> try_register(Options), NetModule = get_set_opts_module(Options), NextReqFun = mc_utils:get_value(next_req_fun, Options, fun() -> ok end), - proc_lib:init_ack({ok, self()}), - gen_server:enter_loop(?MODULE, [], - #state{socket = Socket, - conn_state = ConnState, - net_module = NetModule, - next_req_fun = NextReqFun}); + DefaultUseLegacyProtocol = application:get_env(mongodb, use_legacy_protocol, auto), + UseLegacyProtocol = mc_utils:get_value(use_legacy_protocol, Options, DefaultUseLegacyProtocol), + ProtoOpts = #{use_legacy_protocol => UseLegacyProtocol}, + case mc_worker_pid_info:install_mc_worker_info(Options, NetModule, ConnState#conn_state.database, ProtoOpts) of + ok -> + proc_lib:init_ack({ok, self()}), + gen_server:enter_loop(?MODULE, [], + #state{socket = Socket, + conn_state = ConnState, + net_module = NetModule, + next_req_fun = NextReqFun}); + {error, _} = Error -> + proc_lib:init_ack(Error) + end; Error -> proc_lib:init_ack(Error) end. @@ -76,6 +87,8 @@ handle_call(CMD = #ensure_index{collection = Coll, index_spec = IndexSpec}, _, S ConnState#conn_state.database, #insert{collection = mc_worker_logic:update_dbcoll(Coll, <<"system.indexes">>), documents = [Index]}), {reply, ok, State}; +handle_call(Request, From, State) when ?OP_MSG(Request) -> % MongoDB OpMsg request + process_op_msg_request(Request, From, State); handle_call(Request, From, State) when ?WRITE(Request) -> % write requests (deprecated) process_write_request(Request, From, State); handle_call(Request, From, State) when ?READ(Request) -> % read requests (and all through command) @@ -122,6 +135,39 @@ terminate(_, State = #state{net_module = NetModule}) -> code_change(_Old, State, _Extra) -> {ok, State}. +process_op_msg_request(Request, From, State) -> + #state{socket = Socket, + request_storage = RequestStorage, + conn_state = CS, + net_module = NetModule, + next_req_fun = Next} = State, + Database = CS#conn_state.database, + {ok, PacketSize, Id} = mc_worker_logic:make_request(Socket, NetModule, Database, Request), + UState = need_hibernate(PacketSize, State), + case get_op_msg_write_concern(Request) of + {_, {<<"w">>, 0}} -> %no concern request + Next(), + {reply, + #op_msg_response{response_doc = #{<<"ok">> => 1.0}}, + UState}; + _ -> %ordinary request with response + Next(), + RespFun = mc_worker_logic:get_resp_fun(Request, From), % save function, which will be called on response + URStorage = RequestStorage#{Id => RespFun}, + {noreply, UState#state{request_storage = URStorage}} + end. + +get_op_msg_write_concern(#op_msg_write_op{extra_fields = ExtraFields}) -> + case lists:keyfind(<<"writeConcern">>, 1, ExtraFields) of + {_, WC} -> WC; + _ -> not_found + end; +get_op_msg_write_concern(#op_msg_command{command_doc = DocList}) -> + case lists:keyfind(<<"writeConcern">>, 1, DocList) of + {_, WC} -> WC; + _ -> not_found + end. + %% @private process_read_request(Request, From, State) -> #state{socket = Socket, diff --git a/src/connection/mc_worker_logic.erl b/src/connection/mc_worker_logic.erl index 73447367..166e7ebd 100644 --- a/src/connection/mc_worker_logic.erl +++ b/src/connection/mc_worker_logic.erl @@ -15,6 +15,8 @@ -export([decode_responses/1, process_responses/2, connect/1]). -export([make_request/4, get_resp_fun/2, update_dbcoll/2, collection/1, ensure_index/3]). +-dialyzer({no_fail_call, ensure_index/3}). + %% Make connection to database and return socket -spec connect(proplists:proplist()) -> {ok, port()} | {error, inet:posix()}. connect(Conf) -> @@ -30,6 +32,7 @@ connect(Conf) -> -spec make_request(gen_tcp:socket() | ssl:sslsocket(), atom(), mc_worker_api:database(), mongo_protocol:message() | list(mongo_protocol:message())) -> {ok | {error, any()}, integer(), pos_integer()}. + make_request(Socket, NetModule, Database, Request) -> {Packet, Id} = encode_request(Database, Request), {NetModule:send(Socket, Packet), iolist_size(Packet), Id}. @@ -37,11 +40,14 @@ make_request(Socket, NetModule, Database, Request) -> decode_responses(Data) -> decode_responses(Data, []). --spec get_resp_fun(#query{} | #getmore{} | #insert{} | #update{} | #delete{}, pid()) -> fun(). +-spec get_resp_fun(#query{} | #getmore{} | #insert{} | #update{} | #delete{} | #op_msg_command{} | #op_msg_write_op{}, + pid()) -> fun(). get_resp_fun(Read, From) when is_record(Read, query); is_record(Read, getmore) -> fun(Response) -> gen_server:reply(From, Response) end; get_resp_fun(Write, From) when is_record(Write, insert); is_record(Write, update); is_record(Write, delete) -> - process_write_response(From). + process_write_response(From); +get_resp_fun(OpMsg, From) when is_record(OpMsg, op_msg_write_op); is_record(OpMsg, op_msg_command) -> + process_op_msg_response(From). -spec process_responses(Responses :: list(), RequestStorage :: map()) -> UpdStorage :: map(). process_responses(Responses, RequestStorage) -> @@ -125,6 +131,11 @@ process_write_response(From) -> end end. +process_op_msg_response(From) -> + fun(#op_msg_response{} = OpMsg) -> + gen_server:reply(From, OpMsg) + end. + %% @private do_connect(Host, Port, Timeout, true, Opts) -> {ok, _} = application:ensure_all_started(ssl), diff --git a/src/connection/mongo_protocol.erl b/src/connection/mongo_protocol.erl index 7c0d35e0..3ee70f5d 100644 --- a/src/connection/mongo_protocol.erl +++ b/src/connection/mongo_protocol.erl @@ -19,11 +19,12 @@ -type reply() :: #reply{}. % message id -type requestid() :: integer(). --type message() :: notice() | request(). +-type message() :: notice() | request() | #op_msg_command{} | #op_msg_write_op{}. % RequestId expected to be in scope at call site -define(put_header(Opcode), ?put_int32(_RequestId), ?put_int32(0), ?put_int32(Opcode)). -define(get_header(Opcode, ResponseTo), ?get_int32(_RequestId), ?get_int32(ResponseTo), ?get_int32(Opcode)). +-define(get_header_ignore_req_id(Opcode, ResponseTo), ?get_int32(_), ?get_int32(ResponseTo), ?get_int32(Opcode)). -define(ReplyOpcode, 1). -define(UpdateOpcode, 2001). @@ -32,7 +33,14 @@ -define(GetmoreOpcode, 2005). -define(DeleteOpcode, 2006). -define(KillcursorOpcode, 2007). +-define(OpMsgOpcode, 2013). +-define(OpMsgDbFieldIndex, 4). + +-define (put_uint32 (N), (N):32/unsigned-little). +-define (put_uint8 (N), (N):8/unsigned-little). +-define (get_uint32 (N), N:32/unsigned-little). +-define (get_uint8 (N), N:8/unsigned-little). -spec dbcoll(database(), colldb()) -> bson:utf8(). @@ -82,12 +90,56 @@ put_message(Db, #getmore{collection = Coll, batchsize = Batch, cursorid = Cid}, ?put_int32(0), (bson_binary:put_cstring(dbcoll(Db, Coll)))/binary, ?put_int32(Batch), - ?put_int64(Cid)>>. + ?put_int64(Cid)>>; +put_message(Db, #op_msg_write_op{} = OpMsg, _RequestId) -> + << + ?put_header(?OpMsgOpcode), + ?put_uint32(0), % Flags + (put_section_type_zero(OpMsg#op_msg_write_op{database = make_bin(Db)}))/binary + >>; +put_message(Db, #op_msg_command{} = OpMsg, _RequestId) -> + << + ?put_header(?OpMsgOpcode), + ?put_uint32(0), % Flags + (put_section_type_zero(OpMsg#op_msg_command{database = make_bin(Db)}))/binary + >>. + +make_bin(Atom) when is_atom(Atom) -> + erlang:atom_to_binary(Atom, utf8); +make_bin(Bin) -> + Bin. +put_section_type_zero(#op_msg_command{ + command_doc = Doc, + database = Database +}) -> + << + ?put_uint8(0), + (bson_binary:put_document(bson:merge(Doc, {<<"$db">>, Database})))/binary + >>; +put_section_type_zero(#op_msg_write_op{ + command = Command, + collection = Collection, + database = Database, + extra_fields = ExtraFields, + documents_name = DocumentsName, + documents = Documents +}) -> + Msg = [ + {erlang:atom_to_binary(Command), Collection}, + {<<"$db">>, Database} + ] ++ + ExtraFields + ++ + [{DocumentsName, Documents}], + << + ?put_uint8(0), + (bson_binary:put_document(bson:document(Msg)))/binary + >>. -spec get_reply(binary()) -> {requestid(), reply(), binary()}. -get_reply(Message) -> - <> = Message) -> + < startingfrom = StartingFrom, documents = Docs }, - {ResponseTo, Reply, BinRest}. + {ResponseTo, Reply, BinRest}; +get_reply(<> = Message) -> + <> = Message, + %% For now assume the sequence type is zero + <> = Bin1, + {Doc, Rest} = bson_binary:get_map(Bin2), + Reply = #op_msg_response{ + response_doc = Doc + }, + {ResponseTo, Reply, Rest}. -spec binarize(binary() | atom()) -> binary(). %@doc Ensures the given term is converted to a UTF-8 binary. diff --git a/src/main/mc_super_sup.erl b/src/main/mc_super_sup.erl index 91f3bce0..485e0491 100644 --- a/src/main/mc_super_sup.erl +++ b/src/main/mc_super_sup.erl @@ -17,4 +17,5 @@ start_link() -> init(app) -> MongoIdServer = ?CHILD(mongo_id_server, worker), McPoolSup = ?CHILD(mc_pool_sup, supervisor), - {ok, {{one_for_one, 1000, 3600}, [MongoIdServer, McPoolSup]}}. + PIDInfoServer = ?CHILD(mc_worker_pid_info, worker), + {ok, {{one_for_one, 1000, 3600}, [MongoIdServer, McPoolSup, PIDInfoServer]}}. diff --git a/src/main/mc_worker_pid_info.erl b/src/main/mc_worker_pid_info.erl new file mode 100644 index 00000000..91f21bfd --- /dev/null +++ b/src/main/mc_worker_pid_info.erl @@ -0,0 +1,136 @@ +-module(mc_worker_pid_info). + +%% This module is used to get information (currently only the protocol type) +%% ragaring mc_worker processes. This is useful so we can encode messages in +%% the right way befor sending them to an mc_worker process. + +-behaviour(gen_server). + +-include("mongo_protocol.hrl"). + +-export([start_link/0, + init/1, + terminate/2, + handle_cast/2, + handle_call/3, + get_info/1, + set_info/2, + discard_info/1, + get_protocol_type/1, + handle_info/2, + install_mc_worker_info/4]). + +-dialyzer({no_fail_call, detect_protocol_type/4}). + +-define(CLEAN_TABLE_PERIOD_MINS, 30). +-define(CLEAN_TABLE_MESSAGE, clean_table). +-define(MC_WORKER_PID_INFO_TAB_NAME, mc_worker_pid_info_tab). + +-spec start_link() -> {ok, pid()}. +start_link() -> + gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). + +init(_) -> + ets:new(?MC_WORKER_PID_INFO_TAB_NAME, [public, named_table, {read_concurrency, true}]), + {ok, start_cleanup_timer_update_state(#{})}. + +start_cleanup_timer_update_state(State) -> + State#{ + cleanup_timer_ref => + erlang:start_timer(timer:minutes(?CLEAN_TABLE_PERIOD_MINS), + self(), + ?CLEAN_TABLE_MESSAGE, + []) + }. + +terminate(_,_) -> + ets:delete(?MC_WORKER_PID_INFO_TAB_NAME). + +%% These functions does not do anything as this server is just a holder of an ETS table +handle_cast(_Request, State) -> {noreply, State}. +handle_call(_Request, _From, State) -> {reply, {error, ignore}, State}. + +handle_info({timeout, + TimerRef, + ?CLEAN_TABLE_MESSAGE}, + #{cleanup_timer_ref := TimerRef} = State) -> + PidInfos = ets:tab2list(?MC_WORKER_PID_INFO_TAB_NAME), + [delete_pid_if_dead(Pid) || {Pid, _} <- PidInfos], + {noreply, start_cleanup_timer_update_state(State)}; +handle_info(_, State) -> + {noreply, State}. + +delete_pid_if_dead(Pid) -> + case erlang:is_process_alive(Pid) of + true -> ok; + false -> ets:delete(?MC_WORKER_PID_INFO_TAB_NAME, Pid) + end. + +get_info(MCWorkerPID) -> + try + case ets:lookup(?MC_WORKER_PID_INFO_TAB_NAME, MCWorkerPID) of + [{MCWorkerPID, InfoMap}] -> + {ok, InfoMap}; + [] -> + not_found + end + catch + _:_:_ -> + not_found + end. + + +set_info(MCWorkerPID, InfoMap) -> + ets:insert(?MC_WORKER_PID_INFO_TAB_NAME, {MCWorkerPID, InfoMap}). + +discard_info(MCWorkerPID) -> + ets:delete(?MC_WORKER_PID_INFO_TAB_NAME, MCWorkerPID). + +get_protocol_type(MCWorkerPID) -> + case get_info(MCWorkerPID) of + {ok, #{protocol_type := ProtocolType}} -> + ProtocolType; + _ -> + %% Not found means that this library has been hot upgraded and the + %% mc_worker process was created before the hot_upgrade so we use + %% the legacy protocol as this was what existed before + legacy + end. + +%% This process should be called from mc_worker processes to install their info +%% in the ?MC_WORKER_PID_INFO_TAB_NAME ETS table +install_mc_worker_info(Options, NetModule, Database, Opts) -> + UseLegacyProtocol = maps:get(use_legacy_protocol, Opts), + try + ProtocolType = detect_protocol_type(UseLegacyProtocol, Options, NetModule, Database), + mc_worker_pid_info:set_info(self(), #{protocol_type => ProtocolType}), + ok + catch + What:Reason -> + {error, {What, Reason}} + end. + +%% true - use the legacy message protocol +%% false - use the modern message protocol based on the op_msg opcode +%% auto - detect which message protocol to use based on the database +detect_protocol_type(true, _ConnectionOpts, _NetModule, _Database) -> legacy; +detect_protocol_type(false, _ConnectionOpts, _NetModule, _Database) -> op_msg; +detect_protocol_type(auto, ConnectionOpts, NetModule, Database) -> + %% Automatically detect which protocol to use. We send a `hello' command using + %% the new protocol. If we have our connection closed by the server, it means + %% it doesn't support the new protocol. + %% See also: + %% * https://github.com/mongodb/mongo/blob/5e494138af456f42381ad08748cc7fbc4ace7a60/src/mongo/base/error_codes.yml + %% * https://www.mongodb.com/docs/manual/reference/command/hello/#mongodb-dbcommand-dbcmd.hello + %% * https://www.mongodb.com/docs/manual/reference/mongodb-wire-protocol/#std-label-wire-op-msg + {ok, Socket} = mc_worker_logic:connect(ConnectionOpts), + Command = bson:fields({hello, 1}), + Request = #op_msg_command{command_doc = Command, database = Database}, + try mc_connection_man:request_raw_no_parse(Socket, Database, Request, NetModule) of + [{_, #op_msg_response{response_doc = #{<<"ok">> := _}}} | _] -> op_msg; + _ErrorResponse -> legacy + catch + _:_ -> legacy + after + NetModule:close(Socket) + end. diff --git a/src/mongoc/mc_monitor.erl b/src/mongoc/mc_monitor.erl index 5324e59b..0b535323 100644 --- a/src/mongoc/mc_monitor.erl +++ b/src/mongoc/mc_monitor.erl @@ -153,11 +153,16 @@ maybe_recheck(_, Topology, Server, ConnectArgs, HB_MS, MinHB_MS) -> check(ConnectArgs, Server) -> Start = os:timestamp(), {ok, Conn} = mc_worker_api:connect(ConnectArgs), - {true, IsMaster} = mc_worker_api:command(Conn, {isMaster, 1}), + {true, IsMaster} = is_master_check(mc_utils:use_legacy_protocol(Conn), Conn), Finish = os:timestamp(), mc_worker_api:disconnect(Conn), {monitor_ismaster, Server, IsMaster, timer:now_diff(Finish, Start)}. +is_master_check(true, Connection) -> + mc_worker_api:command(Connection, {isMaster, 1}); +is_master_check(false, Connection) -> + mc_worker_api:command(Connection, {hello, 1}). + %% @private do_timeout(Pid, TO) when TO > 0 -> receive diff --git a/src/mongoc/mc_topology.erl b/src/mongoc/mc_topology.erl index 561a3510..5f52b1c6 100644 --- a/src/mongoc/mc_topology.erl +++ b/src/mongoc/mc_topology.erl @@ -218,13 +218,17 @@ parse_ismaster(Server, IsMaster, RTT, State = #topology_state{servers = Tab}) -> arbiters = maps:get(<<"arbiters">>, IsMaster, []), electionId = maps:get(<<"electionId">>, IsMaster, undefined), primary = maps:get(<<"primary">>, IsMaster, undefined), - ismaster = maps:get(<<"ismaster">>, IsMaster, undefined), + ismaster = is_master(IsMaster), secondary = maps:get(<<"secondary">>, IsMaster, undefined) }, ets:insert(Tab, ToUpdate), mc_server:update_ismaster(ToUpdate#mc_server.pid, {SType, ToUpdate}), mc_topology_logics:update_topology_state(ToUpdate, State). +is_master(#{<<"ismaster">> := V} = _IsMaster) -> V; +is_master(#{<<"isWritablePrimary">> := V} = _IsMaster) -> V; +is_master(_) -> undefined. + %% @private parse_rtt(_, undefined, RTT) -> {RTT, RTT}; parse_rtt(OldRTT, CurRTT, RTT) -> @@ -234,8 +238,12 @@ parse_rtt(OldRTT, CurRTT, RTT) -> %% @private server_type(#{<<"ismaster">> := true, <<"secondary">> := false, <<"setName">> := _}) -> rsPrimary; +server_type(#{<<"isWritablePrimary">> := true, <<"secondary">> := false, <<"setName">> := _}) -> + rsPrimary; server_type(#{<<"ismaster">> := false, <<"secondary">> := true, <<"setName">> := _}) -> rsSecondary; +server_type(#{<<"isWritablePrimary">> := false, <<"secondary">> := true, <<"setName">> := _}) -> + rsSecondary; server_type(#{<<"arbiterOnly">> := true, <<"setName">> := _}) -> rsArbiter; server_type(#{<<"hidden">> := true, <<"setName">> := _}) -> @@ -246,6 +254,8 @@ server_type(#{<<"msg">> := <<"isdbgrid">>}) -> mongos; server_type(#{<<"isreplicaset">> := true}) -> rsGhost; +server_type(#{<<"isWritablePrimary">> := true}) -> + standalone; server_type(#{<<"ok">> := _}) -> unknown; server_type(_) -> diff --git a/src/support/mc_utils.erl b/src/support/mc_utils.erl index f3993489..6fad93ec 100644 --- a/src/support/mc_utils.erl +++ b/src/support/mc_utils.erl @@ -31,7 +31,9 @@ hmac/2, is_proplist/1, to_binary/1, - get_srv_seeds/1]). + get_srv_seeds/1, + use_legacy_protocol/1, + get_connection_pid/1]). get_value(Key, List) -> get_value(Key, List, undefined). @@ -122,3 +124,18 @@ valid_endpoint(Host, Srv) -> [_ | HostBaseDomain] = string:split(Host, "."), [_ | SrvBaseDomain] = string:split(Srv, "."), HostBaseDomain == SrvBaseDomain. + +use_legacy_protocol(Connection) -> + %% Latest MongoDB version that supported the non-op-msg based opcodes was + %% 5.0.x (at the time of writing 5.0.14). The non-op-msg based opcodes were + %% removed in MongoDB version 5.1.0. See + %% https://www.mongodb.com/docs/manual/legacy-opcodes/ + case mc_worker_pid_info:get_protocol_type(Connection) of + legacy -> true; + op_msg -> false + end. + +get_connection_pid(Connection) when is_pid(Connection) -> + Connection; +get_connection_pid(#{connection_pid := Pid}) -> + Pid. diff --git a/test/mc_test_utils.erl b/test/mc_test_utils.erl index dce30b91..fcd311f2 100644 --- a/test/mc_test_utils.erl +++ b/test/mc_test_utils.erl @@ -10,14 +10,15 @@ -author("tihon"). %% API --export([collection/1]). +-export([collection/2]). -collection(Case) -> +collection(Mod, Case) -> Now = now_to_seconds(os:timestamp()), <<(atom_to_binary(?MODULE, utf8))/binary, $-, + (atom_to_binary(Mod, utf8))/binary, $-, (atom_to_binary(Case, utf8))/binary, $-, (list_to_binary(integer_to_list(Now)))/binary>>. %% @private now_to_seconds({Mega, Sec, _}) -> - (Mega * 1000000) + Sec. \ No newline at end of file + (Mega * 1000000) + Sec. diff --git a/test/mc_worker_api_SUITE.erl b/test/mc_worker_api_SUITE.erl index 4e607900..0befd43f 100644 --- a/test/mc_worker_api_SUITE.erl +++ b/test/mc_worker_api_SUITE.erl @@ -6,7 +6,7 @@ -include("mongo_protocol.hrl"). --compile(export_all). +-compile([export_all, nowarn_export_all]). all() -> [ @@ -31,7 +31,7 @@ end_per_suite(_Config) -> init_per_testcase(Case, Config) -> {ok, Connection} = mc_worker_api:connect([{database, ?config(database, Config)}, {login, <<"user">>}, {password, <<"test">>}, {w_mode, safe}]), - [{connection, Connection}, {collection, mc_test_utils:collection(Case)} | Config]. + [{connection, Connection}, {collection, mc_test_utils:collection(?MODULE, Case)} | Config]. end_per_testcase(_Case, Config) -> Connection = ?config(connection, Config), @@ -65,7 +65,6 @@ insert_and_find(Config) -> 4 = mc_worker_api:count(Connection, Collection, #{}), {ok, TeamsCur} = mc_worker_api:find(Connection, Collection, #{}), TeamsFound = mc_cursor:rest(TeamsCur), - undefined = process_info(TeamsCur), ?assertEqual(Teams, TeamsFound), {ok, NationalTeamsCur} = mc_worker_api:find( @@ -134,6 +133,12 @@ insert_and_delete(Config) -> mc_worker_api:delete_one(Connection, Collection, #{}), 3 = mc_worker_api:count(Connection, Collection, #{}), + + mc_worker_api:delete_limit(Connection, Collection, #{}, 1), + 2 = mc_worker_api:count(Connection, Collection, #{}), + + mc_worker_api:delete(Connection, Collection, #{}), + 0 = mc_worker_api:count(Connection, Collection, #{}), Config. insert_map(Config) -> diff --git a/test/mongo_api_SUITE.erl b/test/mongo_api_SUITE.erl index 7bc77ca8..d257ae5d 100644 --- a/test/mongo_api_SUITE.erl +++ b/test/mongo_api_SUITE.erl @@ -4,11 +4,12 @@ -include_lib("common_test/include/ct.hrl"). -include_lib("eunit/include/eunit.hrl"). --compile(export_all). +-compile([export_all, nowarn_export_all]). all() -> [ ensure_index_test, + per_connection_protocol_type_test, count_test, find_one_test, find_test, @@ -25,7 +26,7 @@ end_per_suite(_Config) -> init_per_testcase(Case, Config) -> {ok, Pid} = mongo_api:connect(single, ["localhost:27017"], [{pool_size, 1}, {max_overflow, 0}], [{database, ?config(database, Config)}, {login, <<"user">>}, {password, <<"test">>}]), - [{connection, Pid}, {collection, mc_test_utils:collection(Case)} | Config]. + [{connection, Pid}, {collection, mc_test_utils:collection(?MODULE, Case)} | Config]. end_per_testcase(_Case, Config) -> Connection = ?config(connection, Config), @@ -33,6 +34,15 @@ end_per_testcase(_Case, Config) -> mongo_api:delete(Connection, Collection, #{}), mongo_api:disconnect(Connection). +parse_mongo_version(StrVersion) -> + Segments = string:split(StrVersion, ".", all), + case lists:map(fun list_to_integer/1, Segments) of + [Major, Minor] -> + {Major, Minor, 0}; + [Major, Minor, Patch] -> + {Major, Minor, Patch} + end. + %% Tests ensure_index_test(Config) -> Pid = ?config(connection, Config), @@ -41,6 +51,38 @@ ensure_index_test(Config) -> ok = mongo_api:ensure_index(Pid, Collection, {<<"key">>, {<<"z_first">>, 1, <<"a_last">>, 1}}), Config. +%% regardless of application env, we can set the protocol type per connection +per_connection_protocol_type_test(Config) -> + {ok, MCWorkerConnection0} = + mc_worker:start_link([{database, ?config(database, Config)} + , {w_mode, safe} + , {use_legacy_protocol, true} + ]), + ?assert(mc_utils:use_legacy_protocol(MCWorkerConnection0)), + MRef0 = monitor(process, MCWorkerConnection0), + mc_worker:disconnect(MCWorkerConnection0), + receive + {'DOWN', MRef0, process, MCWorkerConnection0, _} -> + ok + after + 1_000 -> ct:fail("worker didn't halt") + end, + {ok, MCWorkerConnection1} = + mc_worker:start_link([{database, ?config(database, Config)} + , {w_mode, safe} + , {use_legacy_protocol, false} + ]), + ?assertNot(mc_utils:use_legacy_protocol(MCWorkerConnection1)), + MRef1 = monitor(process, MCWorkerConnection1), + mc_worker:disconnect(MCWorkerConnection1), + receive + {'DOWN', MRef1, process, MCWorkerConnection1, _} -> + ok + after + 1_000 -> ct:fail("worker didn't halt") + end, + ok. + count_test(Config) -> Collection = ?config(collection, Config), Pid = ?config(connection, Config), diff --git a/test/switch_db_SUITE.erl b/test/switch_db_SUITE.erl index db0d2abb..9f30ed29 100644 --- a/test/switch_db_SUITE.erl +++ b/test/switch_db_SUITE.erl @@ -33,7 +33,7 @@ end_per_suite(_Config) -> init_per_testcase(Case, Config) -> {ok, Connection} = mc_worker_api:connect([{database, ?config(database, Config)}, {login, <<"user">>}, {password, <<"test">>}, {w_mode, safe}]), - [{connection, Connection}, {collection, mc_test_utils:collection(Case)} | Config]. + [{connection, Connection}, {collection, mc_test_utils:collection(?MODULE, Case)} | Config]. end_per_testcase(_Case, Config) -> Connection = ?config(connection, Config),