Skip to content

Commit e46c13f

Browse files
committed
Split out logging from #1353
1 parent a0ecb22 commit e46c13f

File tree

2 files changed

+119
-33
lines changed

2 files changed

+119
-33
lines changed

src/analysis_and_optimization/Memory_patterns.ml

Lines changed: 118 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,27 @@ open Core
22
open Core.Poly
33
open Middle
44

5+
let log_demotions = ref true
6+
7+
let user_warning (mem_pattern : Mem_pattern.t) (linenum : int) (msg : string) =
8+
if !log_demotions then
9+
let mem_name =
10+
match mem_pattern with Mem_pattern.SoA -> "SoA" | AoS -> "AoS" in
11+
Printf.eprintf "%s (Line: %i) warning: %s\n" mem_name linenum msg
12+
13+
let user_warning_op (mem_pattern : Mem_pattern.t) (linenum : int) (msg : string)
14+
(names : string) =
15+
if !log_demotions then
16+
let mem_name =
17+
match mem_pattern with Mem_pattern.SoA -> "SoA" | AoS -> "AoS" in
18+
if not (String.is_empty names || String.is_empty msg) then
19+
Printf.eprintf "%s (Line %i) warning: %s\n" mem_name linenum (msg ^ names)
20+
21+
let concat_set_str (set : string Set.Poly.t) =
22+
Set.fold
23+
~f:(fun acc elem -> if acc = "" then acc ^ elem else acc ^ ", " ^ elem)
24+
~init:"" set
25+
526
(**
627
Return a Var expression of the name for each type
728
containing an eigen matrix
@@ -98,7 +119,7 @@ let query_stan_math_mem_pattern_support (name : string)
98119
Frontend.SignatureMismatch.check_compatible_arguments_mod_conv x args
99120
|> Result.is_ok)
100121
namematches in
101-
let is_soa = function _, _, _, Mem_pattern.SoA -> true | _ -> false in
122+
let is_soa (_, _, _, p) = p = Mem_pattern.SoA in
102123
List.exists ~f:is_soa filteredmatches
103124

104125
(*Validate whether a function can support SoA matrices*)
@@ -116,13 +137,13 @@ let is_fun_soa_supported name exprs =
116137
will be returned if the matrix or vector is accessed by single
117138
cell indexing.
118139
*)
119-
let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
120-
Expr.{pattern; _} : string Set.Poly.t =
140+
let rec query_initial_demotable_expr (in_loop : bool) (stmt_linenum : int)
141+
~(acc : string Set.Poly.t) Expr.{pattern; _} : string Set.Poly.t =
121142
let query_expr (accum : string Set.Poly.t) =
122-
query_initial_demotable_expr in_loop ~acc:accum in
143+
query_initial_demotable_expr in_loop stmt_linenum ~acc:accum in
123144
match pattern with
124145
| FunApp (kind, (exprs : Expr.Typed.t list)) ->
125-
query_initial_demotable_funs in_loop acc kind exprs
146+
query_initial_demotable_funs in_loop stmt_linenum acc kind exprs
126147
| Indexed ((Expr.{meta= {type_; _}; _} as expr), indexed) ->
127148
let index_set =
128149
Set.Poly.union_list
@@ -132,18 +153,29 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
132153
(query_expr acc))
133154
indexed) in
134155
let index_demotes =
135-
if is_uni_eigen_loop_indexing in_loop type_ indexed then
136-
Set.union (query_var_eigen_names expr) index_set
156+
if is_uni_eigen_loop_indexing in_loop type_ indexed then (
157+
let single_index_set = query_var_eigen_names expr in
158+
let failure_str = concat_set_str (Set.inter acc single_index_set) in
159+
let msg = "Accessed by element in a for loop: " in
160+
user_warning_op SoA stmt_linenum msg failure_str;
161+
Set.union single_index_set index_set)
137162
else Set.union (query_expr acc expr) index_set in
138163
Set.union acc index_demotes
139164
| Var (_ : string) | Lit ((_ : Expr.Pattern.litType), (_ : string)) -> acc
140165
| Promotion (expr, _, _) -> query_expr acc expr
141166
| TupleProjection (expr, _) -> query_expr acc expr
142167
| TernaryIf (predicate, texpr, fexpr) ->
143168
let predicate_demotes = query_expr acc predicate in
144-
Set.union
145-
(Set.union predicate_demotes (query_var_eigen_names texpr))
146-
(query_var_eigen_names fexpr)
169+
let full_set =
170+
Set.union
171+
(Set.union predicate_demotes (query_var_eigen_names texpr))
172+
(query_var_eigen_names fexpr) in
173+
if Set.is_empty full_set then full_set
174+
else
175+
let failure_str = concat_set_str (Set.inter acc full_set) in
176+
let msg = "Used in a ternary operator which is not allowed: " in
177+
user_warning_op SoA stmt_linenum msg failure_str;
178+
full_set
147179
| EAnd (lhs, rhs) | EOr (lhs, rhs) ->
148180
(*We need to get the demotes from both sides*)
149181
let full_lhs_rhs = Set.union (query_expr acc lhs) (query_expr acc rhs) in
@@ -166,9 +198,11 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
166198
to the UDF.
167199
exprs The expression list passed to the functions.
168200
*)
169-
and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t)
170-
(kind : 'a Fun_kind.t) (exprs : Expr.Typed.t list) : string Set.Poly.t =
171-
let query_expr accum = query_initial_demotable_expr in_loop ~acc:accum in
201+
and query_initial_demotable_funs (in_loop : bool) (stmt_linenum : int)
202+
(acc : string Set.Poly.t) (kind : 'a Fun_kind.t) (exprs : Expr.Typed.t list)
203+
: string Set.Poly.t =
204+
let query_expr accum =
205+
query_initial_demotable_expr in_loop stmt_linenum ~acc:accum in
172206
let top_level_eigen_names =
173207
Set.Poly.union_list (List.map ~f:query_var_eigen_names exprs) in
174208
let demoted_eigen_names = List.fold ~init:acc ~f:query_expr exprs in
@@ -181,11 +215,26 @@ and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t)
181215
| name -> (
182216
match is_fun_soa_supported name exprs with
183217
| true -> Set.union acc demoted_eigen_names
184-
| false -> Set.union acc demoted_and_top_level_names))
218+
| false ->
219+
let fail_names =
220+
concat_set_str (Set.inter acc top_level_eigen_names) in
221+
user_warning_op SoA stmt_linenum
222+
("Function " ^ name ^ " is not supported: ")
223+
fail_names;
224+
Set.union acc demoted_and_top_level_names))
185225
| CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec | FnMakeTuple) ->
226+
let fail_names =
227+
concat_set_str (Set.inter acc demoted_and_top_level_names) in
228+
user_warning_op SoA stmt_linenum
229+
"Used in {} make array or make row vector compiler functions: "
230+
fail_names;
186231
Set.union acc demoted_and_top_level_names
187232
| CompilerInternal (_ : 'a Internal_fun.t) -> acc
188233
| UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) ->
234+
let fail_names =
235+
concat_set_str (Set.inter acc demoted_and_top_level_names) in
236+
user_warning_op SoA stmt_linenum "Used in user defined function:"
237+
fail_names;
189238
Set.union acc demoted_and_top_level_names
190239

