diff --git a/.github/workflows/dialyzer.yml b/.github/workflows/dialyzer.yml index c821eee0..8049b0d2 100644 --- a/.github/workflows/dialyzer.yml +++ b/.github/workflows/dialyzer.yml @@ -7,14 +7,14 @@ on: jobs: dialyzer: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 container: image: erlang:24-slim steps: - - uses: actions/checkout@v2.0.0 + - uses: actions/checkout@v4 - name: Cache PLTs id: cache-plts - uses: actions/cache@v2 + uses: actions/cache@v4 with: path: ~/.cache/rebar3/ key: ${{ runner.os }}-erlang-${{ hashFiles(format('{0}{1}', github.workspace, '/rebar.lock')) }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ff030ac5..785e9abf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -7,37 +7,37 @@ on: jobs: test: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 strategy: matrix: - erlang: [22, 23] + erlang: [24, 25] mongodb: ["4.4.8", "5.0.2"] container: image: erlang:${{ matrix.erlang }} steps: - - uses: actions/checkout@v2.0.0 + - uses: actions/checkout@v4 - run: ./scripts/install_mongo_debian.sh ${{ matrix.mongodb }} - run: ./scripts/start_mongo_single_node.sh - run: ./scripts/start_mongo_cluster.sh - run: ./rebar3 eunit - run: ./rebar3 ct - name: Archive Replica Set Logs - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: failure() with: - name: mongodb_replica_set_logs + name: erlang-${{ matrix.erlang }}-mongodb-${{ matrix.mongodb }}-mongodb_replica_set_logs path: rs0-logs retention-days: 1 - name: Archive Single Node Log - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: failure() with: - name: single_node.log + name: erlang-${{ matrix.erlang }}-mongodb-${{ matrix.mongodb }}-single_node.log path: single_node.log retention-days: 1 - name: CT Logs - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: - name: ct_logs + name: erlang-${{ matrix.erlang }}-mongodb-${{ matrix.mongodb }}-ct_logs path: _build/test/logs/ retention-days: 5 diff --git a/.github/workflows/test_coverage.yml b/.github/workflows/test_coverage.yml index 76ee41d0..449a7b92 100644 --- a/.github/workflows/test_coverage.yml +++ b/.github/workflows/test_coverage.yml @@ -7,11 +7,11 @@ on: jobs: test_coverage: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 container: - image: erlang:23 + image: erlang:24 steps: - - uses: actions/checkout@v2.0.0 + - uses: actions/checkout@v4 - run: ./scripts/install_mongo_debian.sh 5.0.2 - run: ./scripts/start_mongo_single_node.sh - run: ./scripts/start_mongo_cluster.sh @@ -19,27 +19,27 @@ jobs: - run: ./rebar3 ct --cover --cover_export_name ct.coverdata - run: rebar3 cover --verbose - name: Archive Replica Set Logs - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: failure() with: name: mongodb_replica_set_logs path: rs0-logs retention-days: 1 - name: Archive Single Node Log - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: failure() with: name: single_node.log path: single_node.log retention-days: 1 - name: Coverage Report - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: Coverage Report path: _build/test/cover/ retention-days: 5 - name: CT Logs - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: ct_logs path: _build/test/logs/ diff --git a/.gitignore b/.gitignore index 77179f06..95249589 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,8 @@ variables-ct* *.iml data _build -.idea \ No newline at end of file +.idea +*.log +rebar.lock +*.crashdump +*.plt diff --git a/Makefile b/Makefile index d62b6a1a..14ce2ce0 100644 --- a/Makefile +++ b/Makefile @@ -40,7 +40,7 @@ report-rebar3-version: # Dialyzer. .$(PROJECT).plt: - @$(DIALYZER) --build_plt --output_plt .$(PROJECT).plt -r deps \ + @$(DIALYZER) --build_plt --output_plt .$(PROJECT).plt -r _build/default/lib/ \ --apps erts kernel stdlib sasl inets crypto public_key ssl mnesia syntax_tools asn1 clean-plt: @@ -48,8 +48,8 @@ clean-plt: build-plt: clean-plt .$(PROJECT).plt -dialyze: .$(PROJECT).plt - @$(DIALYZER) -I include -I deps --src -r src --plt .$(PROJECT).plt --no_native \ - -Werror_handling -Wrace_conditions -Wunmatched_returns +dialyzer: .$(PROJECT).plt + @$(DIALYZER) -I include -I _build/default/lib/ --src -r src --plt .$(PROJECT).plt --no_native \ + -Werror_handling -Wrace_conditions -Wunmatched_returns --get_warnings -.PHONY: deps clean-plt build-plt dialyze report-rebar3-version +.PHONY: deps clean-plt build-plt dialyzer report-rebar3-version diff --git a/README.md b/README.md index cc705573..4154f82d 100644 --- a/README.md +++ b/README.md @@ -42,9 +42,30 @@ If you are choosing between using [mongos](https://docs.mongodb.com/manual/reference/program/mongos/) and using mongo shard with `mongo_api` - prefer mongos and use `mc_worker_api`. +In MongoDB 3.6, the `OP_MSG` and `OP_COMPRESSED` opcodes were added to the wired +message protocol. `OP_MSG` was added to standardize the MongoDB message format, +and `OP_COMPRESSED` takes things a step further by compressing the message to increase +network efficiency. As of MongoDB 5.1, the old opcodes were deprecated, and all +messages are sent with either `OP_MSG` or `OP_COMPRESSED`. + +By default, this driver tries to automatically detect which MongoDB messaging protocol +to use - the legacy one that existed before the `OP_MSG` opcode was introduced, or the +modern messaging protocol based on `OP_MSG`. This is accomplished by setting the default +value for the `use_legacy_protocol` env variable to `auto`. One can force the driver to +use the legacy opcodes or the `OP_MSG`opcode by setting the application env +`use_legacy_protocol` to `true` or `false` (for example by calling +`application:set_env(mongodb, use_legacy_protocol, false)`). + +It's also possible to define the usage of legacy protocol on a per-connection basis. +Simple pass `{use_legacy_protocol, false | true}` to mc_worker start options. + +As the `OP_MSG` opcode has existed in MongoDB since 3.6, and the driver defaults to auto- +detecting which set of opcodes to use, ALL of the messages will be sent using the `OP_MSG` +opcode if you are using a version for MongoDB greater 3.6 unless you explicitly set the +`use_legacy_protocol` variable to `true`. + mc_worker_api -- direct connection client --------------------------------- - ### Connecting To connect to a database `test` on mongodb server listening on `localhost:27017` (or any address & port of your choosing) diff --git a/include/mongo_protocol.hrl b/include/mongo_protocol.hrl index 45be0124..e4499893 100644 --- a/include/mongo_protocol.hrl +++ b/include/mongo_protocol.hrl @@ -8,6 +8,7 @@ -type colldb() :: collection() | {database(), collection()}. -type collection() :: binary() | atom(). % without db prefix -type database() :: binary() | atom(). +-type command() :: insert | update | delete. %% write @@ -45,6 +46,24 @@ projector = #{} :: mc_worker_api:projector() }). +-record(op_msg_write_op, { + command :: command(), + collection :: colldb(), + database :: undefined | mc_worker_api:database(), + extra_fields = [] :: bson:document() | nonempty_list({binary(),any()}), + documents_name = <<"documents">> :: bson:utf8(), + documents = [] :: any() +}). + +-record(op_msg_response, { + response_doc :: map() +}). + +-record(op_msg_command, { + database :: undefined | mc_worker_api:database(), + command_doc :: bson:document() | nonempty_list({binary(),any()}) +}). + -record(getmore, { collection :: colldb(), batchsize = 0 :: mc_worker_api:batchsize(), @@ -73,9 +92,9 @@ -record(reply, { cursornotfound :: boolean(), queryerror :: boolean(), - awaitcapable = false :: boolean(), + awaitcapable = false :: boolean() | undefined, cursorid :: mc_worker_api:cursorid(), - startingfrom = 0 :: integer(), + startingfrom = 0 :: integer() | undefined, documents :: [map()] }). -endif. diff --git a/include/mongo_types.hrl b/include/mongo_types.hrl index e2cab5ea..ffe0b361 100644 --- a/include/mongo_types.hrl +++ b/include/mongo_types.hrl @@ -31,11 +31,24 @@ | {ssl, boolean()} | {ssl_opts, proplists:proplist()} | {register, atom() | fun()}. +-type socket() :: gen_tcp:socket() | ssl:sslsocket(). -type write_mode() :: unsafe | safe | {safe, bson:document()}. -type read_mode() :: master | slave_ok. -type service() :: {Host :: inet:hostname() | inet:ip_address(), Post :: 0..65535}. -type options() :: [option()]. -type option() :: {timeout, timeout()} | {ssl, boolean()} | ssl | {database, database()} | {read_mode, read_mode()} | {write_mode, write_mode()}. -type cursor() :: pid(). --type query() :: #query{}. --endif. \ No newline at end of file +-type query() :: #'query'{}. +-type op_msg_command() :: #op_msg_command{}. +-type op_msg_write_op() :: #op_msg_write_op{}. +-type op_msg_response() :: #op_msg_response{}. +-type request() :: query() +| op_msg_command() +| op_msg_write_op() +| #killcursor{} +| #insert{} +| #update{} +| #delete{} +| #getmore{} +| #ensure_index{}. +-endif. diff --git a/include/mongoc.hrl b/include/mongoc.hrl index 148ad2b9..63b669f9 100644 --- a/include/mongoc.hrl +++ b/include/mongoc.hrl @@ -76,8 +76,8 @@ | {password, binary()} | {w_mode, mc_worker_api:write_mode()}. -type readprefs() :: [readpref()]. --type readpref() :: #{rp_mode => readmode()} +-type readpref() :: #{rp_mode => readmode()} | #{mode => binary()} | #{mode => binary(), tags => tuple()} |{rp_tags, [tuple()]}. -type reason() :: atom(). --endif. \ No newline at end of file +-endif. diff --git a/src/api/mc_worker_api.erl b/src/api/mc_worker_api.erl index 74443956..16922520 100644 --- a/src/api/mc_worker_api.erl +++ b/src/api/mc_worker_api.erl @@ -107,10 +107,23 @@ insert(Connection, Coll, Doc, WC, DB) when is_tuple(Doc); is_map(Doc) -> {Res, [UDoc | _]} = insert(Connection, Coll, [Doc], WC, DB), {Res, UDoc}; insert(Connection, Coll, Docs, WC, DB) -> - Converted = prepare(Docs, fun assign_id/1), - {command(DB, Connection, {<<"insert">>, Coll, <<"documents">>, Converted, <<"writeConcern">>, WC}), Converted}. - -%% @doc Insert one document or multiple documents into a colleciton. + ConvertedDocs = prepare(Docs, fun assign_id/1), + UseLegacyProtocol = mc_utils:use_legacy_protocol(Connection), + insert(UseLegacyProtocol, Connection, Coll, WC, DB, ConvertedDocs). + +insert(true, Connection, Coll, WriteConcern, DB, ConvertedDocs) -> + Command = + command(DB, Connection, {<<"insert">>, Coll, <<"documents">>, ConvertedDocs, <<"writeConcern">>, WriteConcern}), + {Command, ConvertedDocs}; +insert(false, Connection, Coll, WriteConcern, _DB, ConvertedDocs) -> + Msg = #op_msg_write_op{ + command = insert, + collection = Coll, + extra_fields = [{<<"writeConcern">>, WriteConcern}], + documents = ConvertedDocs}, + {mc_connection_man:op_msg(Connection, Msg), ConvertedDocs}. + +%% @doc Insert one document or multiple documents into a collection. %% params: %% connection - mc_worker pid %% collection - collection() @@ -132,9 +145,7 @@ update(Connection, Coll, Selector, Doc) -> %% @doc Replace the document matching criteria entirely with the new Document. -spec update(pid(), collection(), selector(), map() | bson:document(), boolean(), boolean()) -> {boolean(), map()}. update(Connection, Coll, Selector, Doc, Upsert, MultiUpdate) -> - Converted = prepare(Doc, fun(D) -> D end), - command(Connection, {<<"update">>, Coll, <<"updates">>, - [#{<<"q">> => Selector, <<"u">> => Converted, <<"upsert">> => Upsert, <<"multi">> => MultiUpdate}]}). + update(Connection, Coll, Selector, Doc, Upsert, MultiUpdate, {<<"w">>, 1}). %% @deprecated %% @doc Replace the document matching criteria entirely with the new Document. @@ -145,11 +156,26 @@ update(Connection, Coll, Selector, Doc, Upsert, MultiUpdate, WC) -> %% @deprecated %% @doc Replace the document matching criteria entirely with the new Document. -spec update(pid(), collection(), selector(), map() | bson:document(), boolean(), boolean(), bson:document(), database()) -> {boolean(), map()}. -update(Connection, Coll, Selector, Doc, Upsert, MultiUpdate, WC, DB) -> - Converted = prepare(Doc, fun(D) -> D end), - command(DB, Connection, {<<"update">>, Coll, <<"updates">>, - [#{<<"q">> => Selector, <<"u">> => Converted, <<"upsert">> => Upsert, <<"multi">> => MultiUpdate}], - <<"writeConcern">>, WC}). +update(Connection, Coll, Selector, Doc, Upsert, MultiUpdate, WriteConcern, DB) -> + ConvertedDocs = prepare(Doc, fun(D) -> D end), + UseLegacyProtocol = mc_utils:use_legacy_protocol(Connection), + update(UseLegacyProtocol, Connection, Coll, Selector, ConvertedDocs, Upsert, MultiUpdate, WriteConcern, DB). + +update(true, Connection, Coll, Selector, ConvertedDocs, Upsert, MultiUpdate, WriteConcern, DB) -> + command(DB, Connection, {<<"update">>, Coll, <<"updates">>, [#{<<"q">> => Selector, <<"u">> => ConvertedDocs, + <<"upsert">> => Upsert, <<"multi">> => MultiUpdate}], <<"writeConcern">>, WriteConcern}); +update(false, Connection, Coll, Selector, ConvertedDocs, Upsert, MultiUpdate, WriteConcern, _DB) -> + Msg = #op_msg_write_op{ + command = update, + collection = Coll, + extra_fields = [{<<"writeConcern">>, WriteConcern}], + documents_name = <<"updates">>, + documents = [#{ + <<"q">> => Selector, + <<"u">> => ConvertedDocs, + <<"upsert">> => Upsert, + <<"multi">> => MultiUpdate}]}, + mc_connection_man:op_msg(Connection, Msg). %% @doc Replace the document matching criteria entirely with the new Document. %% params: @@ -183,22 +209,35 @@ delete_one(Connection, Coll, Selector) -> %% @deprecated %% @doc Delete selected documents -spec delete_limit(pid(), collection(), selector(), integer()) -> {boolean(), map()}. -delete_limit(Connection, Coll, Selector, N) -> - command(Connection, {<<"delete">>, Coll, <<"deletes">>, - [#{<<"q">> => Selector, <<"limit">> => N}]}). +delete_limit(Connection, Coll, Selector, Limit) -> + delete_limit(Connection, Coll, Selector, Limit, undefined). %% @deprecated %% @doc Delete selected documents --spec delete_limit(pid(), collection(), selector(), integer(), bson:document()) -> {boolean(), map()}. -delete_limit(Connection, Coll, Selector, N, WC) -> - delete_limit(Connection, Coll, Selector, N, WC, undefined). +-spec delete_limit(pid(), collection(), selector(), integer(), bson:document() | undefined) -> {boolean(), map()}. +delete_limit(Connection, Coll, Selector, Limit, undefined) -> + delete_limit(Connection, Coll, Selector, Limit, {<<"w">>, 1}, undefined); +delete_limit(Connection, Coll, Selector, Limit, WriteConcern) -> + delete_limit(Connection, Coll, Selector, Limit, WriteConcern, undefined). %% @deprecated %% @doc Delete selected documents -spec delete_limit(pid(), collection(), selector(), integer(), bson:document(), database()) -> {boolean(), map()}. -delete_limit(Connection, Coll, Selector, N, WC, DB) -> +delete_limit(Connection, Coll, Selector, Limit, WriteConcern, DB) -> + UseLegacyProtocol = mc_utils:use_legacy_protocol(Connection), + delete_limit(UseLegacyProtocol, Connection, Coll, Selector, Limit, WriteConcern, DB). + +delete_limit(true, Connection, Coll, Selector, Limit, WriteConcern, DB) -> command(DB, Connection, {<<"delete">>, Coll, <<"deletes">>, - [#{<<"q">> => Selector, <<"limit">> => N}], <<"writeConcern">>, WC}). + [#{<<"q">> => Selector, <<"limit">> => Limit}], <<"writeConcern">>, WriteConcern}); +delete_limit(false, Connection, Coll, Selector, Limit, WriteConcern, _DB) -> + Msg = #op_msg_write_op{command = delete, + collection = Coll, + extra_fields = [{<<"writeConcern">>, WriteConcern}], + documents_name = <<"deletes">>, + documents = [#{<<"q">> => Selector, + <<"limit">> => Limit}]}, + mc_connection_man:op_msg(Connection, Msg). %% @doc Delete selected documents %% params: @@ -230,6 +269,10 @@ find_one(Connection, Coll, Selector, Args, Db) -> Projector = maps:get(projector, Args, #{}), Skip = maps:get(skip, Args, 0), ReadPref = maps:get(readopts, Args, #{<<"mode">> => <<"primary">>}), + UseLegacyProtocol = mc_utils:use_legacy_protocol(Connection), + find_one(UseLegacyProtocol, Connection, Coll, Selector, Projector, Skip, ReadPref, Db). + +find_one(true, Connection, Coll, Selector, Projector, Skip, ReadPref, Db) -> find_one(Connection, #'query'{ database = Db, @@ -237,6 +280,21 @@ find_one(Connection, Coll, Selector, Args, Db) -> selector = mongoc:append_read_preference(Selector, ReadPref), projector = Projector, skip = Skip + }); +find_one(false, Connection, Coll, Selector, Projector, Skip, ReadPref, _Db) -> + CommandDoc = [ + {<<"find">>, Coll}, + {<<"$readPreference">>, ReadPref}, + {<<"filter">>, Selector}, + {<<"projection">>, Projector}, + {<<"skip">>, Skip}, + {<<"batchSize">>, 1}, + {<<"limit">>, 1}, + {<<"singleBatch">>, true} %% Close cursor after first batch + ], + mc_connection_man:op_msg_read_one(Connection, + #'op_msg_command'{ + command_doc = CommandDoc }). %% @doc Return projection of selected documents. @@ -254,7 +312,21 @@ find_one(Cmd = #{connection := Connection, collection := Collection, selector := -spec find_one(pid() | atom(), query()) -> map() | undefined. find_one(Connection, Query) when is_record(Query, query) -> - mc_connection_man:read_one(Connection, Query). + UseLegacyProtocol = mc_utils:use_legacy_protocol(Connection), + find_one_with_query_record(UseLegacyProtocol, Connection, Query). + +find_one_with_query_record(true, Connection, Query) -> + mc_connection_man:read_one(Connection, Query); +find_one_with_query_record(false, Connection, Query) -> + #'query'{collection = Coll, + skip = Skip, + selector = Selector, + projector = Projector} = Query, + {RP, NewSelector, _} = mongoc:extract_read_preference(Selector), + Args = #{projector => Projector, + skip => Skip, + readopts => RP}, +find_one(Connection, Coll, NewSelector, Args). %% @deprecated %% @doc Return selected documents. @@ -274,7 +346,13 @@ find(Connection, Coll, Selector, Args) -> find(Connection, Coll, Selector, Args, Db) -> Projector = maps:get(projector, Args, #{}), Skip = maps:get(skip, Args, 0), - BatchSize = maps:get(batchsize, Args, 0), + BatchSize = + case mc_utils:use_legacy_protocol(Connection) of + true -> + maps:get(batchsize, Args, 0); + false -> + maps:get(batchsize, Args, 101) + end, ReadPref = maps:get(readopts, Args, #{<<"mode">> => <<"primary">>}), find(Connection, #'query'{ @@ -304,12 +382,49 @@ find(Cmd = #{connection := Connection, collection := Collection, selector := Sel -spec find(pid() | atom(), query()) -> {ok, cursor()} | []. find(Connection, Query) when is_record(Query, query) -> - case mc_connection_man:read(Connection, Query) of + FixedQuery = fixed_query(mc_utils:use_legacy_protocol(Connection), Query), + case mc_connection_man:read(Connection, FixedQuery) of [] -> []; {ok, Cursor} when is_pid(Cursor) -> {ok, Cursor} end. +fixed_query(true, Query) -> + Query; +fixed_query(false, Query) -> + #'query'{collection = Coll, + skip = Skip, + selector = Selector, + batchsize = BatchSize, + projector = Projector} = Query, + {ReadPref, NewSelector, OrderBy} = mongoc:extract_read_preference(Selector), + %% We might need to do some transformations: + %% See: https://github.com/mongodb/specifications/blob/master/source/find_getmore_killcursors_commands.rst#mapping-op-query-behavior-to-the-find-command-limit-and-batchsize-fields + SingleBatch = BatchSize < 0, + AbsBatchSize = erlang:abs(BatchSize), + BatchSizeField = batch_size(AbsBatchSize =:= 0, AbsBatchSize), + SingleBatchField = single_batch(SingleBatch), + SortField = sort_field(OrderBy), + CommandDoc = [ + {<<"find">>, Coll}, + {<<"$readPreference">>, ReadPref}, + {<<"filter">>, NewSelector}, + {<<"projection">>, Projector}, + {<<"skip">>, Skip} + ] ++ SortField + ++ BatchSizeField + ++ SingleBatchField, + #op_msg_command{command_doc = CommandDoc}. + +batch_size(true, _BatchSize) -> []; +batch_size(false, BatchSize) -> [{<<"batchSize">>, BatchSize}]. + +single_batch(true) -> []; +single_batch(false = SingleBatch) -> [{<<"singleBatch">>, SingleBatch}]. + +sort_field(OrderBy) when is_map(OrderBy), map_size(OrderBy) =:= 0 -> []; +sort_field(OrderBy) -> [{<<"sort">>, OrderBy}]. + %% @deprecated %% @doc Count selected documents -spec count(pid(), collection(), selector()) -> integer(). @@ -358,7 +473,12 @@ ensure_index(Connection, Coll, IndexSpec) -> -spec ensure_index(pid(), colldb(), bson:document(), database()) -> ok | {error, any()}. ensure_index(Connection, Coll, IndexSpec, DB) -> - mc_connection_man:request_worker(Connection, #ensure_index{database = DB, collection = Coll, index_spec = IndexSpec}). + ensure_index(mc_utils:use_legacy_protocol(Connection), Connection, Coll, IndexSpec, DB). + +ensure_index(true, Connection, Coll, IndexSpec, DB) -> + mc_connection_man:request_worker(Connection, #ensure_index{database = DB, collection = Coll, index_spec = IndexSpec}); +ensure_index(false, Connection, Coll, IndexSpec, DB) -> + command(DB, Connection, {<<"createIndexes">>, Coll, <<"indexes">>, IndexSpec}). %% @doc Execute given MongoDB command and return its result. -spec command(pid(), selector()) -> {boolean(), map()} | {ok, cursor()}. @@ -374,7 +494,6 @@ command(Db, Connection, Command) -> command(Db, Connection, Command, IsSlaveOk) -> mc_connection_man:database_command(Connection, Db, Command, IsSlaveOk). - %% @private -spec prepare(tuple() | list() | map(), fun()) -> list(). prepare(Docs, AssignFun) when is_tuple(Docs) -> %bson diff --git a/src/api/mongo_api.erl b/src/api/mongo_api.erl index e6072e2c..71ede345 100644 --- a/src/api/mongo_api.erl +++ b/src/api/mongo_api.erl @@ -13,6 +13,8 @@ -include("mongoc.hrl"). -include("mongo_protocol.hrl"). +-type transaction_result(T) :: T | {error, term()}. + %% API -export([ connect/4, @@ -23,6 +25,7 @@ delete/3, count/4, ensure_index/3, + command/3, disconnect/1]). -spec connect(atom(), list(), proplists:proplist(), proplists:proplist()) -> {ok, pid()}. @@ -104,6 +107,14 @@ ensure_index(Topology, Coll, IndexSpec) -> mc_worker_api:ensure_index(Worker, Coll, IndexSpec) end, #{}). +-spec command(atom() | pid(), selector(), timeout()) -> transaction_result(integer()). +command(Topology, Command, Timeout) -> + mongoc:transaction_query(Topology, + fun(Conf = #{pool := Worker}) -> + Query = mongoc:command_query(Conf, Command), + mc_worker_api:command(Worker, Query) + end, #{}, Timeout). + -spec disconnect(atom() | pid()) -> ok. disconnect(Topology) -> - mongoc:disconnect(Topology). \ No newline at end of file + mongoc:disconnect(Topology). diff --git a/src/api/mongoc.erl b/src/api/mongoc.erl index 0d0c1a7c..28f9de77 100644 --- a/src/api/mongoc.erl +++ b/src/api/mongoc.erl @@ -21,8 +21,10 @@ transaction/4, status/1, append_read_preference/2, + extract_read_preference/1, find_query/6, count_query/4, + command_query/2, find_one_query/5]). @@ -60,19 +62,23 @@ transaction(Topology, Transaction, Options, Timeout) -> catch error:not_master -> mc_topology:update_topology(Topology), + error_logger:error_msg("transaction error:~p reason:~p~n", [error, not_master]), {error, not_master}; error:{bad_query, {not_master, _}} -> mc_topology:update_topology(Topology), + error_logger:error_msg("transaction error:~p reason:~p~n", [error, not_master]), {error, not_master}; - _:R -> + E:R -> mc_topology:update_topology(Topology), + error_logger:error_msg("transaction error:~p reason:~p~n", [E, R]), {error, R} end; Error -> + error_logger:error_msg("mc_topology get_pool error:~p~n", [Error]), Error end. -%% @doc Get worker from pool and run transaction with additioanl query options on it. Suitable for read transactions +%% @doc Get worker from pool and run transaction with additional query options on it. Suitable for read transactions -spec transaction_query(pid() | atom(), fun()) -> any(). transaction_query(Topology, Transaction) -> transaction_query(Topology, Transaction, #{}). @@ -83,12 +89,7 @@ transaction_query(Topology, Transaction, Options) -> -spec transaction_query(pid() | atom(), fun(), map(), integer() | infinity) -> any(). transaction_query(Topology, Transaction, Options, Timeout) -> - case mc_topology:get_pool(Topology, Options) of - {ok, Pool = #{pool := C}} -> - poolboy:transaction(C, fun(Worker) -> Transaction(Pool#{pool => Worker}) end, Timeout); - Error -> - Error - end. + transaction(Topology, Transaction, Options, Timeout). -spec find_one_query(map(), collection(), selector(), projector(), integer()) -> query(). find_one_query(#{server_type := ServerType, read_preference := RPrefs}, Coll, Selector, Projector, Skip) -> @@ -129,7 +130,15 @@ count_query(#{server_type := ServerType, read_preference := RPrefs}, Coll, Selec }, mongos_query_transform(ServerType, Q, RPrefs). --spec append_read_preference(selector(), map()) -> selector(). +-spec command_query(map(), selector()) -> query(). +command_query(#{server_type := ServerType, read_preference := RPrefs}, Command) -> + Q = #'query'{ + collection = <<"$cmd">>, + selector = Command + }, + mongos_query_transform(ServerType, Q, RPrefs). + +-spec append_read_preference(selector(), readpref()) -> selector(). append_read_preference(Selector = #{<<"$query">> := _}, RP) -> Selector#{<<"$readPreference">> => RP}; append_read_preference(Selector, RP) when is_tuple(Selector) andalso element(1, Selector) =:= <<"count">> -> @@ -140,6 +149,31 @@ append_read_preference(Selector, RP) -> #{<<"$query">> => Selector, <<"$readPreference">> => RP}. +extract_read_preference(#{<<"$readPreference">> := RP} = Selector) -> + {RP, + maps:get(<<"$query">>, Selector, #{}), + maps:get(<<"$orderby">>, Selector, #{})}; +extract_read_preference(Selector) when is_map(Selector) -> + {#{<<"mode">> => <<"primary">>}, + maps:get(<<"$query">>, Selector, Selector), + maps:get(<<"$orderby">>, Selector, #{})};%TODO also extract orderby and what else might be inside (strange but needed to pass test) +extract_read_preference(Selector) when is_tuple(Selector) -> + Fields = bson:fields(Selector), + Query = case lists:keyfind(<<"$query">>, 1, Fields) of + {_, Q} -> Q; + false -> Selector + end, + OrderBy = case lists:keyfind(<<"$orderby">>, 1, Fields) of + {_, OB} -> OB; + false -> #{} + end, + case lists:keyfind(<<"$readPreference">>, 1, Fields) of + {_, RP} -> + {RP, Query, OrderBy}; + false -> + {#{<<"mode">> => <<"primary">>}, Query, OrderBy} + end. + %%%=================================================================== %%% Internal functions %%%=================================================================== diff --git a/src/connection/mc_auth_logic.erl b/src/connection/mc_auth_logic.erl index 92eed8e8..d4440dcd 100644 --- a/src/connection/mc_auth_logic.erl +++ b/src/connection/mc_auth_logic.erl @@ -36,7 +36,7 @@ auth(Connection, _, Database, Login, Password) -> %old authorisation %% @private --spec mongodb_cr_auth(pid(), binary(), binary(), binary()) -> boolean(). +-spec mongodb_cr_auth(pid(), binary(), binary(), binary()) -> boolean() | no_return(). mongodb_cr_auth(Connection, Database, Login, Password) -> {true, Res} = mc_connection_man:database_command(Connection, Database, {<<"getnonce">>, 1}), Nonce = maps:get(<<"nonce">>, Res), @@ -46,7 +46,7 @@ mongodb_cr_auth(Connection, Database, Login, Password) -> end. %% @private --spec scram_sha_1_auth(port(), binary(), binary(), binary()) -> boolean(). +-spec scram_sha_1_auth(pid(), database(), binary(), binary()) -> boolean(). scram_sha_1_auth(Connection, Database, Login, Password) -> try scram_first_step(Connection, Database, Login, Password) @@ -124,9 +124,14 @@ generate_sig(SaltedPassword, AuthMessage) -> mc_utils:hmac(ServerKey, AuthMessage). %% @private +-if(?OTP_RELEASE >= 24). +hi(Password, Salt, Iterations) -> + crypto:pbkdf2_hmac(sha, Password, Salt, Iterations, 20). +-else. hi(Password, Salt, Iterations) -> {ok, Key} = pbkdf2:pbkdf2(sha, Password, Salt, Iterations, 20), Key. +-endif. %% @private xorKeys(<<>>, _, Res) -> Res; diff --git a/src/connection/mc_connection_man.erl b/src/connection/mc_connection_man.erl index 61e0e6a4..afdd580e 100644 --- a/src/connection/mc_connection_man.erl +++ b/src/connection/mc_connection_man.erl @@ -12,28 +12,42 @@ -include("mongo_types.hrl"). -include("mongo_protocol.hrl"). +-dialyzer({no_fail_call, query_to_op_msg_cmd/2}). + -define(NOT_MASTER_ERROR, 13435). -define(UNAUTHORIZED_ERROR(C), C =:= 10057; C =:= 16550). %% API --export([request_worker/2]). +-export([request_worker/2, process_reply/2]). -export([read/2, read/3, read_one/2]). --export([command/2, command/3, database_command/3, database_command/4]). +-export([op_msg/2, op_msg_read_one/2, op_msg_raw_result/2]). +-export([command/2, command/3, database_command/3, database_command/4, request_raw_no_parse/4]). -spec read(pid() | atom(), query()) -> [] | {ok, pid()}. read(Connection, Request) -> read(Connection, Request, undefined). --spec read(pid() | atom(), query(), undefined | mc_worker_api:batchsize()) -> [] | {ok, pid()}. +-spec read(pid() | atom(), query() | op_msg_command(), undefined | mc_worker_api:batchsize()) -> [] | {ok, pid()}. read(Connection, Request = #'query'{collection = Collection, batchsize = BatchSize, database = DB}, CmdBatchSize) -> + read(Connection, Request, Collection, select_batchsize(CmdBatchSize, BatchSize), DB); +read(Connection, Request = #'op_msg_command'{database = DB, command_doc = ([{_, Collection} | _ ] = Fields)}, + _CmdBatchSize) -> + BatchSize = case lists:keyfind(<<"batchSize">>, 1, Fields) of + {_, Size} -> Size; + false -> 101 + end, + read(Connection, Request, Collection, BatchSize, DB). + +read(Connection, Request, Collection, BatchSize, DB) -> case request_worker(Connection, Request) of {_, []} -> []; {Cursor, Batch} -> - mc_cursor:start_link(Connection, Collection, Cursor, select_batchsize(CmdBatchSize, BatchSize), Batch, DB) + mc_cursor:start_link(Connection, Collection, Cursor, BatchSize, Batch, DB); + X -> + erlang:error({error_unexpected_response, X}) end. - --spec read_one(pid() | atom(), query()) -> undefined | map(). +-spec read_one(pid() | atom(), request()) -> undefined | map(). read_one(Connection, Request) -> {0, Docs} = request_worker(Connection, Request#'query'{batchsize = -1}), case Docs of @@ -41,25 +55,52 @@ read_one(Connection, Request) -> [Doc | _] -> Doc end. +-spec command(pid(), mc_worker_api:selector()) -> {boolean(), map()}. command(Connection, Query = #query{selector = Cmd}) -> + QueryOrOpMsg = query_to_op_msg_cmd(mc_utils:use_legacy_protocol(Connection), Query), case determine_cursor(Cmd) of false -> - Doc = read_one(Connection, Query), - process_reply(Doc, Query); + legacy_command(mc_utils:use_legacy_protocol(Connection), Connection, QueryOrOpMsg); BatchSize -> - case read(Connection, Query#query{batchsize = -1}, BatchSize) of + case read(Connection, QueryOrOpMsg, BatchSize) of [] -> []; {ok, Cursor} when is_pid(Cursor) -> {ok, Cursor} end end; -command(Connection, Command) when not is_record(Command, query)-> +command(Connection, Command) when not is_record(Command, query) -> command(Connection, #'query'{ collection = <<"$cmd">>, selector = Command }). +legacy_command(true, Connection, Query) -> + Doc = read_one(Connection, Query), + process_reply(Doc, Query); +legacy_command(false, Connection, OpMsg) -> + {true, mc_connection_man:op_msg_raw_result(Connection, OpMsg)}. + +-spec query_to_op_msg_cmd(boolean(), mc_worker_api:selector() ) -> query() | op_msg_command(). +query_to_op_msg_cmd(true, Query) -> + Query#query{batchsize = -1}; +query_to_op_msg_cmd(false, Query) -> + #query{database = DB, slaveok = SlaveOk, selector = Selector} = Query, + Fields = bson:fields(Selector), + NewSelector = + case {lists:keyfind(<<"$readPreference">>, 1, Fields), SlaveOk} of + {{<<"$readPreference">>, _}, _} -> Selector; + {false, true} -> + bson:merge(Fields, {<<"$readPreference">>, #{<<"mode">> => <<"primaryPreferred">>}}); + {false, false} -> + %% primary is the default mode so we do not need to change anything + Fields + end, + #'op_msg_command'{ + database = DB, + command_doc = NewSelector + }. + command(Connection, Command, _IsSlaveOk = true) -> command(Connection, #'query'{ @@ -90,7 +131,7 @@ database_command(Connection, Database, Command, IsSlaveOk) -> }, IsSlaveOk). --spec request_worker(pid(), mongo_protocol:message()) -> ok | {non_neg_integer(), [map()]}. +-spec request_worker(pid(), mongo_protocol:message()) -> ok | {non_neg_integer(), [map()]} | map(). request_worker(Connection, Request) -> %request to worker Timeout = mc_utils:get_timeout(), reply(gen_server:call(Connection, Request, Timeout)). @@ -101,6 +142,64 @@ process_reply(Doc = #{<<"ok">> := N}, _) when is_number(N) -> %command succeed process_reply(Doc, Command) -> %unknown result erlang:error({bad_command, Doc}, [Command]). +op_msg(Connection, OpMsg) -> + Doc = request_worker(Connection, OpMsg), + process_reply(Doc, OpMsg). + +op_msg_read_one(Connection, OpMsg) -> + Timeout = mc_utils:get_timeout(), + Response = gen_server:call(Connection, OpMsg, Timeout), + case Response of + #op_msg_response{response_doc = + #{<<"ok">> := 1.0, + <<"cursor">>:= + #{<<"firstBatch">>:=[Doc], + <<"id">>:=0} + }} -> + Doc; + #op_msg_response{response_doc = + #{<<"ok">> := 1.0}} -> + undefined; + #op_msg_response{response_doc = Doc} -> + erlang:error({error, Doc}); + _ -> + erlang:error({error_unexpected_response, Response}) + end. + +op_msg_raw_result(Connection, OpMsg) -> + Timeout = mc_utils:get_timeout(), + FromServer = gen_server:call(Connection, OpMsg, Timeout), + case FromServer of + #op_msg_response{response_doc = + (#{<<"ok">> := 1.0} = Res)} -> + Res; + _ -> + erlang:error({error, FromServer}) + end. + +request_raw_no_parse(Socket, Database, Request, NetModule) -> + Timeout = mc_utils:get_timeout(), + ok = set_opts(Socket, NetModule, false), + {ok, _, _} = mc_worker_logic:make_request(Socket, NetModule, Database, Request), + Result = recv_all(Socket, Timeout, NetModule), + ok = set_opts(Socket, NetModule, true), + Result. + +%% @private +set_opts(Socket, ssl, Value) -> + ssl:setopts(Socket, [{active, Value}]); +set_opts(Socket, gen_tcp, Value) -> + inet:setopts(Socket, [{active, Value}]). + +%% @private +recv_all(Socket, Timeout, NetModule) -> + recv_all(Socket, Timeout, NetModule, <<>>). +recv_all(Socket, Timeout, NetModule, Rest) -> + {ok, Packet} = NetModule:recv(Socket, 0, Timeout), + case mc_worker_logic:decode_responses(<>) of + {[], Unfinished} -> recv_all(Socket, Timeout, NetModule, Unfinished); + {Responses, _} -> Responses + end. %% @private reply(ok) -> ok; @@ -112,7 +211,14 @@ reply(#reply{cursornotfound = false, queryerror = true} = Reply) -> reply(#reply{cursornotfound = true, queryerror = false} = Reply) -> erlang:error({bad_cursor, Reply#reply.cursorid}); reply({error, Error}) -> - process_error(error, Error). + process_error(error, Error); +reply(#op_msg_response{response_doc = (#{<<"cursor">> := #{<<"firstBatch">> := Batch, <<"id">> := Id}} = Doc)}) when + map_get(<<"ok">>, Doc) == 1 -> + {Id, Batch}; +reply(#op_msg_response{response_doc = Document}) when map_get(<<"ok">>, Document) == 1 -> + Document; +reply(Resp) -> + erlang:error({error_cannot_parse_response, Resp}). %% @private -spec process_error(atom() | integer(), term()) -> no_return(). diff --git a/src/connection/mc_cursor.erl b/src/connection/mc_cursor.erl index 8fecf0ab..4b48e23d 100644 --- a/src/connection/mc_cursor.erl +++ b/src/connection/mc_cursor.erl @@ -29,6 +29,8 @@ code_change/3 ]). +-dialyzer({no_match, handle_call/3}). + -record(state, { connection :: mc_worker:connection(), collection :: atom(), @@ -40,7 +42,7 @@ }). --spec next(pid()) -> error | {bson:document()}. +-spec next(pid()) -> error | {} | {bson:document()}. next(Cursor) -> next(Cursor, cursor_default_timeout()). @@ -51,33 +53,33 @@ next(Cursor, Timeout) -> exit:{noproc, _} -> error end. --spec next_batch(pid()) -> error | {bson:document()}. +-spec next_batch(pid()) -> error | {} | {bson:document()}. next_batch(Cursor) -> next_batch(Cursor, cursor_default_timeout()). --spec next_batch(pid(), timeout()) -> error | {} | {bson:document()}. +-spec next_batch(pid(),timeout()) -> error | {} | {bson:document()}. next_batch(Cursor, Timeout) -> try gen_server:call(Cursor, {next_batch, Timeout}, Timeout) catch exit:{noproc, _} -> error end. --spec rest(pid()) -> [bson:document()] | error. +-spec rest(pid()) -> [bson:document()] | {} | error. rest(Cursor) -> rest(Cursor, cursor_default_timeout()). --spec rest(pid(), timeout()) -> [bson:document()] | error. +-spec rest(pid(), timeout()) -> [bson:document()] | {} | error. rest(Cursor, Timeout) -> try gen_server:call(Cursor, {rest, infinity, Timeout}, Timeout) catch exit:{noproc, _} -> error end. --spec take(pid(), non_neg_integer()) -> [bson:document()] | error. +-spec take(pid(), non_neg_integer()) -> [bson:document()] | {} | error. take(Cursor, Limit) -> take(Cursor, Limit, cursor_default_timeout()). --spec take(pid(), non_neg_integer(), timeout()) -> [bson:document()] | error. +-spec take(pid(), non_neg_integer(), timeout()) -> [bson:document()] | {} | error. take(Cursor, Limit, Timeout) -> try gen_server:call(Cursor, {rest, Limit, Timeout}, Timeout) catch @@ -121,6 +123,7 @@ start_link(Connection, Collection, Cursor, BatchSize, Batch, DB) -> %% @hidden init([Owner, Connection, Collection, Cursor, BatchSize, Batch,DB]) -> + process_flag(trap_exit, true), Monitor = erlang:monitor(process, Owner), {ok, #state{ connection = Connection, @@ -169,8 +172,16 @@ handle_info(_, State) -> %% @hidden terminate(_, #state{cursor = 0}) -> ok; -terminate(_, State) -> - gen_server:call(State#state.connection, #killcursor{cursorids = [State#state.cursor]}). +terminate(_, #state{collection = Collection, connection = Connection, cursor = Cursor}) -> + terminate(mc_utils:use_legacy_protocol(Connection), Connection, Collection, Cursor). + +terminate(true, Connection, _Collection, Cursor) -> + gen_server:call(Connection, #killcursor{cursorids = [Cursor]}); +terminate(false, Connection, Collection, Cursor) -> + KillCursorCommand = + #op_msg_command{command_doc = [{<<"killCursors">>, Collection}, + {<<"cursors">>, [Cursor]}]}, + mc_connection_man:request_worker(Connection, KillCursorCommand). %% @hidden code_change(_Old, State, _Extra) -> @@ -181,19 +192,33 @@ next_i(#state{batch = [Doc | Rest]} = State, _Timeout) -> {{Doc}, State#state{batch = Rest}}; next_i(#state{batch = [], cursor = 0} = State, _Timeout) -> {{}, State}; -next_i(#state{batch = []} = State, Timeout) -> +next_i(#state{batch = [], connection = Connection} = State, Timeout) -> + next_i(mc_utils:use_legacy_protocol(Connection), State, Timeout). + +next_i(true, State, Timeout) -> Reply = gen_server:call( State#state.connection, #getmore{ collection = State#state.collection, batchsize = State#state.batchsize, - cursorid = State#state.cursor, - database = State#state.database + cursorid = State#state.cursor }, Timeout), Cursor = Reply#reply.cursorid, Batch = Reply#reply.documents, - next_i(State#state{cursor = Cursor, batch = Batch}, Timeout). + next_i(State#state{cursor = Cursor, batch = Batch}, Timeout); +next_i(false, State, Timeout) -> + GetMoreCommand = + #op_msg_command{command_doc = [{<<"getMore">>, State#state.cursor}, + {<<"collection">>, State#state.collection}, + {<<"batchSize">>, State#state.batchsize}]}, + Result = mc_connection_man:request_worker(State#state.connection, GetMoreCommand), + case Result of + #{<<"cursor">>:=#{<<"id">>:=NewCursorId,<<"nextBatch">>:=Batch},<<"ok">>:=1.0} -> + next_i(State#state{cursor = NewCursorId, batch = Batch}, Timeout); + _ -> + erlang:error({error_unexpected_cursor_result, Result}) + end. %% @private rest_i(State, infinity, Timeout) -> diff --git a/src/connection/mc_worker.erl b/src/connection/mc_worker.erl index 4bee3ce5..66286dd4 100644 --- a/src/connection/mc_worker.erl +++ b/src/connection/mc_worker.erl @@ -5,6 +5,7 @@ -define(WRITE(Req), is_record(Req, insert); is_record(Req, update); is_record(Req, delete)). -define(READ(Req), is_record(Request, 'query'); is_record(Request, getmore)). +-define(OP_MSG(Req), is_record(Request, 'op_msg_command'); is_record(Request, 'op_msg_write_op')). -define(LOG_DEFAULT_DB, fun() -> error_logger:info_msg("Using default 'test' database"), <<"test">> end). -export([start_link/1, database/2, disconnect/1, hibernate/1]). @@ -16,12 +17,14 @@ terminate/2, code_change/3]). +-dialyzer({no_fail_call, get_op_msg_write_concern/1}). + -record(state, { socket :: gen_tcp:socket() | ssl:sslsocket(), request_storage = #{} :: map(), buffer = <<>> :: binary(), conn_state :: conn_state(), - hibernate_timer :: reference() | undefined, + hibernate_timer :: timer:tref() | reference() | undefined, next_req_fun :: fun(), net_module :: ssl | gen_tcp }). @@ -53,12 +56,20 @@ init(Options) -> try_register(Options), NetModule = get_set_opts_module(Options), NextReqFun = mc_utils:get_value(next_req_fun, Options, fun() -> ok end), - proc_lib:init_ack({ok, self()}), - gen_server:enter_loop(?MODULE, [], - #state{socket = Socket, - conn_state = ConnState, - net_module = NetModule, - next_req_fun = NextReqFun}); + DefaultUseLegacyProtocol = application:get_env(mongodb, use_legacy_protocol, auto), + UseLegacyProtocol = mc_utils:get_value(use_legacy_protocol, Options, DefaultUseLegacyProtocol), + ProtoOpts = #{use_legacy_protocol => UseLegacyProtocol}, + case mc_worker_pid_info:install_mc_worker_info(Options, NetModule, ConnState#conn_state.database, ProtoOpts) of + ok -> + proc_lib:init_ack({ok, self()}), + gen_server:enter_loop(?MODULE, [], + #state{socket = Socket, + conn_state = ConnState, + net_module = NetModule, + next_req_fun = NextReqFun}); + {error, _} = Error -> + proc_lib:init_ack(Error) + end; Error -> proc_lib:init_ack(Error) end. @@ -76,6 +87,8 @@ handle_call(CMD = #ensure_index{collection = Coll, index_spec = IndexSpec}, _, S ConnState#conn_state.database, #insert{collection = mc_worker_logic:update_dbcoll(Coll, <<"system.indexes">>), documents = [Index]}), {reply, ok, State}; +handle_call(Request, From, State) when ?OP_MSG(Request) -> % MongoDB OpMsg request + process_op_msg_request(Request, From, State); handle_call(Request, From, State) when ?WRITE(Request) -> % write requests (deprecated) process_write_request(Request, From, State); handle_call(Request, From, State) when ?READ(Request) -> % read requests (and all through command) @@ -122,6 +135,39 @@ terminate(_, State = #state{net_module = NetModule}) -> code_change(_Old, State, _Extra) -> {ok, State}. +process_op_msg_request(Request, From, State) -> + #state{socket = Socket, + request_storage = RequestStorage, + conn_state = CS, + net_module = NetModule, + next_req_fun = Next} = State, + Database = CS#conn_state.database, + {ok, PacketSize, Id} = mc_worker_logic:make_request(Socket, NetModule, Database, Request), + UState = need_hibernate(PacketSize, State), + case get_op_msg_write_concern(Request) of + {_, {<<"w">>, 0}} -> %no concern request + Next(), + {reply, + #op_msg_response{response_doc = #{<<"ok">> => 1.0}}, + UState}; + _ -> %ordinary request with response + Next(), + RespFun = mc_worker_logic:get_resp_fun(Request, From), % save function, which will be called on response + URStorage = RequestStorage#{Id => RespFun}, + {noreply, UState#state{request_storage = URStorage}} + end. + +get_op_msg_write_concern(#op_msg_write_op{extra_fields = ExtraFields}) -> + case lists:keyfind(<<"writeConcern">>, 1, ExtraFields) of + {_, WC} -> WC; + _ -> not_found + end; +get_op_msg_write_concern(#op_msg_command{command_doc = DocList}) -> + case lists:keyfind(<<"writeConcern">>, 1, DocList) of + {_, WC} -> WC; + _ -> not_found + end. + %% @private process_read_request(Request, From, State) -> #state{socket = Socket, diff --git a/src/connection/mc_worker_logic.erl b/src/connection/mc_worker_logic.erl index 73447367..166e7ebd 100644 --- a/src/connection/mc_worker_logic.erl +++ b/src/connection/mc_worker_logic.erl @@ -15,6 +15,8 @@ -export([decode_responses/1, process_responses/2, connect/1]). -export([make_request/4, get_resp_fun/2, update_dbcoll/2, collection/1, ensure_index/3]). +-dialyzer({no_fail_call, ensure_index/3}). + %% Make connection to database and return socket -spec connect(proplists:proplist()) -> {ok, port()} | {error, inet:posix()}. connect(Conf) -> @@ -30,6 +32,7 @@ connect(Conf) -> -spec make_request(gen_tcp:socket() | ssl:sslsocket(), atom(), mc_worker_api:database(), mongo_protocol:message() | list(mongo_protocol:message())) -> {ok | {error, any()}, integer(), pos_integer()}. + make_request(Socket, NetModule, Database, Request) -> {Packet, Id} = encode_request(Database, Request), {NetModule:send(Socket, Packet), iolist_size(Packet), Id}. @@ -37,11 +40,14 @@ make_request(Socket, NetModule, Database, Request) -> decode_responses(Data) -> decode_responses(Data, []). --spec get_resp_fun(#query{} | #getmore{} | #insert{} | #update{} | #delete{}, pid()) -> fun(). +-spec get_resp_fun(#query{} | #getmore{} | #insert{} | #update{} | #delete{} | #op_msg_command{} | #op_msg_write_op{}, + pid()) -> fun(). get_resp_fun(Read, From) when is_record(Read, query); is_record(Read, getmore) -> fun(Response) -> gen_server:reply(From, Response) end; get_resp_fun(Write, From) when is_record(Write, insert); is_record(Write, update); is_record(Write, delete) -> - process_write_response(From). + process_write_response(From); +get_resp_fun(OpMsg, From) when is_record(OpMsg, op_msg_write_op); is_record(OpMsg, op_msg_command) -> + process_op_msg_response(From). -spec process_responses(Responses :: list(), RequestStorage :: map()) -> UpdStorage :: map(). process_responses(Responses, RequestStorage) -> @@ -125,6 +131,11 @@ process_write_response(From) -> end end. +process_op_msg_response(From) -> + fun(#op_msg_response{} = OpMsg) -> + gen_server:reply(From, OpMsg) + end. + %% @private do_connect(Host, Port, Timeout, true, Opts) -> {ok, _} = application:ensure_all_started(ssl), diff --git a/src/connection/mongo_protocol.erl b/src/connection/mongo_protocol.erl index 7c0d35e0..3ee70f5d 100644 --- a/src/connection/mongo_protocol.erl +++ b/src/connection/mongo_protocol.erl @@ -19,11 +19,12 @@ -type reply() :: #reply{}. % message id -type requestid() :: integer(). --type message() :: notice() | request(). +-type message() :: notice() | request() | #op_msg_command{} | #op_msg_write_op{}. % RequestId expected to be in scope at call site -define(put_header(Opcode), ?put_int32(_RequestId), ?put_int32(0), ?put_int32(Opcode)). -define(get_header(Opcode, ResponseTo), ?get_int32(_RequestId), ?get_int32(ResponseTo), ?get_int32(Opcode)). +-define(get_header_ignore_req_id(Opcode, ResponseTo), ?get_int32(_), ?get_int32(ResponseTo), ?get_int32(Opcode)). -define(ReplyOpcode, 1). -define(UpdateOpcode, 2001). @@ -32,7 +33,14 @@ -define(GetmoreOpcode, 2005). -define(DeleteOpcode, 2006). -define(KillcursorOpcode, 2007). +-define(OpMsgOpcode, 2013). +-define(OpMsgDbFieldIndex, 4). + +-define (put_uint32 (N), (N):32/unsigned-little). +-define (put_uint8 (N), (N):8/unsigned-little). +-define (get_uint32 (N), N:32/unsigned-little). +-define (get_uint8 (N), N:8/unsigned-little). -spec dbcoll(database(), colldb()) -> bson:utf8(). @@ -82,12 +90,56 @@ put_message(Db, #getmore{collection = Coll, batchsize = Batch, cursorid = Cid}, ?put_int32(0), (bson_binary:put_cstring(dbcoll(Db, Coll)))/binary, ?put_int32(Batch), - ?put_int64(Cid)>>. + ?put_int64(Cid)>>; +put_message(Db, #op_msg_write_op{} = OpMsg, _RequestId) -> + << + ?put_header(?OpMsgOpcode), + ?put_uint32(0), % Flags + (put_section_type_zero(OpMsg#op_msg_write_op{database = make_bin(Db)}))/binary + >>; +put_message(Db, #op_msg_command{} = OpMsg, _RequestId) -> + << + ?put_header(?OpMsgOpcode), + ?put_uint32(0), % Flags + (put_section_type_zero(OpMsg#op_msg_command{database = make_bin(Db)}))/binary + >>. + +make_bin(Atom) when is_atom(Atom) -> + erlang:atom_to_binary(Atom, utf8); +make_bin(Bin) -> + Bin. +put_section_type_zero(#op_msg_command{ + command_doc = Doc, + database = Database +}) -> + << + ?put_uint8(0), + (bson_binary:put_document(bson:merge(Doc, {<<"$db">>, Database})))/binary + >>; +put_section_type_zero(#op_msg_write_op{ + command = Command, + collection = Collection, + database = Database, + extra_fields = ExtraFields, + documents_name = DocumentsName, + documents = Documents +}) -> + Msg = [ + {erlang:atom_to_binary(Command), Collection}, + {<<"$db">>, Database} + ] ++ + ExtraFields + ++ + [{DocumentsName, Documents}], + << + ?put_uint8(0), + (bson_binary:put_document(bson:document(Msg)))/binary + >>. -spec get_reply(binary()) -> {requestid(), reply(), binary()}. -get_reply(Message) -> - <> = Message) -> + < startingfrom = StartingFrom, documents = Docs }, - {ResponseTo, Reply, BinRest}. + {ResponseTo, Reply, BinRest}; +get_reply(<> = Message) -> + <> = Message, + %% For now assume the sequence type is zero + <> = Bin1, + {Doc, Rest} = bson_binary:get_map(Bin2), + Reply = #op_msg_response{ + response_doc = Doc + }, + {ResponseTo, Reply, Rest}. -spec binarize(binary() | atom()) -> binary(). %@doc Ensures the given term is converted to a UTF-8 binary. diff --git a/src/main/mc_super_sup.erl b/src/main/mc_super_sup.erl index 91f3bce0..485e0491 100644 --- a/src/main/mc_super_sup.erl +++ b/src/main/mc_super_sup.erl @@ -17,4 +17,5 @@ start_link() -> init(app) -> MongoIdServer = ?CHILD(mongo_id_server, worker), McPoolSup = ?CHILD(mc_pool_sup, supervisor), - {ok, {{one_for_one, 1000, 3600}, [MongoIdServer, McPoolSup]}}. + PIDInfoServer = ?CHILD(mc_worker_pid_info, worker), + {ok, {{one_for_one, 1000, 3600}, [MongoIdServer, McPoolSup, PIDInfoServer]}}. diff --git a/src/main/mc_worker_pid_info.erl b/src/main/mc_worker_pid_info.erl new file mode 100644 index 00000000..91f21bfd --- /dev/null +++ b/src/main/mc_worker_pid_info.erl @@ -0,0 +1,136 @@ +-module(mc_worker_pid_info). + +%% This module is used to get information (currently only the protocol type) +%% ragaring mc_worker processes. This is useful so we can encode messages in +%% the right way befor sending them to an mc_worker process. + +-behaviour(gen_server). + +-include("mongo_protocol.hrl"). + +-export([start_link/0, + init/1, + terminate/2, + handle_cast/2, + handle_call/3, + get_info/1, + set_info/2, + discard_info/1, + get_protocol_type/1, + handle_info/2, + install_mc_worker_info/4]). + +-dialyzer({no_fail_call, detect_protocol_type/4}). + +-define(CLEAN_TABLE_PERIOD_MINS, 30). +-define(CLEAN_TABLE_MESSAGE, clean_table). +-define(MC_WORKER_PID_INFO_TAB_NAME, mc_worker_pid_info_tab). + +-spec start_link() -> {ok, pid()}. +start_link() -> + gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). + +init(_) -> + ets:new(?MC_WORKER_PID_INFO_TAB_NAME, [public, named_table, {read_concurrency, true}]), + {ok, start_cleanup_timer_update_state(#{})}. + +start_cleanup_timer_update_state(State) -> + State#{ + cleanup_timer_ref => + erlang:start_timer(timer:minutes(?CLEAN_TABLE_PERIOD_MINS), + self(), + ?CLEAN_TABLE_MESSAGE, + []) + }. + +terminate(_,_) -> + ets:delete(?MC_WORKER_PID_INFO_TAB_NAME). + +%% These functions does not do anything as this server is just a holder of an ETS table +handle_cast(_Request, State) -> {noreply, State}. +handle_call(_Request, _From, State) -> {reply, {error, ignore}, State}. + +handle_info({timeout, + TimerRef, + ?CLEAN_TABLE_MESSAGE}, + #{cleanup_timer_ref := TimerRef} = State) -> + PidInfos = ets:tab2list(?MC_WORKER_PID_INFO_TAB_NAME), + [delete_pid_if_dead(Pid) || {Pid, _} <- PidInfos], + {noreply, start_cleanup_timer_update_state(State)}; +handle_info(_, State) -> + {noreply, State}. + +delete_pid_if_dead(Pid) -> + case erlang:is_process_alive(Pid) of + true -> ok; + false -> ets:delete(?MC_WORKER_PID_INFO_TAB_NAME, Pid) + end. + +get_info(MCWorkerPID) -> + try + case ets:lookup(?MC_WORKER_PID_INFO_TAB_NAME, MCWorkerPID) of + [{MCWorkerPID, InfoMap}] -> + {ok, InfoMap}; + [] -> + not_found + end + catch + _:_:_ -> + not_found + end. + + +set_info(MCWorkerPID, InfoMap) -> + ets:insert(?MC_WORKER_PID_INFO_TAB_NAME, {MCWorkerPID, InfoMap}). + +discard_info(MCWorkerPID) -> + ets:delete(?MC_WORKER_PID_INFO_TAB_NAME, MCWorkerPID). + +get_protocol_type(MCWorkerPID) -> + case get_info(MCWorkerPID) of + {ok, #{protocol_type := ProtocolType}} -> + ProtocolType; + _ -> + %% Not found means that this library has been hot upgraded and the + %% mc_worker process was created before the hot_upgrade so we use + %% the legacy protocol as this was what existed before + legacy + end. + +%% This process should be called from mc_worker processes to install their info +%% in the ?MC_WORKER_PID_INFO_TAB_NAME ETS table +install_mc_worker_info(Options, NetModule, Database, Opts) -> + UseLegacyProtocol = maps:get(use_legacy_protocol, Opts), + try + ProtocolType = detect_protocol_type(UseLegacyProtocol, Options, NetModule, Database), + mc_worker_pid_info:set_info(self(), #{protocol_type => ProtocolType}), + ok + catch + What:Reason -> + {error, {What, Reason}} + end. + +%% true - use the legacy message protocol +%% false - use the modern message protocol based on the op_msg opcode +%% auto - detect which message protocol to use based on the database +detect_protocol_type(true, _ConnectionOpts, _NetModule, _Database) -> legacy; +detect_protocol_type(false, _ConnectionOpts, _NetModule, _Database) -> op_msg; +detect_protocol_type(auto, ConnectionOpts, NetModule, Database) -> + %% Automatically detect which protocol to use. We send a `hello' command using + %% the new protocol. If we have our connection closed by the server, it means + %% it doesn't support the new protocol. + %% See also: + %% * https://github.com/mongodb/mongo/blob/5e494138af456f42381ad08748cc7fbc4ace7a60/src/mongo/base/error_codes.yml + %% * https://www.mongodb.com/docs/manual/reference/command/hello/#mongodb-dbcommand-dbcmd.hello + %% * https://www.mongodb.com/docs/manual/reference/mongodb-wire-protocol/#std-label-wire-op-msg + {ok, Socket} = mc_worker_logic:connect(ConnectionOpts), + Command = bson:fields({hello, 1}), + Request = #op_msg_command{command_doc = Command, database = Database}, + try mc_connection_man:request_raw_no_parse(Socket, Database, Request, NetModule) of + [{_, #op_msg_response{response_doc = #{<<"ok">> := _}}} | _] -> op_msg; + _ErrorResponse -> legacy + catch + _:_ -> legacy + after + NetModule:close(Socket) + end. diff --git a/src/mongoc/mc_monitor.erl b/src/mongoc/mc_monitor.erl index 5324e59b..0b535323 100644 --- a/src/mongoc/mc_monitor.erl +++ b/src/mongoc/mc_monitor.erl @@ -153,11 +153,16 @@ maybe_recheck(_, Topology, Server, ConnectArgs, HB_MS, MinHB_MS) -> check(ConnectArgs, Server) -> Start = os:timestamp(), {ok, Conn} = mc_worker_api:connect(ConnectArgs), - {true, IsMaster} = mc_worker_api:command(Conn, {isMaster, 1}), + {true, IsMaster} = is_master_check(mc_utils:use_legacy_protocol(Conn), Conn), Finish = os:timestamp(), mc_worker_api:disconnect(Conn), {monitor_ismaster, Server, IsMaster, timer:now_diff(Finish, Start)}. +is_master_check(true, Connection) -> + mc_worker_api:command(Connection, {isMaster, 1}); +is_master_check(false, Connection) -> + mc_worker_api:command(Connection, {hello, 1}). + %% @private do_timeout(Pid, TO) when TO > 0 -> receive diff --git a/src/mongoc/mc_topology.erl b/src/mongoc/mc_topology.erl index 561a3510..5f52b1c6 100644 --- a/src/mongoc/mc_topology.erl +++ b/src/mongoc/mc_topology.erl @@ -218,13 +218,17 @@ parse_ismaster(Server, IsMaster, RTT, State = #topology_state{servers = Tab}) -> arbiters = maps:get(<<"arbiters">>, IsMaster, []), electionId = maps:get(<<"electionId">>, IsMaster, undefined), primary = maps:get(<<"primary">>, IsMaster, undefined), - ismaster = maps:get(<<"ismaster">>, IsMaster, undefined), + ismaster = is_master(IsMaster), secondary = maps:get(<<"secondary">>, IsMaster, undefined) }, ets:insert(Tab, ToUpdate), mc_server:update_ismaster(ToUpdate#mc_server.pid, {SType, ToUpdate}), mc_topology_logics:update_topology_state(ToUpdate, State). +is_master(#{<<"ismaster">> := V} = _IsMaster) -> V; +is_master(#{<<"isWritablePrimary">> := V} = _IsMaster) -> V; +is_master(_) -> undefined. + %% @private parse_rtt(_, undefined, RTT) -> {RTT, RTT}; parse_rtt(OldRTT, CurRTT, RTT) -> @@ -234,8 +238,12 @@ parse_rtt(OldRTT, CurRTT, RTT) -> %% @private server_type(#{<<"ismaster">> := true, <<"secondary">> := false, <<"setName">> := _}) -> rsPrimary; +server_type(#{<<"isWritablePrimary">> := true, <<"secondary">> := false, <<"setName">> := _}) -> + rsPrimary; server_type(#{<<"ismaster">> := false, <<"secondary">> := true, <<"setName">> := _}) -> rsSecondary; +server_type(#{<<"isWritablePrimary">> := false, <<"secondary">> := true, <<"setName">> := _}) -> + rsSecondary; server_type(#{<<"arbiterOnly">> := true, <<"setName">> := _}) -> rsArbiter; server_type(#{<<"hidden">> := true, <<"setName">> := _}) -> @@ -246,6 +254,8 @@ server_type(#{<<"msg">> := <<"isdbgrid">>}) -> mongos; server_type(#{<<"isreplicaset">> := true}) -> rsGhost; +server_type(#{<<"isWritablePrimary">> := true}) -> + standalone; server_type(#{<<"ok">> := _}) -> unknown; server_type(_) -> diff --git a/src/support/mc_utils.erl b/src/support/mc_utils.erl index f3993489..6fad93ec 100644 --- a/src/support/mc_utils.erl +++ b/src/support/mc_utils.erl @@ -31,7 +31,9 @@ hmac/2, is_proplist/1, to_binary/1, - get_srv_seeds/1]). + get_srv_seeds/1, + use_legacy_protocol/1, + get_connection_pid/1]). get_value(Key, List) -> get_value(Key, List, undefined). @@ -122,3 +124,18 @@ valid_endpoint(Host, Srv) -> [_ | HostBaseDomain] = string:split(Host, "."), [_ | SrvBaseDomain] = string:split(Srv, "."), HostBaseDomain == SrvBaseDomain. + +use_legacy_protocol(Connection) -> + %% Latest MongoDB version that supported the non-op-msg based opcodes was + %% 5.0.x (at the time of writing 5.0.14). The non-op-msg based opcodes were + %% removed in MongoDB version 5.1.0. See + %% https://www.mongodb.com/docs/manual/legacy-opcodes/ + case mc_worker_pid_info:get_protocol_type(Connection) of + legacy -> true; + op_msg -> false + end. + +get_connection_pid(Connection) when is_pid(Connection) -> + Connection; +get_connection_pid(#{connection_pid := Pid}) -> + Pid. diff --git a/test/mc_test_utils.erl b/test/mc_test_utils.erl index dce30b91..fcd311f2 100644 --- a/test/mc_test_utils.erl +++ b/test/mc_test_utils.erl @@ -10,14 +10,15 @@ -author("tihon"). %% API --export([collection/1]). +-export([collection/2]). -collection(Case) -> +collection(Mod, Case) -> Now = now_to_seconds(os:timestamp()), <<(atom_to_binary(?MODULE, utf8))/binary, $-, + (atom_to_binary(Mod, utf8))/binary, $-, (atom_to_binary(Case, utf8))/binary, $-, (list_to_binary(integer_to_list(Now)))/binary>>. %% @private now_to_seconds({Mega, Sec, _}) -> - (Mega * 1000000) + Sec. \ No newline at end of file + (Mega * 1000000) + Sec. diff --git a/test/mc_worker_api_SUITE.erl b/test/mc_worker_api_SUITE.erl index 4e607900..0befd43f 100644 --- a/test/mc_worker_api_SUITE.erl +++ b/test/mc_worker_api_SUITE.erl @@ -6,7 +6,7 @@ -include("mongo_protocol.hrl"). --compile(export_all). +-compile([export_all, nowarn_export_all]). all() -> [ @@ -31,7 +31,7 @@ end_per_suite(_Config) -> init_per_testcase(Case, Config) -> {ok, Connection} = mc_worker_api:connect([{database, ?config(database, Config)}, {login, <<"user">>}, {password, <<"test">>}, {w_mode, safe}]), - [{connection, Connection}, {collection, mc_test_utils:collection(Case)} | Config]. + [{connection, Connection}, {collection, mc_test_utils:collection(?MODULE, Case)} | Config]. end_per_testcase(_Case, Config) -> Connection = ?config(connection, Config), @@ -65,7 +65,6 @@ insert_and_find(Config) -> 4 = mc_worker_api:count(Connection, Collection, #{}), {ok, TeamsCur} = mc_worker_api:find(Connection, Collection, #{}), TeamsFound = mc_cursor:rest(TeamsCur), - undefined = process_info(TeamsCur), ?assertEqual(Teams, TeamsFound), {ok, NationalTeamsCur} = mc_worker_api:find( @@ -134,6 +133,12 @@ insert_and_delete(Config) -> mc_worker_api:delete_one(Connection, Collection, #{}), 3 = mc_worker_api:count(Connection, Collection, #{}), + + mc_worker_api:delete_limit(Connection, Collection, #{}, 1), + 2 = mc_worker_api:count(Connection, Collection, #{}), + + mc_worker_api:delete(Connection, Collection, #{}), + 0 = mc_worker_api:count(Connection, Collection, #{}), Config. insert_map(Config) -> diff --git a/test/mongo_api_SUITE.erl b/test/mongo_api_SUITE.erl index 7bc77ca8..d257ae5d 100644 --- a/test/mongo_api_SUITE.erl +++ b/test/mongo_api_SUITE.erl @@ -4,11 +4,12 @@ -include_lib("common_test/include/ct.hrl"). -include_lib("eunit/include/eunit.hrl"). --compile(export_all). +-compile([export_all, nowarn_export_all]). all() -> [ ensure_index_test, + per_connection_protocol_type_test, count_test, find_one_test, find_test, @@ -25,7 +26,7 @@ end_per_suite(_Config) -> init_per_testcase(Case, Config) -> {ok, Pid} = mongo_api:connect(single, ["localhost:27017"], [{pool_size, 1}, {max_overflow, 0}], [{database, ?config(database, Config)}, {login, <<"user">>}, {password, <<"test">>}]), - [{connection, Pid}, {collection, mc_test_utils:collection(Case)} | Config]. + [{connection, Pid}, {collection, mc_test_utils:collection(?MODULE, Case)} | Config]. end_per_testcase(_Case, Config) -> Connection = ?config(connection, Config), @@ -33,6 +34,15 @@ end_per_testcase(_Case, Config) -> mongo_api:delete(Connection, Collection, #{}), mongo_api:disconnect(Connection). +parse_mongo_version(StrVersion) -> + Segments = string:split(StrVersion, ".", all), + case lists:map(fun list_to_integer/1, Segments) of + [Major, Minor] -> + {Major, Minor, 0}; + [Major, Minor, Patch] -> + {Major, Minor, Patch} + end. + %% Tests ensure_index_test(Config) -> Pid = ?config(connection, Config), @@ -41,6 +51,38 @@ ensure_index_test(Config) -> ok = mongo_api:ensure_index(Pid, Collection, {<<"key">>, {<<"z_first">>, 1, <<"a_last">>, 1}}), Config. +%% regardless of application env, we can set the protocol type per connection +per_connection_protocol_type_test(Config) -> + {ok, MCWorkerConnection0} = + mc_worker:start_link([{database, ?config(database, Config)} + , {w_mode, safe} + , {use_legacy_protocol, true} + ]), + ?assert(mc_utils:use_legacy_protocol(MCWorkerConnection0)), + MRef0 = monitor(process, MCWorkerConnection0), + mc_worker:disconnect(MCWorkerConnection0), + receive + {'DOWN', MRef0, process, MCWorkerConnection0, _} -> + ok + after + 1_000 -> ct:fail("worker didn't halt") + end, + {ok, MCWorkerConnection1} = + mc_worker:start_link([{database, ?config(database, Config)} + , {w_mode, safe} + , {use_legacy_protocol, false} + ]), + ?assertNot(mc_utils:use_legacy_protocol(MCWorkerConnection1)), + MRef1 = monitor(process, MCWorkerConnection1), + mc_worker:disconnect(MCWorkerConnection1), + receive + {'DOWN', MRef1, process, MCWorkerConnection1, _} -> + ok + after + 1_000 -> ct:fail("worker didn't halt") + end, + ok. + count_test(Config) -> Collection = ?config(collection, Config), Pid = ?config(connection, Config), diff --git a/test/switch_db_SUITE.erl b/test/switch_db_SUITE.erl index db0d2abb..9f30ed29 100644 --- a/test/switch_db_SUITE.erl +++ b/test/switch_db_SUITE.erl @@ -33,7 +33,7 @@ end_per_suite(_Config) -> init_per_testcase(Case, Config) -> {ok, Connection} = mc_worker_api:connect([{database, ?config(database, Config)}, {login, <<"user">>}, {password, <<"test">>}, {w_mode, safe}]), - [{connection, Connection}, {collection, mc_test_utils:collection(Case)} | Config]. + [{connection, Connection}, {collection, mc_test_utils:collection(?MODULE, Case)} | Config]. end_per_testcase(_Case, Config) -> Connection = ?config(connection, Config),