@@ -2,6 +2,27 @@ open Core
22open Core.Poly
33open Middle
44
5+ let log_demotions = ref true
6+
7+ let user_warning (mem_pattern : Mem_pattern.t ) (linenum : int ) (msg : string ) =
8+ if ! log_demotions then
9+ let mem_name =
10+ match mem_pattern with Mem_pattern. SoA -> " SoA" | AoS -> " AoS" in
11+ Printf. eprintf " %s (Line: %i) warning: %s\n " mem_name linenum msg
12+
13+ let user_warning_op (mem_pattern : Mem_pattern.t ) (linenum : int ) (msg : string )
14+ (names : string ) =
15+ if ! log_demotions then
16+ let mem_name =
17+ match mem_pattern with Mem_pattern. SoA -> " SoA" | AoS -> " AoS" in
18+ if not (String. is_empty names || String. is_empty msg) then
19+ Printf. eprintf " %s (Line %i) warning: %s\n " mem_name linenum (msg ^ names)
20+
21+ let concat_set_str (set : string Set.Poly.t ) =
22+ Set. fold
23+ ~f: (fun acc elem -> if acc = " " then acc ^ elem else acc ^ " , " ^ elem)
24+ ~init: " " set
25+
526(* *
627 Return a Var expression of the name for each type
728 containing an eigen matrix
@@ -98,7 +119,7 @@ let query_stan_math_mem_pattern_support (name : string)
98119 Frontend.SignatureMismatch. check_compatible_arguments_mod_conv x args
99120 |> Result. is_ok)
100121 namematches in
101- let is_soa = function _ , _ , _ , Mem_pattern. SoA -> true | _ -> false in
122+ let is_soa ( _ , _ , _ , p ) = p = Mem_pattern. SoA in
102123 List. exists ~f: is_soa filteredmatches
103124
104125(* Validate whether a function can support SoA matrices*)
@@ -116,13 +137,13 @@ let is_fun_soa_supported name exprs =
116137 will be returned if the matrix or vector is accessed by single
117138 cell indexing.
118139 *)
119- let rec query_initial_demotable_expr (in_loop : bool ) ~( acc : string Set.Poly.t )
120- Expr. {pattern; _} : string Set.Poly.t =
140+ let rec query_initial_demotable_expr (in_loop : bool ) ( stmt_linenum : int )
141+ ~( acc : string Set.Poly.t ) Expr. {pattern; _} : string Set.Poly.t =
121142 let query_expr (accum : string Set.Poly.t ) =
122- query_initial_demotable_expr in_loop ~acc: accum in
143+ query_initial_demotable_expr in_loop stmt_linenum ~acc: accum in
123144 match pattern with
124145 | FunApp (kind , (exprs : Expr.Typed.t list )) ->
125- query_initial_demotable_funs in_loop acc kind exprs
146+ query_initial_demotable_funs in_loop stmt_linenum acc kind exprs
126147 | Indexed ((Expr. {meta = {type_; _} ; _} as expr ), indexed ) ->
127148 let index_set =
128149 Set.Poly. union_list
@@ -132,18 +153,29 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
132153 (query_expr acc))
133154 indexed) in
134155 let index_demotes =
135- if is_uni_eigen_loop_indexing in_loop type_ indexed then
136- Set. union (query_var_eigen_names expr) index_set
156+ if is_uni_eigen_loop_indexing in_loop type_ indexed then (
157+ let single_index_set = query_var_eigen_names expr in
158+ let failure_str = concat_set_str (Set. inter acc single_index_set) in
159+ let msg = " Accessed by element in a for loop: " in
160+ user_warning_op SoA stmt_linenum msg failure_str;
161+ Set. union single_index_set index_set)
137162 else Set. union (query_expr acc expr) index_set in
138163 Set. union acc index_demotes
139164 | Var (_ : string ) | Lit ((_ : Expr.Pattern.litType ), (_ : string )) -> acc
140165 | Promotion (expr , _ , _ ) -> query_expr acc expr
141166 | TupleProjection (expr , _ ) -> query_expr acc expr
142167 | TernaryIf (predicate , texpr , fexpr ) ->
143168 let predicate_demotes = query_expr acc predicate in
144- Set. union
145- (Set. union predicate_demotes (query_var_eigen_names texpr))
146- (query_var_eigen_names fexpr)
169+ let full_set =
170+ Set. union
171+ (Set. union predicate_demotes (query_var_eigen_names texpr))
172+ (query_var_eigen_names fexpr) in
173+ if Set. is_empty full_set then full_set
174+ else
175+ let failure_str = concat_set_str (Set. inter acc full_set) in
176+ let msg = " Used in a ternary operator which is not allowed: " in
177+ user_warning_op SoA stmt_linenum msg failure_str;
178+ full_set
147179 | EAnd (lhs , rhs ) | EOr (lhs , rhs ) ->
148180 (* We need to get the demotes from both sides*)
149181 let full_lhs_rhs = Set. union (query_expr acc lhs) (query_expr acc rhs) in
@@ -166,9 +198,11 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
166198 to the UDF.
167199 exprs The expression list passed to the functions.
168200 *)
169- and query_initial_demotable_funs (in_loop : bool ) (acc : string Set.Poly.t )
170- (kind : 'a Fun_kind.t ) (exprs : Expr.Typed.t list ) : string Set.Poly.t =
171- let query_expr accum = query_initial_demotable_expr in_loop ~acc: accum in
201+ and query_initial_demotable_funs (in_loop : bool ) (stmt_linenum : int )
202+ (acc : string Set.Poly.t ) (kind : 'a Fun_kind.t ) (exprs : Expr.Typed.t list )
203+ : string Set.Poly.t =
204+ let query_expr accum =
205+ query_initial_demotable_expr in_loop stmt_linenum ~acc: accum in
172206 let top_level_eigen_names =
173207 Set.Poly. union_list (List. map ~f: query_var_eigen_names exprs) in
174208 let demoted_eigen_names = List. fold ~init: acc ~f: query_expr exprs in
@@ -181,11 +215,26 @@ and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t)
181215 | name -> (
182216 match is_fun_soa_supported name exprs with
183217 | true -> Set. union acc demoted_eigen_names
184- | false -> Set. union acc demoted_and_top_level_names))
218+ | false ->
219+ let fail_names =
220+ concat_set_str (Set. inter acc top_level_eigen_names) in
221+ user_warning_op SoA stmt_linenum
222+ (" Function " ^ name ^ " is not supported: " )
223+ fail_names;
224+ Set. union acc demoted_and_top_level_names))
185225 | CompilerInternal (Internal_fun. FnMakeArray | FnMakeRowVec | FnMakeTuple ) ->
226+ let fail_names =
227+ concat_set_str (Set. inter acc demoted_and_top_level_names) in
228+ user_warning_op SoA stmt_linenum
229+ " Used in {} make array or make row vector compiler functions: "
230+ fail_names;
186231 Set. union acc demoted_and_top_level_names
187232 | CompilerInternal (_ : 'a Internal_fun.t ) -> acc
188233 | UserDefined ((_ : string ), (_ : bool Fun_kind.suffix )) ->
234+ let fail_names =
235+ concat_set_str (Set. inter acc demoted_and_top_level_names) in
236+ user_warning_op SoA stmt_linenum " Used in user defined function:"
237+ fail_names;
189238 Set. union acc demoted_and_top_level_names
190239
191240(* *
@@ -283,9 +332,10 @@ let contains_at_least_one_ad_matrix_or_all_data
283332 [query_initial_demotable_expr] for an explanation of the logic.
284333 *)
285334let rec query_initial_demotable_stmt (in_loop : bool ) (acc : string Set.Poly.t )
286- (Stmt. {pattern; _} : Stmt.Located.t ) : string Set.Poly.t =
335+ (Stmt. {pattern; meta} : Stmt.Located.t ) : string Set.Poly.t =
336+ let linenum = meta.end_loc.line_num in
287337 let query_expr (accum : string Set.Poly.t ) =
288- query_initial_demotable_expr in_loop ~acc: accum in
338+ query_initial_demotable_expr in_loop linenum ~acc: accum in
289339 match pattern with
290340 | Stmt.Pattern. Assignment
291341 ( lval
@@ -299,21 +349,25 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
299349 List. fold ~init: acc
300350 ~f: (fun accum x ->
301351 Index. folder accum
302- (fun acc -> query_initial_demotable_expr in_loop ~acc )
352+ (fun acc -> query_initial_demotable_expr in_loop linenum ~acc )
303353 x)
304354 idx in
305355 match is_uni_eigen_loop_indexing in_loop ut idx with
306- | true -> Set. add idx_list name
356+ | true ->
357+ user_warning_op SoA linenum " Accessed by element in a for loop: "
358+ (if Set. mem acc name then " " else name);
359+ Set. add idx_list name
307360 | false -> idx_list in
308361 let rhs_demotable_names = query_expr acc rhs in
309362 let rhs_and_idx_demotions = Set. union idx_demotable rhs_demotable_names in
310363 (* RHS (1)*)
311364 let tuple_demotions =
312365 match lval with
313366 | LTupleProjection _ , _ ->
314- Set. add
315- (Set. union rhs_and_idx_demotions (query_var_eigen_names rhs))
316- name
367+ let tuple_set = query_var_eigen_names rhs in
368+ let fail_set = concat_set_str tuple_set in
369+ user_warning_op SoA linenum " Used in tuple: " fail_set;
370+ Set. add (Set. union rhs_and_idx_demotions tuple_set) name
317371 | _ -> rhs_and_idx_demotions in
318372 let assign_demotions =
319373 let is_eigen_stmt = UnsizedType. contains_eigen_type rhs.meta.type_ in
@@ -344,13 +398,34 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
344398 if
345399 is_all_rhs_aos || is_rhs_not_promoteable_to_soa
346400 || is_not_supported_func
347- then
348- Set. add (Set. union tuple_demotions (query_var_eigen_names rhs)) name
401+ then (
402+ let rhs_set = query_var_eigen_names rhs in
403+ let all_rhs_warn =
404+ if is_all_rhs_aos then
405+ " Right hand side of assignment is all AoS: "
406+ else " " in
407+ let rhs_not_promotable_to_soa_warn =
408+ if is_rhs_not_promoteable_to_soa then
409+ " The right hand side of the assignment only contains data and \
410+ scalar operations that are not promotable to SoA: "
411+ else " " in
412+ let not_supported_func_warn =
413+ if is_not_supported_func then
414+ " Function on right hand side of assignment is not supported by \
415+ SoA: "
416+ else " " in
417+ let rhs_name_set = Set. add rhs_set name in
418+ let rhs_name_set_str = concat_set_str rhs_name_set in
419+ user_warning_op SoA linenum all_rhs_warn rhs_name_set_str;
420+ user_warning_op SoA linenum rhs_not_promotable_to_soa_warn
421+ rhs_name_set_str;
422+ user_warning_op SoA linenum not_supported_func_warn rhs_name_set_str;
423+ Set. add (Set. union tuple_demotions rhs_set) name)
349424 else tuple_demotions
350425 else tuple_demotions in
351426 Set. union acc assign_demotions
352427 | NRFunApp (kind , exprs ) ->
353- query_initial_demotable_funs in_loop acc kind exprs
428+ query_initial_demotable_funs in_loop linenum acc kind exprs
354429 | IfElse (predicate , true_stmt , op_false_stmt ) ->
355430 let predicate_acc = query_expr acc predicate in
356431 Set. union acc
@@ -408,24 +483,36 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
408483 @param pattern The Stmt pattern to query.
409484 *)
410485let query_demotable_stmt (aos_exits : string Set.Poly.t )
411- (pattern : (Expr.Typed.t, int) Stmt.Pattern.t ) : string Set.Poly.t =
412- match pattern with
486+ (stmt : Stmt.Located.Non_recursive.t ) : string Set.Poly.t =
487+ let linenum = stmt.meta.end_loc.line_num in
488+ match stmt.pattern with
413489 | Stmt.Pattern. Assignment (lval, (_ : UnsizedType.t ), (rhs : Expr.Typed.t ))
414490 -> (
415491 let assign_name = Stmt.Helpers. lhs_variable lval in
416492 let all_rhs_eigen_names = query_var_eigen_names rhs in
417- if Set. mem aos_exits assign_name then
418- Set. add all_rhs_eigen_names assign_name
493+ if Set. mem aos_exits assign_name then (
494+ user_warning_op SoA linenum
495+ " Right hand side contains only AoS expressions: " assign_name;
496+ Set. add all_rhs_eigen_names assign_name)
419497 else
420498 match is_nonzero_subset ~set: aos_exits ~subset: all_rhs_eigen_names with
421- | true -> Set. add all_rhs_eigen_names assign_name
499+ | true ->
500+ user_warning_op SoA linenum
501+ " Right hand side contains only AoS expressions: " assign_name;
502+ Set. add all_rhs_eigen_names assign_name
422503 | false -> Set.Poly. empty)
423504 | Decl {decl_id; initialize = Assign e ; _} -> (
424505 let all_rhs_eigen_names = query_var_eigen_names e in
425- if Set. mem aos_exits decl_id then Set. add all_rhs_eigen_names decl_id
506+ if Set. mem aos_exits decl_id then (
507+ user_warning_op SoA linenum
508+ " Right hand side contains only AoS expressions: " decl_id;
509+ Set. add all_rhs_eigen_names decl_id)
426510 else
427511 match is_nonzero_subset ~set: aos_exits ~subset: all_rhs_eigen_names with
428- | true -> Set. add all_rhs_eigen_names decl_id
512+ | true ->
513+ user_warning_op SoA linenum
514+ " Right hand side contains only AoS expressions: " decl_id;
515+ Set. add all_rhs_eigen_names decl_id
429516 | false -> Set.Poly. empty)
430517 (* All other statements do not need logic here*)
431518 | _ -> Set.Poly. empty
0 commit comments