191240
(**
@@ -283,9 +332,10 @@ let contains_at_least_one_ad_matrix_or_all_data
283332
[query_initial_demotable_expr] for an explanation of the logic.
284333
*)
285334
let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
286-
(Stmt.{pattern; _} : Stmt.Located.t) : string Set.Poly.t =
335+
(Stmt.{pattern; meta} : Stmt.Located.t) : string Set.Poly.t =
336+
let linenum = meta.end_loc.line_num in
287337
let query_expr (accum : string Set.Poly.t) =
288-
query_initial_demotable_expr in_loop ~acc:accum in
338+
query_initial_demotable_expr in_loop linenum ~acc:accum in
289339
match pattern with
290340
| Stmt.Pattern.Assignment
291341
( lval
@@ -299,21 +349,25 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
299349
List.fold ~init:acc
300350
~f:(fun accum x ->
301351
Index.folder accum
302-
(fun acc -> query_initial_demotable_expr in_loop ~acc)
352+
(fun acc -> query_initial_demotable_expr in_loop linenum ~acc)
303353
x)
304354
idx in
305355
match is_uni_eigen_loop_indexing in_loop ut idx with
306-
| true -> Set.add idx_list name
356+
| true ->
357+
user_warning_op SoA linenum "Accessed by element in a for loop: "
358+
(if Set.mem acc name then "" else name);
359+
Set.add idx_list name
307360
| false -> idx_list in
308361
let rhs_demotable_names = query_expr acc rhs in
309362
let rhs_and_idx_demotions = Set.union idx_demotable rhs_demotable_names in
310363
(* RHS (1)*)
311364
let tuple_demotions =
312365
match lval with
313366
| LTupleProjection _, _ ->
314-
Set.add
315-
(Set.union rhs_and_idx_demotions (query_var_eigen_names rhs))
316-
name
367+
let tuple_set = query_var_eigen_names rhs in
368+
let fail_set = concat_set_str tuple_set in
369+
user_warning_op SoA linenum "Used in tuple: " fail_set;
370+
Set.add (Set.union rhs_and_idx_demotions tuple_set) name
317371
| _ -> rhs_and_idx_demotions in
318372
let assign_demotions =
319373
let is_eigen_stmt = UnsizedType.contains_eigen_type rhs.meta.type_ in
@@ -344,13 +398,34 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
344398
if
345399
is_all_rhs_aos || is_rhs_not_promoteable_to_soa
346400
|| is_not_supported_func
347-
then
348-
Set.add (Set.union tuple_demotions (query_var_eigen_names rhs)) name
401+
then (
402+
let rhs_set = query_var_eigen_names rhs in
403+
let all_rhs_warn =
404+
if is_all_rhs_aos then
405+
"Right hand side of assignment is all AoS: "
406+
else "" in
407+
let rhs_not_promotable_to_soa_warn =
408+
if is_rhs_not_promoteable_to_soa then
409+
"The right hand side of the assignment only contains data and \
410+
scalar operations that are not promotable to SoA: "
411+
else "" in
412+
let not_supported_func_warn =
413+
if is_not_supported_func then
414+
"Function on right hand side of assignment is not supported by \
415+
SoA: "
416+
else "" in
417+
let rhs_name_set = Set.add rhs_set name in
418+
let rhs_name_set_str = concat_set_str rhs_name_set in
419+
user_warning_op SoA linenum all_rhs_warn rhs_name_set_str;
420+
user_warning_op SoA linenum rhs_not_promotable_to_soa_warn
421+
rhs_name_set_str;
422+
user_warning_op SoA linenum not_supported_func_warn rhs_name_set_str;
423+
Set.add (Set.union tuple_demotions rhs_set) name)
349424
else tuple_demotions
350425
else tuple_demotions in
351426
Set.union acc assign_demotions
352427
| NRFunApp (kind, exprs) ->
353-
query_initial_demotable_funs in_loop acc kind exprs
428+
query_initial_demotable_funs in_loop linenum acc kind exprs
354429
| IfElse (predicate, true_stmt, op_false_stmt) ->
355430
let predicate_acc = query_expr acc predicate in
356431
Set.union acc
@@ -408,24 +483,36 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
408483
@param pattern The Stmt pattern to query.
409484
*)
410485
let query_demotable_stmt (aos_exits : string Set.Poly.t)
411-
(pattern : (Expr.Typed.t, int) Stmt.Pattern.t) : string Set.Poly.t =
412-
match pattern with
486+
(stmt : Stmt.Located.Non_recursive.t) : string Set.Poly.t =
487+
let linenum = stmt.meta.end_loc.line_num in
488+
match stmt.pattern with
413489
| Stmt.Pattern.Assignment (lval, (_ : UnsizedType.t), (rhs : Expr.Typed.t))
414490
-> (
415491
let assign_name = Stmt.Helpers.lhs_variable lval in
416492
let all_rhs_eigen_names = query_var_eigen_names rhs in
417-
if Set.mem aos_exits assign_name then
418-
Set.add all_rhs_eigen_names assign_name
493+
if Set.mem aos_exits assign_name then (
494+
user_warning_op SoA linenum
495+
"Right hand side contains only AoS expressions: " assign_name;
496+
Set.add all_rhs_eigen_names assign_name)
419497
else
420498
match is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names with
421-
| true -> Set.add all_rhs_eigen_names assign_name
499+
| true ->
500+
user_warning_op SoA linenum
501+
"Right hand side contains only AoS expressions: " assign_name;
502+
Set.add all_rhs_eigen_names assign_name
422503
| false -> Set.Poly.empty)
423504
| Decl {decl_id; initialize= Assign e; _} -> (
424505
let all_rhs_eigen_names = query_var_eigen_names e in
425-
if Set.mem aos_exits decl_id then Set.add all_rhs_eigen_names decl_id
506+
if Set.mem aos_exits decl_id then (
507+
user_warning_op SoA linenum
508+
"Right hand side contains only AoS expressions: " decl_id;
509+
Set.add all_rhs_eigen_names decl_id)
426510
else
427511
match is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names with
428-
| true -> Set.add all_rhs_eigen_names decl_id
512+
| true ->
513+
user_warning_op SoA linenum
514+
"Right hand side contains only AoS expressions: " decl_id;
515+
Set.add all_rhs_eigen_names decl_id
429516
| false -> Set.Poly.empty)
430517
(* All other statements do not need logic here*)
431518
| _ -> Set.Poly.empty

src/analysis_and_optimization/Optimize.ml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,8 +1228,7 @@ let optimize_soa (mir : Program.Typed.t) =
12281228
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t)
12291229
(l : int) (aos_variables : string Set.Poly.t) =
12301230
let mir_node mir_idx = Map.find_exn flowgraph_to_mir mir_idx in
1231-
match (mir_node l).pattern with
1232-
| stmt -> Memory_patterns.query_demotable_stmt aos_variables stmt in
1231+
Memory_patterns.query_demotable_stmt aos_variables (mir_node l) in
12331232
let initial_variables =
12341233
List.fold ~init:Set.Poly.empty
12351234
~f:(Memory_patterns.query_initial_demotable_stmt false)

0 commit comments

Comments
 (0)