Skip to content

Commit 98cab57

Browse files
committed
Split out logging from #1353
1 parent 216cb27 commit 98cab57

File tree

2 files changed

+119
-33
lines changed

2 files changed

+119
-33
lines changed

src/analysis_and_optimization/Memory_patterns.ml

Lines changed: 118 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,27 @@ open Core
22
open Core.Poly
33
open Middle
44

5+
let log_demotions = ref true
6+
7+
let user_warning (mem_pattern : Mem_pattern.t) (linenum : int) (msg : string) =
8+
if !log_demotions then
9+
let mem_name =
10+
match mem_pattern with Mem_pattern.SoA -> "SoA" | AoS -> "AoS" in
11+
Printf.eprintf "%s (Line: %i) warning: %s\n" mem_name linenum msg
12+
13+
let user_warning_op (mem_pattern : Mem_pattern.t) (linenum : int) (msg : string)
14+
(names : string) =
15+
if !log_demotions then
16+
let mem_name =
17+
match mem_pattern with Mem_pattern.SoA -> "SoA" | AoS -> "AoS" in
18+
if not (String.is_empty names || String.is_empty msg) then
19+
Printf.eprintf "%s (Line %i) warning: %s\n" mem_name linenum (msg ^ names)
20+
21+
let concat_set_str (set : string Set.Poly.t) =
22+
Set.fold
23+
~f:(fun acc elem -> if acc = "" then acc ^ elem else acc ^ ", " ^ elem)
24+
~init:"" set
25+
526
(**
627
Return a Var expression of the name for each type
728
containing an eigen matrix
@@ -99,7 +120,7 @@ let query_stan_math_mem_pattern_support (name : string)
99120
Frontend.SignatureMismatch.check_compatible_arguments_mod_conv x args
100121
|> Result.is_ok)
101122
namematches in
102-
let is_soa = function _, _, _, Mem_pattern.SoA -> true | _ -> false in
123+
let is_soa (_, _, _, p) = p = Mem_pattern.SoA in
103124
List.exists ~f:is_soa filteredmatches
104125

105126
(*Validate whether a function can support SoA matrices*)
@@ -117,13 +138,13 @@ let is_fun_soa_supported name exprs =
117138
will be returned if the matrix or vector is accessed by single
118139
cell indexing.
119140
*)
120-
let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
121-
Expr.Fixed.{pattern; _} : string Set.Poly.t =
141+
let rec query_initial_demotable_expr (in_loop : bool) (stmt_linenum : int)
142+
~(acc : string Set.Poly.t) Expr.Fixed.{pattern; _} : string Set.Poly.t =
122143
let query_expr (accum : string Set.Poly.t) =
123-
query_initial_demotable_expr in_loop ~acc:accum in
144+
query_initial_demotable_expr in_loop stmt_linenum ~acc:accum in
124145
match pattern with
125146
| FunApp (kind, (exprs : Expr.Typed.t list)) ->
126-
query_initial_demotable_funs in_loop acc kind exprs
147+
query_initial_demotable_funs in_loop stmt_linenum acc kind exprs
127148
| Indexed ((Expr.Fixed.{meta= {type_; _}; _} as expr), indexed) ->
128149
let index_set =
129150
Set.Poly.union_list
@@ -133,8 +154,12 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
133154
(query_expr acc))
134155
indexed) in
135156
let index_demotes =
136-
if is_uni_eigen_loop_indexing in_loop type_ indexed then
137-
Set.union (query_var_eigen_names expr) index_set
157+
if is_uni_eigen_loop_indexing in_loop type_ indexed then (
158+
let single_index_set = query_var_eigen_names expr in
159+
let failure_str = concat_set_str (Set.inter acc single_index_set) in
160+
let msg = "Accessed by element in a for loop: " in
161+
user_warning_op SoA stmt_linenum msg failure_str;
162+
Set.union single_index_set index_set)
138163
else Set.union (query_expr acc expr) index_set in
139164
Set.union acc index_demotes
140165
| Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) ->
@@ -143,9 +168,16 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
143168
| TupleProjection (expr, _) -> query_expr acc expr
144169
| TernaryIf (predicate, texpr, fexpr) ->
145170
let predicate_demotes = query_expr acc predicate in
146-
Set.union
147-
(Set.union predicate_demotes (query_var_eigen_names texpr))
148-
(query_var_eigen_names fexpr)
171+
let full_set =
172+
Set.union
173+
(Set.union predicate_demotes (query_var_eigen_names texpr))
174+
(query_var_eigen_names fexpr) in
175+
if Set.is_empty full_set then full_set
176+
else
177+
let failure_str = concat_set_str (Set.inter acc full_set) in
178+
let msg = "Used in a ternary operator which is not allowed: " in
179+
user_warning_op SoA stmt_linenum msg failure_str;
180+
full_set
149181
| EAnd (lhs, rhs) | EOr (lhs, rhs) ->
150182
(*We need to get the demotes from both sides*)
151183
let full_lhs_rhs = Set.union (query_expr acc lhs) (query_expr acc rhs) in
@@ -168,9 +200,11 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
168200
to the UDF.
169201
exprs The expression list passed to the functions.
170202
*)
171-
and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t)
172-
(kind : 'a Fun_kind.t) (exprs : Expr.Typed.t list) : string Set.Poly.t =
173-
let query_expr accum = query_initial_demotable_expr in_loop ~acc:accum in
203+
and query_initial_demotable_funs (in_loop : bool) (stmt_linenum : int)
204+
(acc : string Set.Poly.t) (kind : 'a Fun_kind.t) (exprs : Expr.Typed.t list)
205+
: string Set.Poly.t =
206+
let query_expr accum =
207+
query_initial_demotable_expr in_loop stmt_linenum ~acc:accum in
174208
let top_level_eigen_names =
175209
Set.Poly.union_list (List.map ~f:query_var_eigen_names exprs) in
176210
let demoted_eigen_names = List.fold ~init:acc ~f:query_expr exprs in
@@ -183,11 +217,26 @@ and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t)
183217
| name -> (
184218
match is_fun_soa_supported name exprs with
185219
| true -> Set.union acc demoted_eigen_names
186-
| false -> Set.union acc demoted_and_top_level_names))
220+
| false ->
221+
let fail_names =
222+
concat_set_str (Set.inter acc top_level_eigen_names) in
223+
user_warning_op SoA stmt_linenum
224+
("Function " ^ name ^ " is not supported: ")
225+
fail_names;
226+
Set.union acc demoted_and_top_level_names))
187227
| CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec | FnMakeTuple) ->
228+
let fail_names =
229+
concat_set_str (Set.inter acc demoted_and_top_level_names) in
230+
user_warning_op SoA stmt_linenum
231+
"Used in {} make array or make row vector compiler functions: "
232+
fail_names;
188233
Set.union acc demoted_and_top_level_names
189234
| CompilerInternal (_ : 'a Internal_fun.t) -> acc
190235
| UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) ->
236+
let fail_names =
237+
concat_set_str (Set.inter acc demoted_and_top_level_names) in
238+
user_warning_op SoA stmt_linenum "Used in user defined function:"
239+
fail_names;
191240
Set.union acc demoted_and_top_level_names
192241

