@@ -2,6 +2,27 @@ open Core
22open Core.Poly
33open Middle
44
5+ let log_demotions = ref true
6+
7+ let user_warning (mem_pattern : Mem_pattern.t ) (linenum : int ) (msg : string ) =
8+ if ! log_demotions then
9+ let mem_name =
10+ match mem_pattern with Mem_pattern. SoA -> " SoA" | AoS -> " AoS" in
11+ Printf. eprintf " %s (Line: %i) warning: %s\n " mem_name linenum msg
12+
13+ let user_warning_op (mem_pattern : Mem_pattern.t ) (linenum : int ) (msg : string )
14+ (names : string ) =
15+ if ! log_demotions then
16+ let mem_name =
17+ match mem_pattern with Mem_pattern. SoA -> " SoA" | AoS -> " AoS" in
18+ if not (String. is_empty names || String. is_empty msg) then
19+ Printf. eprintf " %s (Line %i) warning: %s\n " mem_name linenum (msg ^ names)
20+
21+ let concat_set_str (set : string Set.Poly.t ) =
22+ Set. fold
23+ ~f: (fun acc elem -> if acc = " " then acc ^ elem else acc ^ " , " ^ elem)
24+ ~init: " " set
25+
526(* *
627 Return a Var expression of the name for each type
728 containing an eigen matrix
@@ -99,7 +120,7 @@ let query_stan_math_mem_pattern_support (name : string)
99120 Frontend.SignatureMismatch. check_compatible_arguments_mod_conv x args
100121 |> Result. is_ok)
101122 namematches in
102- let is_soa = function _ , _ , _ , Mem_pattern. SoA -> true | _ -> false in
123+ let is_soa ( _ , _ , _ , p ) = p = Mem_pattern. SoA in
103124 List. exists ~f: is_soa filteredmatches
104125
105126(* Validate whether a function can support SoA matrices*)
@@ -117,13 +138,13 @@ let is_fun_soa_supported name exprs =
117138 will be returned if the matrix or vector is accessed by single
118139 cell indexing.
119140 *)
120- let rec query_initial_demotable_expr (in_loop : bool ) ~( acc : string Set.Poly.t )
121- Expr.Fixed. {pattern; _} : string Set.Poly.t =
141+ let rec query_initial_demotable_expr (in_loop : bool ) ( stmt_linenum : int )
142+ ~( acc : string Set.Poly.t ) Expr.Fixed. {pattern; _} : string Set.Poly.t =
122143 let query_expr (accum : string Set.Poly.t ) =
123- query_initial_demotable_expr in_loop ~acc: accum in
144+ query_initial_demotable_expr in_loop stmt_linenum ~acc: accum in
124145 match pattern with
125146 | FunApp (kind , (exprs : Expr.Typed.t list )) ->
126- query_initial_demotable_funs in_loop acc kind exprs
147+ query_initial_demotable_funs in_loop stmt_linenum acc kind exprs
127148 | Indexed ((Expr.Fixed. {meta = {type_; _} ; _} as expr ), indexed ) ->
128149 let index_set =
129150 Set.Poly. union_list
@@ -133,8 +154,12 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
133154 (query_expr acc))
134155 indexed) in
135156 let index_demotes =
136- if is_uni_eigen_loop_indexing in_loop type_ indexed then
137- Set. union (query_var_eigen_names expr) index_set
157+ if is_uni_eigen_loop_indexing in_loop type_ indexed then (
158+ let single_index_set = query_var_eigen_names expr in
159+ let failure_str = concat_set_str (Set. inter acc single_index_set) in
160+ let msg = " Accessed by element in a for loop: " in
161+ user_warning_op SoA stmt_linenum msg failure_str;
162+ Set. union single_index_set index_set)
138163 else Set. union (query_expr acc expr) index_set in
139164 Set. union acc index_demotes
140165 | Var (_ : string ) | Lit ((_ : Expr.Fixed.Pattern.litType ), (_ : string )) ->
@@ -143,9 +168,16 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
143168 | TupleProjection (expr , _ ) -> query_expr acc expr
144169 | TernaryIf (predicate , texpr , fexpr ) ->
145170 let predicate_demotes = query_expr acc predicate in
146- Set. union
147- (Set. union predicate_demotes (query_var_eigen_names texpr))
148- (query_var_eigen_names fexpr)
171+ let full_set =
172+ Set. union
173+ (Set. union predicate_demotes (query_var_eigen_names texpr))
174+ (query_var_eigen_names fexpr) in
175+ if Set. is_empty full_set then full_set
176+ else
177+ let failure_str = concat_set_str (Set. inter acc full_set) in
178+ let msg = " Used in a ternary operator which is not allowed: " in
179+ user_warning_op SoA stmt_linenum msg failure_str;
180+ full_set
149181 | EAnd (lhs , rhs ) | EOr (lhs , rhs ) ->
150182 (* We need to get the demotes from both sides*)
151183 let full_lhs_rhs = Set. union (query_expr acc lhs) (query_expr acc rhs) in
@@ -168,9 +200,11 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
168200 to the UDF.
169201 exprs The expression list passed to the functions.
170202 *)
171- and query_initial_demotable_funs (in_loop : bool ) (acc : string Set.Poly.t )
172- (kind : 'a Fun_kind.t ) (exprs : Expr.Typed.t list ) : string Set.Poly.t =
173- let query_expr accum = query_initial_demotable_expr in_loop ~acc: accum in
203+ and query_initial_demotable_funs (in_loop : bool ) (stmt_linenum : int )
204+ (acc : string Set.Poly.t ) (kind : 'a Fun_kind.t ) (exprs : Expr.Typed.t list )
205+ : string Set.Poly.t =
206+ let query_expr accum =
207+ query_initial_demotable_expr in_loop stmt_linenum ~acc: accum in
174208 let top_level_eigen_names =
175209 Set.Poly. union_list (List. map ~f: query_var_eigen_names exprs) in
176210 let demoted_eigen_names = List. fold ~init: acc ~f: query_expr exprs in
@@ -183,11 +217,26 @@ and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t)
183217 | name -> (
184218 match is_fun_soa_supported name exprs with
185219 | true -> Set. union acc demoted_eigen_names
186- | false -> Set. union acc demoted_and_top_level_names))
220+ | false ->
221+ let fail_names =
222+ concat_set_str (Set. inter acc top_level_eigen_names) in
223+ user_warning_op SoA stmt_linenum
224+ (" Function " ^ name ^ " is not supported: " )
225+ fail_names;
226+ Set. union acc demoted_and_top_level_names))
187227 | CompilerInternal (Internal_fun. FnMakeArray | FnMakeRowVec | FnMakeTuple ) ->
228+ let fail_names =
229+ concat_set_str (Set. inter acc demoted_and_top_level_names) in
230+ user_warning_op SoA stmt_linenum
231+ " Used in {} make array or make row vector compiler functions: "
232+ fail_names;
188233 Set. union acc demoted_and_top_level_names
189234 | CompilerInternal (_ : 'a Internal_fun.t ) -> acc
190235 | UserDefined ((_ : string ), (_ : bool Fun_kind.suffix )) ->
236+ let fail_names =
237+ concat_set_str (Set. inter acc demoted_and_top_level_names) in
238+ user_warning_op SoA stmt_linenum " Used in user defined function:"
239+ fail_names;
191240 Set. union acc demoted_and_top_level_names
192241
193242(* *
@@ -285,9 +334,10 @@ let contains_at_least_one_ad_matrix_or_all_data
285334 [query_initial_demotable_expr] for an explanation of the logic.
286335 *)
287336let rec query_initial_demotable_stmt (in_loop : bool ) (acc : string Set.Poly.t )
288- (Stmt.Fixed. {pattern; _} : Stmt.Located.t ) : string Set.Poly.t =
337+ (Stmt.Fixed. {pattern; meta} : Stmt.Located.t ) : string Set.Poly.t =
338+ let linenum = meta.end_loc.line_num in
289339 let query_expr (accum : string Set.Poly.t ) =
290- query_initial_demotable_expr in_loop ~acc: accum in
340+ query_initial_demotable_expr in_loop linenum ~acc: accum in
291341 match pattern with
292342 | Stmt.Fixed.Pattern. Assignment
293343 ( lval
@@ -301,21 +351,25 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
301351 List. fold ~init: acc
302352 ~f: (fun accum x ->
303353 Index. folder accum
304- (fun acc -> query_initial_demotable_expr in_loop ~acc )
354+ (fun acc -> query_initial_demotable_expr in_loop linenum ~acc )
305355 x)
306356 idx in
307357 match is_uni_eigen_loop_indexing in_loop ut idx with
308- | true -> Set. add idx_list name
358+ | true ->
359+ user_warning_op SoA linenum " Accessed by element in a for loop: "
360+ (if Set. mem acc name then " " else name);
361+ Set. add idx_list name
309362 | false -> idx_list in
310363 let rhs_demotable_names = query_expr acc rhs in
311364 let rhs_and_idx_demotions = Set. union idx_demotable rhs_demotable_names in
312365 (* RHS (1)*)
313366 let tuple_demotions =
314367 match lval with
315368 | LTupleProjection _ , _ ->
316- Set. add
317- (Set. union rhs_and_idx_demotions (query_var_eigen_names rhs))
318- name
369+ let tuple_set = query_var_eigen_names rhs in
370+ let fail_set = concat_set_str tuple_set in
371+ user_warning_op SoA linenum " Used in tuple: " fail_set;
372+ Set. add (Set. union rhs_and_idx_demotions tuple_set) name
319373 | _ -> rhs_and_idx_demotions in
320374 let assign_demotions =
321375 let is_eigen_stmt = UnsizedType. contains_eigen_type rhs.meta.type_ in
@@ -346,13 +400,34 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
346400 if
347401 is_all_rhs_aos || is_rhs_not_promoteable_to_soa
348402 || is_not_supported_func
349- then
350- Set. add (Set. union tuple_demotions (query_var_eigen_names rhs)) name
403+ then (
404+ let rhs_set = query_var_eigen_names rhs in
405+ let all_rhs_warn =
406+ if is_all_rhs_aos then
407+ " Right hand side of assignment is all AoS: "
408+ else " " in
409+ let rhs_not_promotable_to_soa_warn =
410+ if is_rhs_not_promoteable_to_soa then
411+ " The right hand side of the assignment only contains data and \
412+ scalar operations that are not promotable to SoA: "
413+ else " " in
414+ let not_supported_func_warn =
415+ if is_not_supported_func then
416+ " Function on right hand side of assignment is not supported by \
417+ SoA: "
418+ else " " in
419+ let rhs_name_set = Set. add rhs_set name in
420+ let rhs_name_set_str = concat_set_str rhs_name_set in
421+ user_warning_op SoA linenum all_rhs_warn rhs_name_set_str;
422+ user_warning_op SoA linenum rhs_not_promotable_to_soa_warn
423+ rhs_name_set_str;
424+ user_warning_op SoA linenum not_supported_func_warn rhs_name_set_str;
425+ Set. add (Set. union tuple_demotions rhs_set) name)
351426 else tuple_demotions
352427 else tuple_demotions in
353428 Set. union acc assign_demotions
354429 | NRFunApp (kind , exprs ) ->
355- query_initial_demotable_funs in_loop acc kind exprs
430+ query_initial_demotable_funs in_loop linenum acc kind exprs
356431 | IfElse (predicate , true_stmt , op_false_stmt ) ->
357432 let predicate_acc = query_expr acc predicate in
358433 Set. union acc
@@ -410,24 +485,36 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
410485 @param pattern The Stmt pattern to query.
411486 *)
412487let query_demotable_stmt (aos_exits : string Set.Poly.t )
413- (pattern : (Expr.Typed.t, int) Stmt.Fixed.Pattern.t ) : string Set.Poly.t =
414- match pattern with
488+ (stmt : Stmt.Located.Non_recursive.t ) : string Set.Poly.t =
489+ let linenum = stmt.meta.end_loc.line_num in
490+ match stmt.pattern with
415491 | Stmt.Fixed.Pattern. Assignment
416492 (lval, (_ : UnsizedType.t ), (rhs : Expr.Typed.t )) -> (
417493 let assign_name = Stmt.Helpers. lhs_variable lval in
418494 let all_rhs_eigen_names = query_var_eigen_names rhs in
419- if Set. mem aos_exits assign_name then
420- Set. add all_rhs_eigen_names assign_name
495+ if Set. mem aos_exits assign_name then (
496+ user_warning_op SoA linenum
497+ " Right hand side contains only AoS expressions: " assign_name;
498+ Set. add all_rhs_eigen_names assign_name)
421499 else
422500 match is_nonzero_subset ~set: aos_exits ~subset: all_rhs_eigen_names with
423- | true -> Set. add all_rhs_eigen_names assign_name
501+ | true ->
502+ user_warning_op SoA linenum
503+ " Right hand side contains only AoS expressions: " assign_name;
504+ Set. add all_rhs_eigen_names assign_name
424505 | false -> Set.Poly. empty)
425506 | Decl {decl_id; initialize = Assign e ; _} -> (
426507 let all_rhs_eigen_names = query_var_eigen_names e in
427- if Set. mem aos_exits decl_id then Set. add all_rhs_eigen_names decl_id
508+ if Set. mem aos_exits decl_id then (
509+ user_warning_op SoA linenum
510+ " Right hand side contains only AoS expressions: " decl_id;
511+ Set. add all_rhs_eigen_names decl_id)
428512 else
429513 match is_nonzero_subset ~set: aos_exits ~subset: all_rhs_eigen_names with
430- | true -> Set. add all_rhs_eigen_names decl_id
514+ | true ->
515+ user_warning_op SoA linenum
516+ " Right hand side contains only AoS expressions: " decl_id;
517+ Set. add all_rhs_eigen_names decl_id
431518 | false -> Set.Poly. empty)
432519 (* All other statements do not need logic here*)
433520 | _ -> Set.Poly. empty
0 commit comments