Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 143 additions & 41 deletions src/analysis_and_optimization/Memory_patterns.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,30 @@ open Core
open Core.Poly
open Middle

type demotion = int * Mem_pattern.t * string [@@deriving compare]

let demotion_reasons = ref []

let get_warnings () =
let mem_name pattern =
match pattern with Mem_pattern.SoA -> "SoA" | AoS -> "AoS" in
!demotion_reasons
|> List.dedup_and_sort ~compare:compare_demotion
|> List.map ~f:(fun (linenum, pattern, msg) ->
Printf.sprintf "Optimization hazard warning (Line %i): %s warning: %s"
linenum (mem_name pattern) msg)

let user_warning_op (mem_pattern : Mem_pattern.t) (linenum : int) (msg : string)
(names : string) =
if not (String.is_empty names || String.is_empty msg) then
demotion_reasons :=
(linenum, mem_pattern, msg ^ " " ^ names) :: !demotion_reasons

let concat_set_str (set : string Set.Poly.t) =
Set.fold
~f:(fun acc elem -> if acc = "" then acc ^ elem else acc ^ ", " ^ elem)
~init:"" set

(**
Return a Var expression of the name for each type
containing an eigen matrix
Expand Down Expand Up @@ -98,7 +122,7 @@ let query_stan_math_mem_pattern_support (name : string)
Frontend.SignatureMismatch.check_compatible_arguments_mod_conv x args
|> Result.is_ok)
namematches in
let is_soa = function _, _, _, Mem_pattern.SoA -> true | _ -> false in
let is_soa (_, _, _, p) = p = Mem_pattern.SoA in
List.exists ~f:is_soa filteredmatches

(*Validate whether a function can support SoA matrices*)
Expand All @@ -116,13 +140,13 @@ let is_fun_soa_supported name exprs =
will be returned if the matrix or vector is accessed by single
cell indexing.
*)
let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
Expr.{pattern; _} : string Set.Poly.t =
let rec query_initial_demotable_expr (in_loop : bool) (stmt_linenum : int)
~(acc : string Set.Poly.t) Expr.{pattern; _} : string Set.Poly.t =
let query_expr (accum : string Set.Poly.t) =
query_initial_demotable_expr in_loop ~acc:accum in
query_initial_demotable_expr in_loop stmt_linenum ~acc:accum in
match pattern with
| FunApp (kind, (exprs : Expr.Typed.t list)) ->
query_initial_demotable_funs in_loop acc kind exprs
query_initial_demotable_funs in_loop stmt_linenum acc kind exprs
| Indexed ((Expr.{meta= {type_; _}; _} as expr), indexed) ->
let index_set =
Set.Poly.union_list
Expand All @@ -132,18 +156,29 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
(query_expr acc))
indexed) in
let index_demotes =
if is_uni_eigen_loop_indexing in_loop type_ indexed then
Set.union (query_var_eigen_names expr) index_set
if is_uni_eigen_loop_indexing in_loop type_ indexed then (
let single_index_set = query_var_eigen_names expr in
let failure_str = concat_set_str (Set.inter acc single_index_set) in
let msg = "Accessed by element in a for loop:" in
user_warning_op SoA stmt_linenum msg failure_str;
Set.union single_index_set index_set)
else Set.union (query_expr acc expr) index_set in
Set.union acc index_demotes
| Var (_ : string) | Lit ((_ : Expr.Pattern.litType), (_ : string)) -> acc
| Promotion (expr, _, _) -> query_expr acc expr
| TupleProjection (expr, _) -> query_expr acc expr
| TernaryIf (predicate, texpr, fexpr) ->
let predicate_demotes = query_expr acc predicate in
Set.union
(Set.union predicate_demotes (query_var_eigen_names texpr))
(query_var_eigen_names fexpr)
let full_set =
Set.union
(Set.union predicate_demotes (query_var_eigen_names texpr))
(query_var_eigen_names fexpr) in
if Set.is_empty full_set then full_set
else
let failure_str = concat_set_str (Set.inter acc full_set) in
let msg = "Used in a ternary operator which is not allowed:" in
user_warning_op SoA stmt_linenum msg failure_str;
full_set
| EAnd (lhs, rhs) | EOr (lhs, rhs) ->
(*We need to get the demotes from both sides*)
let full_lhs_rhs = Set.union (query_expr acc lhs) (query_expr acc rhs) in
Expand All @@ -166,9 +201,11 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
to the UDF.
exprs The expression list passed to the functions.
*)
and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t)
(kind : 'a Fun_kind.t) (exprs : Expr.Typed.t list) : string Set.Poly.t =
let query_expr accum = query_initial_demotable_expr in_loop ~acc:accum in
and query_initial_demotable_funs (in_loop : bool) (stmt_linenum : int)
(acc : string Set.Poly.t) (kind : 'a Fun_kind.t) (exprs : Expr.Typed.t list)
: string Set.Poly.t =
let query_expr accum =
query_initial_demotable_expr in_loop stmt_linenum ~acc:accum in
let top_level_eigen_names =
Set.Poly.union_list (List.map ~f:query_var_eigen_names exprs) in
let demoted_eigen_names = List.fold ~init:acc ~f:query_expr exprs in
Expand All @@ -181,11 +218,26 @@ and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t)
| name -> (
match is_fun_soa_supported name exprs with
| true -> Set.union acc demoted_eigen_names
| false -> Set.union acc demoted_and_top_level_names))
| false ->
let fail_names =
concat_set_str (Set.inter acc top_level_eigen_names) in
user_warning_op SoA stmt_linenum
("Function " ^ name ^ " is not supported:")
fail_names;
Set.union acc demoted_and_top_level_names))
| CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec | FnMakeTuple) ->
let fail_names =
concat_set_str (Set.inter acc demoted_and_top_level_names) in
user_warning_op SoA stmt_linenum
"Used in {} make array or make row vector compiler functions:"
fail_names;
Set.union acc demoted_and_top_level_names
| CompilerInternal (_ : 'a Internal_fun.t) -> acc
| UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) ->
let fail_names =
concat_set_str (Set.inter acc demoted_and_top_level_names) in
user_warning_op SoA stmt_linenum "Used in user defined function:"
fail_names;
Set.union acc demoted_and_top_level_names

