Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions src/stan_math_backend/Stan_math_code_gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ let pp_located ppf _ =
throw std::runtime_error("*** IF YOU SEE THIS, PLEASE REPORT A BUG ***"); |}

(** Detect if argument requires C++ template *)
let arg_needs_template = function
let arg_needs_template arg =
match arg with
| UnsizedType.DataOnly, _, t -> UnsizedType.is_eigen_type t
| _, _, t when UnsizedType.contains_int t -> false
| _ -> true
Expand All @@ -66,10 +67,12 @@ let arg_needs_template = function
@return A list of arguments with template parameter names added.
*)
let maybe_templated_arg_types (args : Program.fun_arg_decl) =
List.mapi args ~f:(fun i a ->
match arg_needs_template a with
| true -> Some (sprintf "T%d__" i)
| false -> None )
List.mapi args ~f:(fun i (adtype, _, ut) ->
match ut with
| UMatrix | UVector | URowVector -> Some [sprintf "T%d__" i]
| UReal when adtype = AutoDiffable -> Some [sprintf "T%d__" i]
| UArray _ -> Some [sprintf "T%d__" i; sprintf "Alloc%d__" i]
| UInt | UReal | UMathLibraryFunction | UFun _ -> None )

let return_arg_types (args : Program.fun_arg_decl) =
List.mapi args ~f:(fun i ((_, _, ut) as a) ->
Expand All @@ -80,8 +83,8 @@ let return_arg_types (args : Program.fun_arg_decl) =

let%expect_test "arg types templated correctly" =
[(AutoDiffable, "xreal", UReal); (DataOnly, "yint", UInt)]
|> maybe_templated_arg_types |> List.filter_opt |> String.concat ~sep:","
|> print_endline ;
|> maybe_templated_arg_types |> List.filter_opt |> List.concat
|> String.concat ~sep:"," |> print_endline ;
[%expect {| T0__ |}]

(** Print the code for promoting stan real types
Expand Down Expand Up @@ -137,6 +140,22 @@ let pp_located_error ppf (pp_body_block, body) =
string ppf " catch (const std::exception& e) " ;
pp_block ppf (pp_located, ())

(**
* Print the types used in the C++ function signature.
* For most types we'll simply use the template typename given
* such as `T{id}__, but for std::vector's we will specialize
* the function by wrapping the joint template parameters
* (`T{id}__, Alloc{id}__`) around `std::vector<{Templates}>.
*)
let pp_arg_types ppf (scalar, ut) =
match ut with
| UnsizedType.UInt | UReal | UMatrix | URowVector | UVector ->
string ppf scalar
| UArray _ ->
(* Expressions are not accepted for arrays of Eigen::Matrix *)
pf ppf "std::vector<%s>" scalar
| x -> raise_s [%message (x : UnsizedType.t) "not implemented yet"]

(** Print the type of an object.
@param ppf A pretty printer
@param custom_scalar_opt A string representing a types inner scalar value.
Expand All @@ -150,8 +169,7 @@ let pp_arg ppf (custom_scalar_opt, (_, name, ut)) =
| None -> stantype_prim_str ut
in
(* we add the _arg suffix for any Eigen types *)
pf ppf "const %a& %s" pp_unsizedtype_custom_scalar_eigen_exprs (scalar, ut)
name
pf ppf "const %a& %s" pp_arg_types (scalar, ut) name

let pp_arg_eigen_suffix ppf (custom_scalar_opt, (_, name, ut)) =
let scalar =
Expand All @@ -163,8 +181,7 @@ let pp_arg_eigen_suffix ppf (custom_scalar_opt, (_, name, ut)) =
let opt_arg_suffix =
if UnsizedType.is_eigen_type ut then name ^ "_arg__" else name
in
pf ppf "const %a& %s" pp_unsizedtype_custom_scalar_eigen_exprs (scalar, ut)
opt_arg_suffix
pf ppf "const %a& %s" pp_arg_types (scalar, ut) opt_arg_suffix

(** [pp_located_error_b] automatically adds a Block wrapper *)
let pp_located_error_b ppf body_stmts =
Expand All @@ -178,16 +195,17 @@ let typename = ( ^ ) "typename "
@param fdargs A sexp list of strings representing C++ types.
*)
let get_templates_and_args exprs fdargs =
let argtypetemplates = maybe_templated_arg_types fdargs in
( List.filter_opt argtypetemplates
let argtype_templates = maybe_templated_arg_types fdargs in
let templates =
List.map ~f:(Option.map ~f:(String.concat ~sep:", ")) argtype_templates
in
( List.concat (List.filter_opt argtype_templates)
, if not exprs then
List.map
~f:(fun a -> strf "%a" pp_arg a)
(List.zip_exn argtypetemplates fdargs)
List.map ~f:(fun a -> strf "%a" pp_arg a) (List.zip_exn templates fdargs)
else
List.map
~f:(fun a -> strf "%a" pp_arg_eigen_suffix a)
(List.zip_exn argtypetemplates fdargs) )
(List.zip_exn templates fdargs) )

(** Print the C++ template parameter decleration before a function.
@param ppf A pretty printer.
Expand Down
Loading