Skip to content

Commit 4ab411b

Browse files
authored
Merge pull request #1459 from fnc12/bugfix/cross-join-serialization
fixed cross join serialization
2 parents ca580cf + 698809e commit 4ab411b

File tree

7 files changed

+229
-74
lines changed

7 files changed

+229
-74
lines changed

dev/ast/cross_join.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include "../functional/cxx_type_traits_polyfill.h"
2+
3+
namespace sqlite_orm {
4+
namespace internal {
5+
6+
/**
7+
* CROSS JOIN holder.
8+
* T is joined type which represents any mapped table.
9+
*/
10+
template<class T>
11+
struct cross_join_t {
12+
using type = T;
13+
};
14+
}
15+
}
16+
17+
SQLITE_ORM_EXPORT namespace sqlite_orm {
18+
19+
/**
20+
* CROSS JOIN function. Usage:
21+
* `cross_join<User>();`
22+
*/
23+
template<class T>
24+
internal::cross_join_t<T> cross_join() {
25+
return {};
26+
}
27+
}

dev/conditions.h

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
#include "tags.h"
2727
#include "type_printer.h"
2828
#include "literal.h"
29+
#include "ast/cross_join.h"
2930

3031
namespace sqlite_orm {
3132

3233
namespace internal {
33-
3434
/**
3535
* Collated something
3636
*/
@@ -595,33 +595,12 @@ namespace sqlite_orm {
595595
glob_t(arg_t arg_, pattern_t pattern_) : arg(std::move(arg_)), pattern(std::move(pattern_)) {}
596596
};
597597

598-
struct cross_join_string {
599-
operator std::string() const {
600-
return "CROSS JOIN";
601-
}
602-
};
603-
604-
/**
605-
* CROSS JOIN holder.
606-
* T is joined type which represents any mapped table.
607-
*/
608-
template<class T>
609-
struct cross_join_t : cross_join_string {
610-
using type = T;
611-
};
612-
613-
struct natural_join_string {
614-
operator std::string() const {
615-
return "NATURAL JOIN";
616-
}
617-
};
618-
619598
/**
620599
* NATURAL JOIN holder.
621600
* T is joined type which represents any mapped table.
622601
*/
623602
template<class T>
624-
struct natural_join_t : natural_join_string {
603+
struct natural_join_t {
625604
using type = T;
626605
};
627606

@@ -750,6 +729,13 @@ namespace sqlite_orm {
750729

751730
template<class T>
752731
using is_constrained_join = polyfill::is_detected<on_type_t, T>;
732+
733+
template<class T>
734+
using is_any_join = mpl::invoke_t<mpl::disjunction<check_if<is_constrained_join>,
735+
check_if_is_template<cross_join_t>,
736+
check_if_is_template<natural_join_t>>,
737+
T>;
738+
753739
}
754740
}
755741

@@ -907,11 +893,6 @@ SQLITE_ORM_EXPORT namespace sqlite_orm {
907893
return {std::move(t)};
908894
}
909895

910-
template<class T>
911-
internal::cross_join_t<T> cross_join() {
912-
return {};
913-
}
914-
915896
#ifdef SQLITE_ORM_WITH_CPP20_ALIASES
916897
template<orm_refers_to_recordset auto alias>
917898
auto cross_join() {

dev/statement_serializer.h

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1946,17 +1946,24 @@ namespace sqlite_orm {
19461946
using conditions_tuple = typename statement_type::conditions_type;
19471947
constexpr bool hasExplicitFrom = tuple_has<conditions_tuple, is_from>::value;
19481948
if constexpr (!hasExplicitFrom) {
1949+
using joins_index_sequence = filter_tuple_sequence_t<conditions_tuple, is_any_join>;
1950+
19491951
auto tableNames = collect_table_names(sel, context);
1950-
using joins_index_sequence = filter_tuple_sequence_t<conditions_tuple, is_constrained_join>;
19511952
// deduplicate table names of constrained join statements
19521953
iterate_tuple(sel.conditions, joins_index_sequence{}, [&tableNames, &context](auto& join) {
19531954
using original_join_type = typename std::remove_reference_t<decltype(join)>::type;
1954-
using cross_join_type = mapped_type_proxy_t<original_join_type>;
1955-
std::pair<const std::string&, std::string> tableNameWithAlias{
1956-
lookup_table_name<cross_join_type>(context.db_objects),
1957-
alias_extractor<original_join_type>::as_alias()};
1958-
tableNames.erase(tableNameWithAlias);
1955+
using join_type = mapped_type_proxy_t<original_join_type>;
1956+
1957+
const auto& tableName = lookup_table_name<join_type>(context.db_objects);
1958+
auto it = std::find_if(tableNames.begin(), tableNames.end(), [&tableName](const auto& pair) {
1959+
return pair.first == tableName;
1960+
});
1961+
if (it == tableNames.end()) {
1962+
return;
1963+
}
1964+
tableNames.erase(it);
19591965
});
1966+
19601967
if (!tableNames.empty() && !is_compound_operator<T>::value) {
19611968
ss << " FROM " << streaming_identifiers(tableNames);
19621969
}
@@ -2290,11 +2297,19 @@ namespace sqlite_orm {
22902297
using statement_type = Join;
22912298

22922299
template<class Ctx>
2293-
SQLITE_ORM_STATIC_CALLOP std::string operator()(const statement_type& join,
2300+
SQLITE_ORM_STATIC_CALLOP std::string operator()(const statement_type& /*join*/,
22942301
const Ctx& context) SQLITE_ORM_OR_CONST_CALLOP {
22952302
std::stringstream ss;
2296-
ss << static_cast<std::string>(join) << " "
2297-
<< streaming_identifier(lookup_table_name<type_t<Join>>(context.db_objects));
2303+
if constexpr (polyfill::is_specialization_of<statement_type, cross_join_t>::value) {
2304+
ss << "CROSS JOIN";
2305+
} else if constexpr (polyfill::is_specialization_of<statement_type, natural_join_t>::value) {
2306+
ss << "NATURAL JOIN";
2307+
} else {
2308+
static_assert(polyfill::always_false_v<statement_type>);
2309+
}
2310+
ss << " "
2311+
<< streaming_identifier(
2312+
lookup_table_name<mapped_type_proxy_t<type_t<statement_type>>>(context.db_objects));
22982313
return ss.str();
22992314
}
23002315
};

include/sqlite_orm/sqlite_orm.h

Lines changed: 61 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5197,10 +5197,38 @@ namespace sqlite_orm {
51975197
}
51985198
}
51995199

5200-
namespace sqlite_orm {
5200+
// #include "ast/cross_join.h"
5201+
// #include "../functional/cxx_type_traits_polyfill.h"
52015202

5203+
namespace sqlite_orm {
52025204
namespace internal {
52035205

5206+
/**
5207+
* CROSS JOIN holder.
5208+
* T is joined type which represents any mapped table.
5209+
*/
5210+
template<class T>
5211+
struct cross_join_t {
5212+
using type = T;
5213+
};
5214+
}
5215+
}
5216+
5217+
SQLITE_ORM_EXPORT namespace sqlite_orm {
5218+
5219+
/**
5220+
* CROSS JOIN function. Usage:
5221+
* `cross_join<User>();`
5222+
*/
5223+
template<class T>
5224+
internal::cross_join_t<T> cross_join() {
5225+
return {};
5226+
}
5227+
}
5228+
5229+
namespace sqlite_orm {
5230+
5231+
namespace internal {
52045232
/**
52055233
* Collated something
52065234
*/
@@ -5765,33 +5793,12 @@ namespace sqlite_orm {
57655793
glob_t(arg_t arg_, pattern_t pattern_) : arg(std::move(arg_)), pattern(std::move(pattern_)) {}
57665794
};
57675795

5768-
struct cross_join_string {
5769-
operator std::string() const {
5770-
return "CROSS JOIN";
5771-
}
5772-
};
5773-
5774-
/**
5775-
* CROSS JOIN holder.
5776-
* T is joined type which represents any mapped table.
5777-
*/
5778-
template<class T>
5779-
struct cross_join_t : cross_join_string {
5780-
using type = T;
5781-
};
5782-
5783-
struct natural_join_string {
5784-
operator std::string() const {
5785-
return "NATURAL JOIN";
5786-
}
5787-
};
5788-
57895796
/**
57905797
* NATURAL JOIN holder.
57915798
* T is joined type which represents any mapped table.
57925799
*/
57935800
template<class T>
5794-
struct natural_join_t : natural_join_string {
5801+
struct natural_join_t {
57955802
using type = T;
57965803
};
57975804

@@ -5920,6 +5927,13 @@ namespace sqlite_orm {
59205927

59215928
template<class T>
59225929
using is_constrained_join = polyfill::is_detected<on_type_t, T>;
5930+
5931+
template<class T>
5932+
using is_any_join = mpl::invoke_t<mpl::disjunction<check_if<is_constrained_join>,
5933+
check_if_is_template<cross_join_t>,
5934+
check_if_is_template<natural_join_t>>,
5935+
T>;
5936+
59235937
}
59245938
}
59255939

@@ -6077,11 +6091,6 @@ SQLITE_ORM_EXPORT namespace sqlite_orm {
60776091
return {std::move(t)};
60786092
}
60796093

6080-
template<class T>
6081-
internal::cross_join_t<T> cross_join() {
6082-
return {};
6083-
}
6084-
60856094
#ifdef SQLITE_ORM_WITH_CPP20_ALIASES
60866095
template<orm_refers_to_recordset auto alias>
60876096
auto cross_join() {
@@ -22358,17 +22367,24 @@ namespace sqlite_orm {
2235822367
using conditions_tuple = typename statement_type::conditions_type;
2235922368
constexpr bool hasExplicitFrom = tuple_has<conditions_tuple, is_from>::value;
2236022369
if constexpr (!hasExplicitFrom) {
22370+
using joins_index_sequence = filter_tuple_sequence_t<conditions_tuple, is_any_join>;
22371+
2236122372
auto tableNames = collect_table_names(sel, context);
22362-
using joins_index_sequence = filter_tuple_sequence_t<conditions_tuple, is_constrained_join>;
2236322373
// deduplicate table names of constrained join statements
2236422374
iterate_tuple(sel.conditions, joins_index_sequence{}, [&tableNames, &context](auto& join) {
2236522375
using original_join_type = typename std::remove_reference_t<decltype(join)>::type;
22366-
using cross_join_type = mapped_type_proxy_t<original_join_type>;
22367-
std::pair<const std::string&, std::string> tableNameWithAlias{
22368-
lookup_table_name<cross_join_type>(context.db_objects),
22369-
alias_extractor<original_join_type>::as_alias()};
22370-
tableNames.erase(tableNameWithAlias);
22376+
using join_type = mapped_type_proxy_t<original_join_type>;
22377+
22378+
const auto& tableName = lookup_table_name<join_type>(context.db_objects);
22379+
auto it = std::find_if(tableNames.begin(), tableNames.end(), [&tableName](const auto& pair) {
22380+
return pair.first == tableName;
22381+
});
22382+
if (it == tableNames.end()) {
22383+
return;
22384+
}
22385+
tableNames.erase(it);
2237122386
});
22387+
2237222388
if (!tableNames.empty() && !is_compound_operator<T>::value) {
2237322389
ss << " FROM " << streaming_identifiers(tableNames);
2237422390
}
@@ -22702,11 +22718,19 @@ namespace sqlite_orm {
2270222718
using statement_type = Join;
2270322719

2270422720
template<class Ctx>
22705-
SQLITE_ORM_STATIC_CALLOP std::string operator()(const statement_type& join,
22721+
SQLITE_ORM_STATIC_CALLOP std::string operator()(const statement_type& /*join*/,
2270622722
const Ctx& context) SQLITE_ORM_OR_CONST_CALLOP {
2270722723
std::stringstream ss;
22708-
ss << static_cast<std::string>(join) << " "
22709-
<< streaming_identifier(lookup_table_name<type_t<Join>>(context.db_objects));
22724+
if constexpr (polyfill::is_specialization_of<statement_type, cross_join_t>::value) {
22725+
ss << "CROSS JOIN";
22726+
} else if constexpr (polyfill::is_specialization_of<statement_type, natural_join_t>::value) {
22727+
ss << "NATURAL JOIN";
22728+
} else {
22729+
static_assert(polyfill::always_false_v<statement_type>);
22730+
}
22731+
ss << " "
22732+
<< streaming_identifier(
22733+
lookup_table_name<mapped_type_proxy_t<type_t<statement_type>>>(context.db_objects));
2271022734
return ss.str();
2271122735
}
2271222736
};
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include <sqlite_orm/sqlite_orm.h>
2+
#include <catch2/catch_all.hpp>
3+
4+
using namespace sqlite_orm;
5+
6+
TEST_CASE("cross_join") {
7+
using internal::serialize;
8+
9+
struct User {
10+
int id = 0;
11+
std::string name;
12+
};
13+
auto table = make_table("users", make_column("id", &User::id), make_column("name", &User::name));
14+
using db_objects_t = internal::db_objects_tuple<decltype(table)>;
15+
auto dbObjects = db_objects_t{table};
16+
using context_t = internal::serializer_context<db_objects_t>;
17+
context_t context{dbObjects};
18+
std::string value;
19+
SECTION("straight") {
20+
auto node = cross_join<User>();
21+
value = serialize(node, context);
22+
}
23+
SECTION("alias") {
24+
using user_s = alias_s<User>;
25+
auto node = cross_join<user_s>();
26+
value = serialize(node, context);
27+
}
28+
REQUIRE(value == R"(CROSS JOIN "users")");
29+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include <sqlite_orm/sqlite_orm.h>
2+
#include <catch2/catch_all.hpp>
3+
4+
using namespace sqlite_orm;
5+
6+
TEST_CASE("natural_join") {
7+
using internal::serialize;
8+
9+
struct User {
10+
int id = 0;
11+
std::string name;
12+
};
13+
auto table = make_table("users", make_column("id", &User::id), make_column("name", &User::name));
14+
using db_objects_t = internal::db_objects_tuple<decltype(table)>;
15+
auto dbObjects = db_objects_t{table};
16+
using context_t = internal::serializer_context<db_objects_t>;
17+
context_t context{dbObjects};
18+
std::string value;
19+
SECTION("straight") {
20+
auto node = natural_join<User>();
21+
value = serialize(node, context);
22+
}
23+
SECTION("alias") {
24+
using user_s = alias_s<User>;
25+
auto node = natural_join<user_s>();
26+
value = serialize(node, context);
27+
}
28+
REQUIRE(value == R"(NATURAL JOIN "users")");
29+
}

0 commit comments

Comments
 (0)