(**
Expand Down Expand Up @@ -283,9 +335,10 @@ let contains_at_least_one_ad_matrix_or_all_data
[query_initial_demotable_expr] for an explanation of the logic.
*)
let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
(Stmt.{pattern; _} : Stmt.Located.t) : string Set.Poly.t =
(Stmt.{pattern; meta} : Stmt.Located.t) : string Set.Poly.t =
let linenum = meta.end_loc.line_num in
let query_expr (accum : string Set.Poly.t) =
query_initial_demotable_expr in_loop ~acc:accum in
query_initial_demotable_expr in_loop linenum ~acc:accum in
match pattern with
| Stmt.Pattern.Assignment
( lval
Expand All @@ -299,21 +352,25 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
List.fold ~init:acc
~f:(fun accum x ->
Index.folder accum
(fun acc -> query_initial_demotable_expr in_loop ~acc)
(fun acc -> query_initial_demotable_expr in_loop linenum ~acc)
x)
idx in
match is_uni_eigen_loop_indexing in_loop ut idx with
| true -> Set.add idx_list name
| true ->
user_warning_op SoA linenum "Accessed by element in a for loop:"
(if Set.mem acc name then "" else name);
Set.add idx_list name
| false -> idx_list in
let rhs_demotable_names = query_expr acc rhs in
let rhs_and_idx_demotions = Set.union idx_demotable rhs_demotable_names in
(* RHS (1)*)
let tuple_demotions =
match lval with
| LTupleProjection _, _ ->
Set.add
(Set.union rhs_and_idx_demotions (query_var_eigen_names rhs))
name
let tuple_set = query_var_eigen_names rhs in
let fail_set = concat_set_str tuple_set in
user_warning_op SoA linenum "Used in tuple:" fail_set;
Set.add (Set.union rhs_and_idx_demotions tuple_set) name
| _ -> rhs_and_idx_demotions in
let assign_demotions =
let is_eigen_stmt = UnsizedType.contains_eigen_type rhs.meta.type_ in
Expand All @@ -327,30 +384,52 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
(extract_nonderived_admatrix_types rhs))
| _ -> false in
(* LHS (3) rhs unsupported function*)
let is_not_supported_func =
let non_supported_func_name =
match rhs.pattern with
| FunApp (UserDefined _, _) -> true
| FunApp (CompilerInternal _, _) -> false
| FunApp (StanLib (name, _, _), exprs) ->
not
(query_stan_math_mem_pattern_support name
(List.map ~f:Expr.Typed.fun_arg exprs))
| _ -> false in
| FunApp (UserDefined (name, _), _) -> Some name
| FunApp (StanLib (name, _, _), exprs)
when not
(query_stan_math_mem_pattern_support name
(List.map ~f:Expr.Typed.fun_arg exprs)) ->
Some name
| _ -> None in
(* LHS (3) all rhs aos*)
let is_all_rhs_aos =
is_nonzero_subset
~subset:(query_var_eigen_names rhs)
~set:rhs_demotable_names in
if
is_all_rhs_aos || is_rhs_not_promoteable_to_soa
|| is_not_supported_func
then
Set.add (Set.union tuple_demotions (query_var_eigen_names rhs)) name
|| Option.is_some non_supported_func_name
then (
let rhs_set = query_var_eigen_names rhs in
let all_rhs_warn =
if is_all_rhs_aos then "Right hand side of assignment is all AoS:"
else "" in
let rhs_not_promotable_to_soa_warn =
if is_rhs_not_promoteable_to_soa then
"The right hand side of the assignment only contains data and \
scalar operations that are not promotable to SoA:"
else "" in
let not_supported_func_warn =
match non_supported_func_name with
| Some fname ->
"Function '" ^ fname
^ "' on right hand side of assignment is not supported by \
SoA:"
| None -> "" in
let rhs_name_set = Set.add rhs_set name in
let rhs_name_set_str = concat_set_str rhs_name_set in
user_warning_op SoA linenum all_rhs_warn rhs_name_set_str;
user_warning_op SoA linenum rhs_not_promotable_to_soa_warn
rhs_name_set_str;
user_warning_op SoA linenum not_supported_func_warn rhs_name_set_str;
Set.add (Set.union tuple_demotions rhs_set) name)
else tuple_demotions
else tuple_demotions in
Set.union acc assign_demotions
| NRFunApp (kind, exprs) ->
query_initial_demotable_funs in_loop acc kind exprs
query_initial_demotable_funs in_loop linenum acc kind exprs
| IfElse (predicate, true_stmt, op_false_stmt) ->
let predicate_acc = query_expr acc predicate in
Set.union acc
Expand Down Expand Up @@ -386,7 +465,10 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
| Decl {decl_type= Type.Sized st; decl_id; initialize; _} ->
let complex_name =
match SizedType.is_complex_type st with
| true -> Set.Poly.singleton decl_id
| true ->
user_warning_op SoA linenum "Complex-valued types cannot be SoA:"
decl_id;
Set.Poly.singleton decl_id
| false -> Set.Poly.empty in
let init_names =
match initialize with
Expand All @@ -408,24 +490,44 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
@param pattern The Stmt pattern to query.
*)
let query_demotable_stmt (aos_exits : string Set.Poly.t)
(pattern : (Expr.Typed.t, int) Stmt.Pattern.t) : string Set.Poly.t =
match pattern with
(stmt : Stmt.Located.Non_recursive.t) : string Set.Poly.t =
let linenum = stmt.meta.end_loc.line_num in
match stmt.pattern with
| Stmt.Pattern.Assignment (lval, (_ : UnsizedType.t), (rhs : Expr.Typed.t))
-> (
let assign_name = Stmt.Helpers.lhs_variable lval in
let all_rhs_eigen_names = query_var_eigen_names rhs in
if Set.mem aos_exits assign_name then
Set.add all_rhs_eigen_names assign_name
if Set.mem aos_exits assign_name then (
user_warning_op SoA linenum
"Right hand side contains only AoS expressions:" assign_name;
Set.add all_rhs_eigen_names assign_name)
else
match is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names with
| true -> Set.add all_rhs_eigen_names assign_name
| true ->
let warn =
Fmt.(
str "Right hand side contains AoS expressions (%s):"
(concat_set_str (Set.inter aos_exits all_rhs_eigen_names)))
in
user_warning_op SoA linenum warn assign_name;
Set.add all_rhs_eigen_names assign_name
| false -> Set.Poly.empty)
| Decl {decl_id; initialize= Assign e; _} -> (
let all_rhs_eigen_names = query_var_eigen_names e in
if Set.mem aos_exits decl_id then Set.add all_rhs_eigen_names decl_id
if Set.mem aos_exits decl_id then (
user_warning_op SoA linenum
"Right hand side contains only AoS expressions:" decl_id;
Set.add all_rhs_eigen_names decl_id)
else
match is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names with
| true -> Set.add all_rhs_eigen_names decl_id
| true ->
let warn =
Fmt.(
str "Right hand side contains AoS expressions (%s):"
(concat_set_str (Set.inter aos_exits all_rhs_eigen_names)))
in
user_warning_op SoA linenum warn decl_id;
Set.add all_rhs_eigen_names decl_id
| false -> Set.Poly.empty)
(* All other statements do not need logic here*)
| _ -> Set.Poly.empty
Expand Down
3 changes: 1 addition & 2 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1228,8 +1228,7 @@ let optimize_soa (mir : Program.Typed.t) =
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t)
(l : int) (aos_variables : string Set.Poly.t) =
let mir_node mir_idx = Map.find_exn flowgraph_to_mir mir_idx in
match (mir_node l).pattern with
| stmt -> Memory_patterns.query_demotable_stmt aos_variables stmt in
Memory_patterns.query_demotable_stmt aos_variables (mir_node l) in
let initial_variables =
List.fold ~init:Set.Poly.empty
~f:(Memory_patterns.query_initial_demotable_stmt false)
Expand Down
6 changes: 5 additions & 1 deletion src/driver/Entry.ml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ let stan2cpp model_name model (flags : Flags.t) (output : other_output -> unit)
tx_mir in
if flags.debug_settings.print_mem_patterns then
output
(Memory_patterns (Fmt.str "%a" Memory_patterns.pp_mem_patterns opt_mir));
(Memory_patterns
(Fmt.str "%a%a@\n" Memory_patterns.pp_mem_patterns opt_mir
(* TODO should be better associated with the names from above? *)
Fmt.(list string)
(Memory_patterns.get_warnings ())));
debug_output_mir output opt_mir flags.debug_settings.print_optimized_mir;
let cpp =
Lower_program.lower_program
Expand Down
Loading