193242
(**
@@ -285,9 +334,10 @@ let contains_at_least_one_ad_matrix_or_all_data
285334
[query_initial_demotable_expr] for an explanation of the logic.
286335
*)
287336
let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
288-
(Stmt.Fixed.{pattern; _} : Stmt.Located.t) : string Set.Poly.t =
337+
(Stmt.Fixed.{pattern; meta} : Stmt.Located.t) : string Set.Poly.t =
338+
let linenum = meta.end_loc.line_num in
289339
let query_expr (accum : string Set.Poly.t) =
290-
query_initial_demotable_expr in_loop ~acc:accum in
340+
query_initial_demotable_expr in_loop linenum ~acc:accum in
291341
match pattern with
292342
| Stmt.Fixed.Pattern.Assignment
293343
( lval
@@ -301,21 +351,25 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
301351
List.fold ~init:acc
302352
~f:(fun accum x ->
303353
Index.folder accum
304-
(fun acc -> query_initial_demotable_expr in_loop ~acc)
354+
(fun acc -> query_initial_demotable_expr in_loop linenum ~acc)
305355
x)
306356
idx in
307357
match is_uni_eigen_loop_indexing in_loop ut idx with
308-
| true -> Set.add idx_list name
358+
| true ->
359+
user_warning_op SoA linenum "Accessed by element in a for loop: "
360+
(if Set.mem acc name then "" else name);
361+
Set.add idx_list name
309362
| false -> idx_list in
310363
let rhs_demotable_names = query_expr acc rhs in
311364
let rhs_and_idx_demotions = Set.union idx_demotable rhs_demotable_names in
312365
(* RHS (1)*)
313366
let tuple_demotions =
314367
match lval with
315368
| LTupleProjection _, _ ->
316-
Set.add
317-
(Set.union rhs_and_idx_demotions (query_var_eigen_names rhs))
318-
name
369+
let tuple_set = query_var_eigen_names rhs in
370+
let fail_set = concat_set_str tuple_set in
371+
user_warning_op SoA linenum "Used in tuple: " fail_set;
372+
Set.add (Set.union rhs_and_idx_demotions tuple_set) name
319373
| _ -> rhs_and_idx_demotions in
320374
let assign_demotions =
321375
let is_eigen_stmt = UnsizedType.contains_eigen_type rhs.meta.type_ in
@@ -346,13 +400,34 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
346400
if
347401
is_all_rhs_aos || is_rhs_not_promoteable_to_soa
348402
|| is_not_supported_func
349-
then
350-
Set.add (Set.union tuple_demotions (query_var_eigen_names rhs)) name
403+
then (
404+
let rhs_set = query_var_eigen_names rhs in
405+
let all_rhs_warn =
406+
if is_all_rhs_aos then
407+
"Right hand side of assignment is all AoS: "
408+
else "" in
409+
let rhs_not_promotable_to_soa_warn =
410+
if is_rhs_not_promoteable_to_soa then
411+
"The right hand side of the assignment only contains data and \
412+
scalar operations that are not promotable to SoA: "
413+
else "" in
414+
let not_supported_func_warn =
415+
if is_not_supported_func then
416+
"Function on right hand side of assignment is not supported by \
417+
SoA: "
418+
else "" in
419+
let rhs_name_set = Set.add rhs_set name in
420+
let rhs_name_set_str = concat_set_str rhs_name_set in
421+
user_warning_op SoA linenum all_rhs_warn rhs_name_set_str;
422+
user_warning_op SoA linenum rhs_not_promotable_to_soa_warn
423+
rhs_name_set_str;
424+
user_warning_op SoA linenum not_supported_func_warn rhs_name_set_str;
425+
Set.add (Set.union tuple_demotions rhs_set) name)
351426
else tuple_demotions
352427
else tuple_demotions in
353428
Set.union acc assign_demotions
354429
| NRFunApp (kind, exprs) ->
355-
query_initial_demotable_funs in_loop acc kind exprs
430+
query_initial_demotable_funs in_loop linenum acc kind exprs
356431
| IfElse (predicate, true_stmt, op_false_stmt) ->
357432
let predicate_acc = query_expr acc predicate in
358433
Set.union acc
@@ -410,24 +485,36 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
410485
@param pattern The Stmt pattern to query.
411486
*)
412487
let query_demotable_stmt (aos_exits : string Set.Poly.t)
413-
(pattern : (Expr.Typed.t, int) Stmt.Fixed.Pattern.t) : string Set.Poly.t =
414-
match pattern with
488+
(stmt : Stmt.Located.Non_recursive.t) : string Set.Poly.t =
489+
let linenum = stmt.meta.end_loc.line_num in
490+
match stmt.pattern with
415491
| Stmt.Fixed.Pattern.Assignment
416492
(lval, (_ : UnsizedType.t), (rhs : Expr.Typed.t)) -> (
417493
let assign_name = Stmt.Helpers.lhs_variable lval in
418494
let all_rhs_eigen_names = query_var_eigen_names rhs in
419-
if Set.mem aos_exits assign_name then
420-
Set.add all_rhs_eigen_names assign_name
495+
if Set.mem aos_exits assign_name then (
496+
user_warning_op SoA linenum
497+
"Right hand side contains only AoS expressions: " assign_name;
498+
Set.add all_rhs_eigen_names assign_name)
421499
else
422500
match is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names with
423-
| true -> Set.add all_rhs_eigen_names assign_name
501+
| true ->
502+
user_warning_op SoA linenum
503+
"Right hand side contains only AoS expressions: " assign_name;
504+
Set.add all_rhs_eigen_names assign_name
424505
| false -> Set.Poly.empty)
425506
| Decl {decl_id; initialize= Assign e; _} -> (
426507
let all_rhs_eigen_names = query_var_eigen_names e in
427-
if Set.mem aos_exits decl_id then Set.add all_rhs_eigen_names decl_id
508+
if Set.mem aos_exits decl_id then (
509+
user_warning_op SoA linenum
510+
"Right hand side contains only AoS expressions: " decl_id;
511+
Set.add all_rhs_eigen_names decl_id)
428512
else
429513
match is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names with
430-
| true -> Set.add all_rhs_eigen_names decl_id
514+
| true ->
515+
user_warning_op SoA linenum
516+
"Right hand side contains only AoS expressions: " decl_id;
517+
Set.add all_rhs_eigen_names decl_id
431518
| false -> Set.Poly.empty)
432519
(* All other statements do not need logic here*)
433520
| _ -> Set.Poly.empty

src/analysis_and_optimization/Optimize.ml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,8 +1240,7 @@ let optimize_soa (mir : Program.Typed.t) =
12401240
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t)
12411241
(l : int) (aos_variables : string Set.Poly.t) =
12421242
let mir_node mir_idx = Map.find_exn flowgraph_to_mir mir_idx in
1243-
match (mir_node l).pattern with
1244-
| stmt -> Memory_patterns.query_demotable_stmt aos_variables stmt in
1243+
Memory_patterns.query_demotable_stmt aos_variables (mir_node l) in
12451244
let initial_variables =
12461245
List.fold ~init:Set.Poly.empty
12471246
~f:(Memory_patterns.query_initial_demotable_stmt false)

0 commit comments

Comments
 (0)