@@ -29,16 +29,64 @@ def visit_Arel_Nodes_Concat(o, collector)
2929 visit o . right , collector
3030 end
3131
32+ # Same as SQLite and PostgreSQL.
3233 def visit_Arel_Nodes_UpdateStatement ( o , collector )
33- if has_join_and_composite_primary_key? ( o )
34- update_statement_using_join ( o , collector )
34+ collector . retryable = false
35+ o = prepare_update_statement ( o )
36+
37+ collector << "UPDATE "
38+
39+ # UPDATE with JOIN is in the form of:
40+ #
41+ # UPDATE t1
42+ # SET ..
43+ # FROM t2
44+ # WHERE t1.join_id = t2.join_id
45+ #
46+ # Or if more than one join is present:
47+ #
48+ # UPDATE t1
49+ # SET ..
50+ # FROM t2
51+ # JOIN t3 ON t2.join_id = t3.join_id
52+ # WHERE t1.join_id = t2.join_id
53+ if has_join_sources? ( o )
54+ visit o . relation . left , collector
55+ collect_nodes_for o . values , collector , " SET "
56+ collector << " FROM "
57+ first_join , *remaining_joins = o . relation . right
58+ visit first_join . left , collector
59+
60+ if remaining_joins && !remaining_joins . empty?
61+ collector << " "
62+ remaining_joins . each do |join |
63+ visit join , collector
64+ end
65+ end
66+
67+ collect_nodes_for [ first_join . right . expr ] + o . wheres , collector , " WHERE " , " AND "
68+ else
69+ collector = visit o . relation , collector
70+ collect_nodes_for o . values , collector , " SET "
71+ collect_nodes_for o . wheres , collector , " WHERE " , " AND "
72+ end
73+
74+ collect_nodes_for o . orders , collector , " ORDER BY "
75+ maybe_visit o . limit , collector
76+ end
77+
78+ # Same as PostgreSQL except we need to add limit if using subquery.
79+ def prepare_update_statement ( o )
80+ if has_join_sources? ( o ) && !has_limit_or_offset_or_orders? ( o ) && !has_group_by_and_having? ( o )
81+ o
3582 else
3683 o . limit = Nodes ::Limit . new ( 9_223_372_036_854_775_807 ) if o . orders . any? && o . limit . nil?
3784
3885 super
3986 end
4087 end
4188
89+
4290 def visit_Arel_Nodes_DeleteStatement ( o , collector )
4391 if has_join_and_composite_primary_key? ( o )
4492 delete_statement_using_join ( o , collector )
@@ -61,17 +109,6 @@ def delete_statement_using_join(o, collector)
61109 collect_nodes_for o . wheres , collector , " WHERE " , " AND "
62110 end
63111
64- def update_statement_using_join ( o , collector )
65- collector . retryable = false
66-
67- collector << "UPDATE "
68- visit o . relation . left , collector
69- collect_nodes_for o . values , collector , " SET "
70- collector << " FROM "
71- visit o . relation , collector
72- collect_nodes_for o . wheres , collector , " WHERE " , " AND "
73- end
74-
75112 def visit_Arel_Nodes_Lock ( o , collector )
76113 o . expr = Arel . sql ( "WITH(UPDLOCK)" ) if o . expr . to_s =~ /FOR UPDATE/
77114 collector << " "
0 commit comments