From 07d2e7571085f8539453bd48d7490c255fcdad16 Mon Sep 17 00:00:00 2001 From: Krzysztof Lecki Date: Wed, 20 Mar 2024 17:41:51 +0100 Subject: [PATCH 1/4] Draft draft Signed-off-by: Krzysztof Lecki --- dali/pipeline/operator/error_reporting.h | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/dali/pipeline/operator/error_reporting.h b/dali/pipeline/operator/error_reporting.h index d84bdf0f6f9..d9f259f1e85 100644 --- a/dali/pipeline/operator/error_reporting.h +++ b/dali/pipeline/operator/error_reporting.h @@ -186,6 +186,29 @@ class DaliStopIteration : public DaliError { */ std::string GetErrorContextMessage(const OpSpec &spec); +namespace validate { + +//TODO: additional message +//Dim, Type, axis, dtype +//Matching shape reference, matching other input. +void CheckInputType(OpSpec &spec, Workspace &ws, int input_idx, DALIDataType allowed_type); +void CheckInputType(OpSpec &spec, Workspace &ws, int input_idx, + const std::vector &allowed_types); +void CheckInputType(OpSpec &spec, Workspace &ws, const std::string &argument_name, + DALIDataType allowed_type); +void CheckInputType(OpSpec &spec, Workspace &ws, const std::string &argument_name, + const std::vector &allowed_types); + +template +void CheckInputType(OpSpec &spec, Workspace &ws, int input_idx); +template +void CheckInputType(OpSpec &spec, Workspace &ws, const std::string &argument_name); + + + +} + + } // namespace dali #endif // DALI_PIPELINE_OPERATOR_ERROR_REPORTING_H_ From 7b6a8e3119a3a8af6c791d0a770ff0751a2479bb Mon Sep 17 00:00:00 2001 From: Krzysztof Lecki Date: Mon, 25 Mar 2024 10:59:40 +0100 Subject: [PATCH 2/4] Validation + third_party Signed-off-by: Krzysztof Lecki --- .gitmodules | 3 + cmake/Dependencies.common.cmake | 8 ++ dali/pipeline/operator/error_reporting.cc | 142 +++++++++++++++++++++- dali/pipeline/operator/error_reporting.h | 87 +++++++++++-- third_party/fmt | 1 + 5 files changed, 227 insertions(+), 14 deletions(-) create mode 160000 third_party/fmt diff --git a/.gitmodules b/.gitmodules index 41f89cab274..4abcd184b85 100644 --- a/.gitmodules +++ b/.gitmodules @@ -29,3 +29,6 @@ [submodule "third_party/cvcuda"] path = third_party/cvcuda url = https://github.com/CVCUDA/CV-CUDA.git +[submodule "third_party/fmt"] + path = third_party/fmt + url = git@github.com:fmtlib/fmt.git diff --git a/cmake/Dependencies.common.cmake b/cmake/Dependencies.common.cmake index 46cbda225e3..8c56ab4f536 100644 --- a/cmake/Dependencies.common.cmake +++ b/cmake/Dependencies.common.cmake @@ -380,3 +380,11 @@ if(BUILD_NVIMAGECODEC) endif() endif() endif() + +################################################################## +# {fmt} +################################################################## +check_and_add_cmake_submodule(${PROJECT_SOURCE_DIR}/third_party/fmt EXCLUDE_FROM_ALL) +set_target_properties(fmt PROPERTIES POSITION_INDEPENDENT_CODE ON) +list(APPEND DALI_LIBS fmt) +list(APPEND DALI_EXCLUDES libfmt.a) diff --git a/dali/pipeline/operator/error_reporting.cc b/dali/pipeline/operator/error_reporting.cc index 6f4e1d1a67c..9ac572076ab 100644 --- a/dali/pipeline/operator/error_reporting.cc +++ b/dali/pipeline/operator/error_reporting.cc @@ -17,13 +17,28 @@ #include #include #include +#include +// for fmt::join +#include +// for ostream support +#include #include "dali/core/error_handling.h" +#include "dali/pipeline/data/backend.h" +#include "dali/pipeline/data/types.h" #include "dali/pipeline/operator/error_reporting.h" #include "dali/pipeline/operator/op_spec.h" +// template <> struct fmt::formatter : fmt::ostream_formatter {}; +// template <> struct fmt::formatter : fmt::tostring_formatter {}; + + namespace dali { +auto format_as(dali::DALIDataType type) { + return dali::to_string(type); +} + std::vector GetOperatorOriginInfo(const OpSpec &spec) { auto origin_stack_filename = spec.GetRepeatedArgument("_origin_stack_filename"); auto origin_stack_lineno = spec.GetRepeatedArgument("_origin_stack_lineno"); @@ -67,8 +82,7 @@ void PropagateError(ErrorInfo error) { catch (DaliError &e) { e.UpdateMessage(make_string(error.context_info, e.what(), error.additional_message)); throw; - } - catch (DALIException &e) { + } catch (DALIException &e) { // We drop the C++ stack trace at this point and go back to runtime_error. throw std::runtime_error( make_string(error.context_info, e.what(), @@ -109,9 +123,129 @@ std::string GetErrorContextMessage(const OpSpec &spec) { formatted_origin_stack + "\n") : " "; // we need space before "encountered" - return make_string("Error in ", device, " operator `", op_name, "`", - optional_stack_mention, "encountered:\n\n"); + return make_string("Error in ", device, " operator `", op_name, "`", optional_stack_mention, + "encountered:\n\n"); +} + + +namespace validate { + +std::string SepIfNotEmpty(const std::string &str, const std::string &sep = " ") { + if (str.empty()) { + return ""; + } + return sep; +} + +void Type(DALIDataType actual_type, DALIDataType expected_type, const std::string &name, + const std::string &additional_msg) { + if (actual_type == expected_type) { + return; + } + + throw DaliTypeError(fmt::format("Unexpected type for {}. Got type: `{}` but expected: `{}`.{}{}", name, actual_type, + expected_type, SepIfNotEmpty(additional_msg), + additional_msg)); +} + +void Type(DALIDataType actual_type, span &expected_types, + const std::string &name, const std::string &additional_msg) { + if (std::size(expected_types) == 1) { + Type(actual_type, expected_types[0], name, additional_msg); + return; + } + for (auto expected_type : expected_types) { + if (actual_type == expected_type) { + return; + } + } + + throw DaliTypeError(fmt::format("Unexpected type for {}. Got type: `{}` but expected one of: `{}`.{}{}", + name, actual_type, fmt::join(expected_types, "`, `"), + SepIfNotEmpty(additional_msg), additional_msg)); +} + + +std::string FormatInput(const OpSpec &spec, const Workspace &ws, int input_idx) { + if (spec.GetSchema().HasInputDox()) { + return fmt::format("input `{}` (`__{}`)", input_idx, spec.GetSchema().GetInputName(input_idx)); + } + + return fmt::format("input `{}`", input_idx); } +std::string FormatOutput(const OpSpec &spec, const Workspace &ws, int output_idx) { + return fmt::format("output `{}`", output_idx); +} + +void InputType(const OpSpec &spec, const Workspace &ws, int input_idx, DALIDataType allowed_type, + const std::string &additional_msg) { + DALIDataType dtype = ws.GetInputDataType(input_idx); + Type(dtype, allowed_type, FormatInput(spec, ws, input_idx), additional_msg); +} + +void InputType(const OpSpec &spec, const Workspace &ws, int input_idx, + span &allowed_types, const std::string &additional_msg) { + DALIDataType dtype = ws.GetInputDataType(input_idx); + Type(dtype, allowed_types, FormatInput(spec, ws, input_idx), additional_msg); +} + +void Dtype(const OpSpec &spec, DALIDataType allowed_type, bool allow_unspecified, + const std::string &additional_msg) { + if (allow_unspecified && !spec.HasArgument("dtype")) { + return; + } else if (!allow_unspecified &&!spec.HasArgument("dtype")) { + throw DaliValueError(fmt::format("Argument `dtype` was not specified.{}{}", SepIfNotEmpty(additional_msg), + additional_msg)); + } + } + +void Dtype(const OpSpec &spec, span &allowed_types, bool allow_unspecified, + const std::string &additional_msg) { + + } + +void Dim(int actual_dim, int expected_dim, const std::string &name, + const std::string &additional_msg) { + if (actual_dim == expected_dim) { + return; + } + throw DaliValueError(fmt::format("Got dim: `{}` for {}, but expected: `{}`.{}{}", actual_dim, + name, expected_dim, SepIfNotEmpty(additional_msg), + additional_msg)); +} + +void Dim(int actual_dim, int expected_from, int expected_to, const std::string &name, const std::string &additional_msg) { + + if (expected_from <= actual_dim && actual_dim < expected_to) { + return; + } + throw DaliValueError(fmt::format( + "Got dim: `{}` for {}, but expected value in `[{}, {})` range.{}{}", actual_dim, name, + expected_from, expected_to, SepIfNotEmpty(additional_msg), additional_msg)); +} + +void Dim(int actual_dim, span &expected, const std::string &name, const std::string &additional_msg) { + if (size(expected) == 1) { + Dim(actual_dim, expected[0], name, additional_msg); + return; + } + for (auto expected_dim : expected) { + if (actual_dim == expected_dim) { + return; + } + } + throw DaliValueError(fmt::format("Got dim: `{}` for {}, but expected one of: `{}`.{}{}", + actual_dim, name, fmt::join(expected, "`, `"), + SepIfNotEmpty(additional_msg), additional_msg)); +} + +// void Type(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, DALIDataType allowed_type, +// const std::string &additional_msg); +// void Type(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, +// span &allowed_types, const std::string &additional_msg); + +} // namespace validate + } // namespace dali diff --git a/dali/pipeline/operator/error_reporting.h b/dali/pipeline/operator/error_reporting.h index d9f259f1e85..e2d6db0f63b 100644 --- a/dali/pipeline/operator/error_reporting.h +++ b/dali/pipeline/operator/error_reporting.h @@ -23,9 +23,11 @@ #include #include "dali/core/api_helper.h" +#include "dali/core/span.h" #include "dali/pipeline/data/types.h" #include "dali/pipeline/operator/op_spec.h" #include "dali/pipeline/operator/name_utils.h" +#include "dali/pipeline/workspace/workspace.h" namespace dali { @@ -188,26 +190,91 @@ std::string GetErrorContextMessage(const OpSpec &spec); namespace validate { +void Type(DALIDataType actual_type, DALIDataType expected_type, const std::string &name, + const std::string &additional_message = ""); + +void Type(DALIDataType actual_type, span &expected_types, + const std::string &name, const std::string &additional_message = ""); + //TODO: additional message //Dim, Type, axis, dtype //Matching shape reference, matching other input. -void CheckInputType(OpSpec &spec, Workspace &ws, int input_idx, DALIDataType allowed_type); -void CheckInputType(OpSpec &spec, Workspace &ws, int input_idx, - const std::vector &allowed_types); -void CheckInputType(OpSpec &spec, Workspace &ws, const std::string &argument_name, - DALIDataType allowed_type); -void CheckInputType(OpSpec &spec, Workspace &ws, const std::string &argument_name, - const std::vector &allowed_types); +void InputType(const OpSpec &spec, const Workspace &ws, int input_idx, DALIDataType allowed_type, + const std::string &additional_msg = ""); + +void InputType(const OpSpec &spec, const Workspace &ws, int input_idx, + span &allowed_types, + const std::string &additional_msg = ""); template -void CheckInputType(OpSpec &spec, Workspace &ws, int input_idx); -template -void CheckInputType(OpSpec &spec, Workspace &ws, const std::string &argument_name); +void InputType(const OpSpec &spec, const Workspace &ws, int input_idx, + const std::string &additional_msg = "") { + static std::array allowed_types = { + type2id::value...}; + InputType(spec, ws, input_idx, make_cspan(allowed_types), additional_msg); +} + +void Dtype(const OpSpec &spec, DALIDataType allowed_type, bool allow_unspecified = false, + const std::string &additional_msg = ""); + +void Dtype(const OpSpec &spec, span &allowed_types, bool allow_unspecified = false, + const std::string &additional_msg = ""); + +// template +// void Dtype(const OpSpec &spec, const Workspace &ws, int input_idx, +// const std::string &additional_msg = "") { +// static std::array allowed_types = { +// type2id::value...}; +// InputType(spec, ws, input_idx, make_cspan(allowed_types), additional_msg); +// } + +void OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, DALIDataType allowed_type, + const std::string &additional_msg = ""); +void OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, + span &allowed_types, + const std::string &additional_msg = ""); +template +void OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, + const std::string &additional_msg = "") { + static std::array allowed_types = { + type2id::value...}; + OutputType(spec, ws, output_idx, make_cspan(allowed_types), additional_msg); } +void Dim(int actual_dim, int expected_dim, const std::string &name, const std::string &additional_msg = ""); +void Dim(int actual_dim, int expected_from, int expected_to, const std::string &name, const std::string &additional_msg = ""); +void Dim(int actual_dim, span &expected, const std::string &name, const std::string &additional_msg = ""); + +void InputDim(const OpSpec &spec, const Workspace &ws, int input_idx, int expected_dim, const std::string &name, const std::string &additional_msg = ""); +void InputDim(const OpSpec &spec, const Workspace &ws, int input_idx, int expected_from, int expected_to, const std::string &name, const std::string &additional_msg = ""); +void InputDim(const OpSpec &spec, const Workspace &ws, int input_idx, span &expected, const std::string &name, const std::string &additional_msg = ""); + +void InputDim(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, int expected_dim, const std::string &name, const std::string &additional_msg = ""); +void InputDim(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, int expected_from, int expected_to, const std::string &name, const std::string &additional_msg = ""); +void InputDim(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, span &expected, const std::string &name, const std::string &additional_msg = ""); + +void Shape(const TensorListShape<> &actual_shape, const TensorListShape<> &expected_shape, const std::string &name, const std::string &additional_msg = ""); +void UniformShape(const TensorListShape<> &actual_shape, const std::string &name, const std::string &additional_msg = ""); + + +// void Type(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, DALIDataType +// allowed_type, +// const std::string &additional_msg = ""); +// void Type(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, +// span &allowed_types, const std::string &additional_msg = +// ""); + + +// template +// void Type(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, +// const std::string &additional_msg = "") { +// Type(spec, ws, argument_name, {type2id::value...}, additional_msg); +// } +} // namespace validate + } // namespace dali diff --git a/third_party/fmt b/third_party/fmt new file mode 160000 index 00000000000..e69e5f977d4 --- /dev/null +++ b/third_party/fmt @@ -0,0 +1 @@ +Subproject commit e69e5f977d458f2650bb346dadf2ad30c5320281 From 81bcc8ab6df9395fba31629587dad648c3639367 Mon Sep 17 00:00:00 2001 From: Krzysztof Lecki Date: Wed, 27 Mar 2024 23:30:04 +0100 Subject: [PATCH 3/4] Some examples, cleanup Signed-off-by: Krzysztof Lecki --- dali/pipeline/operator/error_reporting.cc | 133 ++++++++-------- dali/pipeline/operator/error_reporting.h | 180 ++++++++++++++-------- dali/pipeline/operator/name_utils.cc | 19 ++- dali/pipeline/operator/name_utils.h | 29 ++++ 4 files changed, 226 insertions(+), 135 deletions(-) diff --git a/dali/pipeline/operator/error_reporting.cc b/dali/pipeline/operator/error_reporting.cc index 9ac572076ab..566c399c6e3 100644 --- a/dali/pipeline/operator/error_reporting.cc +++ b/dali/pipeline/operator/error_reporting.cc @@ -27,6 +27,7 @@ #include "dali/pipeline/data/backend.h" #include "dali/pipeline/data/types.h" #include "dali/pipeline/operator/error_reporting.h" +#include "dali/pipeline/operator/name_utils.h" #include "dali/pipeline/operator/op_spec.h" // template <> struct fmt::formatter : fmt::ostream_formatter {}; @@ -137,73 +138,70 @@ std::string SepIfNotEmpty(const std::string &str, const std::string &sep = " ") return sep; } -void Type(DALIDataType actual_type, DALIDataType expected_type, const std::string &name, - const std::string &additional_msg) { +DALIDataType Type(DALIDataType actual_type, DALIDataType expected_type, const std::string &name, + const std::string &additional_msg) { if (actual_type == expected_type) { - return; + return actual_type; } - throw DaliTypeError(fmt::format("Unexpected type for {}. Got type: `{}` but expected: `{}`.{}{}", name, actual_type, - expected_type, SepIfNotEmpty(additional_msg), + throw DaliTypeError(fmt::format("Unexpected type for {}. Got type: `{}` but expected: `{}`.{}{}", + name, actual_type, expected_type, SepIfNotEmpty(additional_msg), additional_msg)); } -void Type(DALIDataType actual_type, span &expected_types, - const std::string &name, const std::string &additional_msg) { +DALIDataType Type(DALIDataType actual_type, span &expected_types, + const std::string &name, const std::string &additional_msg) { if (std::size(expected_types) == 1) { - Type(actual_type, expected_types[0], name, additional_msg); - return; + return Type(actual_type, expected_types[0], name, additional_msg); } for (auto expected_type : expected_types) { if (actual_type == expected_type) { - return; + return actual_type; } } - throw DaliTypeError(fmt::format("Unexpected type for {}. Got type: `{}` but expected one of: `{}`.{}{}", - name, actual_type, fmt::join(expected_types, "`, `"), - SepIfNotEmpty(additional_msg), additional_msg)); -} - - -std::string FormatInput(const OpSpec &spec, const Workspace &ws, int input_idx) { - if (spec.GetSchema().HasInputDox()) { - return fmt::format("input `{}` (`__{}`)", input_idx, spec.GetSchema().GetInputName(input_idx)); - } - - return fmt::format("input `{}`", input_idx); + throw DaliTypeError(fmt::format( + "Unexpected type for {}. Got type: `{}` but expected one of: `{}`.{}{}", name, actual_type, + fmt::join(expected_types, "`, `"), SepIfNotEmpty(additional_msg), additional_msg)); } -std::string FormatOutput(const OpSpec &spec, const Workspace &ws, int output_idx) { - return fmt::format("output `{}`", output_idx); -} - -void InputType(const OpSpec &spec, const Workspace &ws, int input_idx, DALIDataType allowed_type, - const std::string &additional_msg) { +DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, + DALIDataType allowed_type, const std::string &additional_msg) { DALIDataType dtype = ws.GetInputDataType(input_idx); - Type(dtype, allowed_type, FormatInput(spec, ws, input_idx), additional_msg); + return Type(dtype, allowed_type, FormatInput(spec, input_idx), additional_msg); } -void InputType(const OpSpec &spec, const Workspace &ws, int input_idx, - span &allowed_types, const std::string &additional_msg) { +DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, + span &allowed_types, const std::string &additional_msg) { DALIDataType dtype = ws.GetInputDataType(input_idx); - Type(dtype, allowed_types, FormatInput(spec, ws, input_idx), additional_msg); + return Type(dtype, allowed_types, FormatInput(spec, input_idx), additional_msg); } -void Dtype(const OpSpec &spec, DALIDataType allowed_type, bool allow_unspecified, - const std::string &additional_msg) { +DALIDataType Dtype(const OpSpec &spec, DALIDataType allowed_type, bool allow_unspecified, + const std::string &additional_msg) { if (allow_unspecified && !spec.HasArgument("dtype")) { - return; - } else if (!allow_unspecified &&!spec.HasArgument("dtype")) { - throw DaliValueError(fmt::format("Argument `dtype` was not specified.{}{}", SepIfNotEmpty(additional_msg), - additional_msg)); - } - } - -void Dtype(const OpSpec &spec, span &allowed_types, bool allow_unspecified, - const std::string &additional_msg) { + return DALI_NO_TYPE; + } else if (!allow_unspecified && !spec.HasArgument("dtype")) { + throw DaliValueError(fmt::format("{} was not specified.{}{}", + FormatArgument(spec, "dtype", true), + SepIfNotEmpty(additional_msg), additional_msg)); + } + return Type(spec.GetArgument("dtype"), allowed_type, FormatArgument(spec, "dtype"), + additional_msg); +} - } +DALIDataType Dtype(const OpSpec &spec, span &allowed_types, bool allow_unspecified, + const std::string &additional_msg) { + if (allow_unspecified && !spec.HasArgument("dtype")) { + return DALI_NO_TYPE; + } else if (!allow_unspecified && !spec.HasArgument("dtype")) { + throw DaliValueError(fmt::format("{} was not specified.{}{}", + FormatArgument(spec, "dtype", true), + SepIfNotEmpty(additional_msg), additional_msg)); + } + return Type(spec.GetArgument("dtype"), allowed_types, FormatArgument(spec, "dtype"), + additional_msg); +} void Dim(int actual_dim, int expected_dim, const std::string &name, const std::string &additional_msg) { @@ -215,37 +213,32 @@ void Dim(int actual_dim, int expected_dim, const std::string &name, additional_msg)); } -void Dim(int actual_dim, int expected_from, int expected_to, const std::string &name, const std::string &additional_msg) { +DALIDataType Dtype(const OpSpec &spec, const Workspace &ws, bool (*is_valid)(DALIDataType), + const std::string &explanation) { + return DALI_NO_TYPE; // TODO(klecki): implement +} - if (expected_from <= actual_dim && actual_dim < expected_to) { - return; - } - throw DaliValueError(fmt::format( - "Got dim: `{}` for {}, but expected value in `[{}, {})` range.{}{}", actual_dim, name, - expected_from, expected_to, SepIfNotEmpty(additional_msg), additional_msg)); +DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, + DALIDataType allowed_type, const std::string &additional_msg) { + DALIDataType dtype = ws.GetOutputDataType(output_idx); + return Type(dtype, allowed_type, FormatOutput(spec, output_idx), additional_msg); } -void Dim(int actual_dim, span &expected, const std::string &name, const std::string &additional_msg) { - if (size(expected) == 1) { - Dim(actual_dim, expected[0], name, additional_msg); - return; - } - for (auto expected_dim : expected) { - if (actual_dim == expected_dim) { - return; - } - } - throw DaliValueError(fmt::format("Got dim: `{}` for {}, but expected one of: `{}`.{}{}", - actual_dim, name, fmt::join(expected, "`, `"), - SepIfNotEmpty(additional_msg), additional_msg)); +DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, + span &allowed_types, const std::string &additional_msg) { + DALIDataType dtype = ws.GetOutputDataType(output_idx); + return Type(dtype, allowed_types, FormatOutput(spec, output_idx), additional_msg); } -// void Type(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, DALIDataType allowed_type, -// const std::string &additional_msg); -// void Type(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, -// span &allowed_types, const std::string &additional_msg); +DALIDataType ArgumentType(const OpSpec &spec, const Workspace &ws, const std::string &arg_name, + const std::string &additional_msg) { + DALIDataType expected_type = spec.GetSchema().GetArgumentType(arg_name); + if (!spec.HasTensorArgument(arg_name)) { + return expected_type; + } + return Type(ws.ArgumentInput(arg_name).type(), expected_type, FormatArgument(spec, arg_name), + additional_msg); +} } // namespace validate - - } // namespace dali diff --git a/dali/pipeline/operator/error_reporting.h b/dali/pipeline/operator/error_reporting.h index e2d6db0f63b..1ef379d7fc6 100644 --- a/dali/pipeline/operator/error_reporting.h +++ b/dali/pipeline/operator/error_reporting.h @@ -190,92 +190,144 @@ std::string GetErrorContextMessage(const OpSpec &spec); namespace validate { -void Type(DALIDataType actual_type, DALIDataType expected_type, const std::string &name, - const std::string &additional_message = ""); +/** @defgroup TypeValidation Validation of types. + * Input/Output type validation - based on the workspace contents and the index. + * Argument type validation - internal check if the argument matches the spec before operator + * code is able to access it. + * Argument 'dtype' validation - based on the operator spec, checks the contents. + * Type() variant provides generic validation. + * @{ + */ -void Type(DALIDataType actual_type, span &expected_types, - const std::string &name, const std::string &additional_message = ""); +/** + * @brief Check if the actual_type matches the expected_type. If not, throws TypeError + * containing the message: + * + * Unexpected type for : Got type: , but expected: . + * + * + * @param actual_type + * @param expected_type + * @param name + * @param additional_message + * @return DALIDataType + */ +DALIDataType Type(DALIDataType actual_type, DALIDataType expected_type, const std::string &name, + const std::string &additional_message = ""); -//TODO: additional message -//Dim, Type, axis, dtype -//Matching shape reference, matching other input. -void InputType(const OpSpec &spec, const Workspace &ws, int input_idx, DALIDataType allowed_type, - const std::string &additional_msg = ""); +/** + * @brief Check if the actual_type matches the expected_type. If not, throws TypeError + * containing the message: + * + * Unexpected type for : Got type: , but expected one of: . + * + * + * @param actual_type + * @param expected_types + * @param name + * @param additional_message + * @return DALIDataType + */ +DALIDataType Type(DALIDataType actual_type, span &expected_types, + const std::string &name, const std::string &additional_message = ""); + +DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, + DALIDataType allowed_type, const std::string &additional_msg = ""); -void InputType(const OpSpec &spec, const Workspace &ws, int input_idx, - span &allowed_types, - const std::string &additional_msg = ""); +DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, + span &allowed_types, const std::string &additional_msg = ""); template -void InputType(const OpSpec &spec, const Workspace &ws, int input_idx, - const std::string &additional_msg = "") { +DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, + const std::string &additional_msg = "") { static std::array allowed_types = { type2id::value...}; - InputType(spec, ws, input_idx, make_cspan(allowed_types), additional_msg); + return InputType(spec, ws, input_idx, make_cspan(allowed_types), additional_msg); } -void Dtype(const OpSpec &spec, DALIDataType allowed_type, bool allow_unspecified = false, - const std::string &additional_msg = ""); +DALIDataType Dtype(const OpSpec &spec, DALIDataType allowed_type, bool allow_unspecified = false, + const std::string &additional_msg = ""); -void Dtype(const OpSpec &spec, span &allowed_types, bool allow_unspecified = false, - const std::string &additional_msg = ""); +DALIDataType Dtype(const OpSpec &spec, span &allowed_types, + bool allow_unspecified = false, const std::string &additional_msg = ""); -// template -// void Dtype(const OpSpec &spec, const Workspace &ws, int input_idx, -// const std::string &additional_msg = "") { -// static std::array allowed_types = { -// type2id::value...}; -// InputType(spec, ws, input_idx, make_cspan(allowed_types), additional_msg); -// } +DALIDataType Dtype(const OpSpec &spec, const Workspace &ws, bool (*is_valid)(DALIDataType), + const std::string &explanation); -void OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, DALIDataType allowed_type, - const std::string &additional_msg = ""); +DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, + DALIDataType allowed_type, const std::string &additional_msg = ""); -void OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, - span &allowed_types, - const std::string &additional_msg = ""); +DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, + span &allowed_types, const std::string &additional_msg = ""); template -void OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, - const std::string &additional_msg = "") { +DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, + const std::string &additional_msg = "") { static std::array allowed_types = { type2id::value...}; - OutputType(spec, ws, output_idx, make_cspan(allowed_types), additional_msg); + return OutputType(spec, ws, output_idx, make_cspan(allowed_types), additional_msg); } -void Dim(int actual_dim, int expected_dim, const std::string &name, const std::string &additional_msg = ""); -void Dim(int actual_dim, int expected_from, int expected_to, const std::string &name, const std::string &additional_msg = ""); -void Dim(int actual_dim, span &expected, const std::string &name, const std::string &additional_msg = ""); - -void InputDim(const OpSpec &spec, const Workspace &ws, int input_idx, int expected_dim, const std::string &name, const std::string &additional_msg = ""); -void InputDim(const OpSpec &spec, const Workspace &ws, int input_idx, int expected_from, int expected_to, const std::string &name, const std::string &additional_msg = ""); -void InputDim(const OpSpec &spec, const Workspace &ws, int input_idx, span &expected, const std::string &name, const std::string &additional_msg = ""); - -void InputDim(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, int expected_dim, const std::string &name, const std::string &additional_msg = ""); -void InputDim(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, int expected_from, int expected_to, const std::string &name, const std::string &additional_msg = ""); -void InputDim(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, span &expected, const std::string &name, const std::string &additional_msg = ""); - -void Shape(const TensorListShape<> &actual_shape, const TensorListShape<> &expected_shape, const std::string &name, const std::string &additional_msg = ""); -void UniformShape(const TensorListShape<> &actual_shape, const std::string &name, const std::string &additional_msg = ""); - - -// void Type(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, DALIDataType -// allowed_type, -// const std::string &additional_msg = ""); -// void Type(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, -// span &allowed_types, const std::string &additional_msg = -// ""); - +/** + * @brief Verifies if given argument input has a correct backing type. + * + * @param spec + * @param ws + * @param arg_name + * @param additional_msg + * @return DALIDataType + */ +DALIDataType ArgumentType(const OpSpec &spec, const Workspace &ws, const std::string &arg_name, + const std::string &additional_msg = ""); + +/** @} */ // end of TypeValidation + +// TODO(klecki): Same convention can be applied to checks for inputs and arguments: +// * dimensionality +// * shape: matching, uniform, non-empty ... +// * layout (we already have some of this) +// Also, there are groups of arguments similar to 'dtype', that can have specific checks: +// * relative coordinates [0, 1] in tensor +// * absolute coordinates in tensor shape. +// * axis within dim. +// Examples: Roi, crop window, anchor points, bboxes... +// +// Mutually exclusive arguments are also popular. +// +// The downside is, this generates a lot of boilerplate code, maybe we should just stick with +// the main variant + the FormatInput/Output/Argument variants. +// +// +// void Dim(int actual_dim, int expected_dim, const std::string &name, +// const std::string &additional_msg = ""); +// void Dim(int actual_dim, int expected_from, int expected_to, const std::string &name, +// const std::string &additional_msg = ""); +// void Dim(int actual_dim, span expected, const std::string &name, +// const std::string &additional_msg = ""); + +// void InputDim(const OpSpec &spec, const Workspace &ws, int input_idx, int expected_dim, +// const std::string &name, const std::string &additional_msg = ""); +// void InputDim(const OpSpec &spec, const Workspace &ws, int input_idx, int expected_from, +// int expected_to, const std::string &name, const std::string &additional_msg = ""); +// void InputDim(const OpSpec &spec, const Workspace &ws, int input_idx, span &expected, +// const std::string &name, const std::string &additional_msg = ""); + +// void ArgumentDim(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, +// int expected_dim, const std::string &name, const std::string &additional_msg = ""); +// void ArgumentDim(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, +// int expected_from, int expected_to, const std::string &name, +// const std::string &additional_msg = ""); +// void ArgumentDim(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, +// span &expected, const std::string &name, +// const std::string &additional_msg = ""); + +// void Shape(const TensorListShape<> &actual_shape, const TensorListShape<> &expected_shape, +// const std::string &name, const std::string &additional_msg = ""); +// void UniformShape(const TensorListShape<> &actual_shape, const std::string &name, +// const std::string &additional_msg = ""); -// template -// void Type(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, -// const std::string &additional_msg = "") { -// Type(spec, ws, argument_name, {type2id::value...}, additional_msg); -// } } // namespace validate - - } // namespace dali #endif // DALI_PIPELINE_OPERATOR_ERROR_REPORTING_H_ diff --git a/dali/pipeline/operator/name_utils.cc b/dali/pipeline/operator/name_utils.cc index 7460478cbc2..dfa89b8b966 100644 --- a/dali/pipeline/operator/name_utils.cc +++ b/dali/pipeline/operator/name_utils.cc @@ -15,8 +15,8 @@ #include #include +#include -#include "dali/core/error_handling.h" #include "dali/pipeline/operator/name_utils.h" #include "dali/pipeline/operator/op_spec.h" @@ -39,4 +39,21 @@ std::string GetOpDisplayName(const OpSpec &spec, bool include_module_path) { } } +std::string FormatInput(const OpSpec &spec, int input_idx, bool capitalize) { + if (spec.GetSchema().HasInputDox()) { + return fmt::format("{}nput `{}` ('__{}')", capitalize ? "I" : "i", input_idx, + spec.GetSchema().GetInputName(input_idx)); + } + + return fmt::format("{}nput `{}`", capitalize ? "I" : "i", input_idx); +} + +std::string FormatOutput(const OpSpec &spec, int output_idx, bool capitalize) { + return fmt::format("{}utput `{}`", capitalize ? "O" : "o", output_idx); +} + +std::string FormatArgument(const OpSpec &spec, const std::string &argument, bool capitalize) { + return fmt::format("{}rgument '{}'", capitalize ? "A" : "a", argument); +} + } // namespace dali diff --git a/dali/pipeline/operator/name_utils.h b/dali/pipeline/operator/name_utils.h index bcfcbfa54e5..48759be4e0d 100644 --- a/dali/pipeline/operator/name_utils.h +++ b/dali/pipeline/operator/name_utils.h @@ -42,6 +42,35 @@ DLL_PUBLIC std::string GetOpModule(const OpSpec &spec); */ DLL_PUBLIC std::string GetOpDisplayName(const OpSpec &spec, bool include_module_path = false); +/** + * @brief Uniformly format the display of the operator input index, optionally including the name + * if provided in schema doc. + * + * @param input_idx Index of the input + * @param capitalize should be true if the output should start with capital letter (used at the + * start of the sentence) + */ +DLL_PUBLIC std::string FormatInput(const OpSpec &spec, int input_idx, bool capitalize = false); + +/** + * @brief Uniformly format the display of the operator output index. + * + * @param input_idx Index of the output + * @param capitalize should be true if the output should start with capital letter (used at the + * start of the sentence) + */ +DLL_PUBLIC std::string FormatOutput(const OpSpec &spec, int output_idx, bool capitalize = false); + +/** + * @brief Uniformly format the display of the operator argument name + * + * @param argument string representing the name of the argument (without additional quotes) + * @param capitalize should be true if the output should start with capital letter (used at the + * start of the sentence) + */ +DLL_PUBLIC std::string FormatArgument(const OpSpec &spec, const std::string &argument, + bool capitalize = false); + } // namespace dali #endif // DALI_PIPELINE_OPERATOR_NAME_UTILS_H_ From e392ac6d1f32eac41c2c9c906d365e7836e20fb2 Mon Sep 17 00:00:00 2001 From: Krzysztof Lecki Date: Wed, 27 Mar 2024 23:59:50 +0100 Subject: [PATCH 4/4] Cleanup and examples of usage Signed-off-by: Krzysztof Lecki --- dali/operators/audio/preemphasis_filter_op.cc | 9 +- dali/operators/audio/preemphasis_filter_op.cu | 4 +- dali/operators/audio/preemphasis_filter_op.h | 5 +- dali/pipeline/executor/executor.cc | 7 ++ dali/pipeline/operator/error_reporting.cc | 15 +-- dali/pipeline/operator/error_reporting.h | 91 ++++++++++++------- 6 files changed, 84 insertions(+), 47 deletions(-) diff --git a/dali/operators/audio/preemphasis_filter_op.cc b/dali/operators/audio/preemphasis_filter_op.cc index 335414a30b1..da4c8d8df47 100644 --- a/dali/operators/audio/preemphasis_filter_op.cc +++ b/dali/operators/audio/preemphasis_filter_op.cc @@ -17,6 +17,7 @@ #include #include #include "dali/operators/audio/preemphasis_filter_op.h" +#include "dali/pipeline/operator/error_reporting.h" namespace dali { @@ -93,11 +94,11 @@ void PreemphasisFilterCPU::RunImplTyped(Workspace &ws) { void PreemphasisFilterCPU::RunImpl(Workspace &ws) { const auto &input = ws.Input(0); - TYPE_SWITCH(input.type(), type2id, InputType, PREEMPH_TYPES, ( - TYPE_SWITCH(output_type_, type2id, OutputType, PREEMPH_TYPES, ( + TYPE_SWITCH(input.type(), type2id, InputType, (PREEMPH_TYPES), ( + TYPE_SWITCH(output_type_, type2id, OutputType, (PREEMPH_TYPES), ( RunImplTyped(ws); - ), DALI_FAIL(make_string("Unsupported output type: ", output_type_))); // NOLINT - ), DALI_FAIL(make_string("Unsupported input type: ", input.type()))); // NOLINT + ), (validate::OutputType(spec_, ws, 0))); // NOLINT + ), (validate::InputType(spec_, ws, 0))); // NOLINT } DALI_REGISTER_OPERATOR(PreemphasisFilter, PreemphasisFilterCPU, CPU); diff --git a/dali/operators/audio/preemphasis_filter_op.cu b/dali/operators/audio/preemphasis_filter_op.cu index c6cf80c7950..6ef2cff1a00 100644 --- a/dali/operators/audio/preemphasis_filter_op.cu +++ b/dali/operators/audio/preemphasis_filter_op.cu @@ -108,8 +108,8 @@ void PreemphasisFilterGPU::RunImplTyped(Workspace &ws) { void PreemphasisFilterGPU::RunImpl(Workspace &ws) { const auto &input = ws.Input(0); - TYPE_SWITCH(input.type(), type2id, InputType, PREEMPH_TYPES, ( - TYPE_SWITCH(output_type_, type2id, OutputType, PREEMPH_TYPES, ( + TYPE_SWITCH(input.type(), type2id, InputType, (PREEMPH_TYPES), ( + TYPE_SWITCH(output_type_, type2id, OutputType, (PREEMPH_TYPES), ( RunImplTyped(ws); ), DALI_FAIL(make_string("Unsupported output type: ", output_type_))); // NOLINT ), DALI_FAIL(make_string("Unsupported input type: ", input.type()))); // NOLINT diff --git a/dali/operators/audio/preemphasis_filter_op.h b/dali/operators/audio/preemphasis_filter_op.h index 471e83473fe..ea9a79e2a69 100644 --- a/dali/operators/audio/preemphasis_filter_op.h +++ b/dali/operators/audio/preemphasis_filter_op.h @@ -22,9 +22,10 @@ #include "dali/pipeline/data/types.h" #include "dali/pipeline/operator/checkpointing/stateless_operator.h" #include "dali/pipeline/operator/operator.h" +#include "dali/pipeline/operator/error_reporting.h" #define PREEMPH_TYPES \ - (uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, uint64_t, int64_t, float, double) + uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, uint64_t, int64_t, float, double namespace dali { namespace detail { @@ -46,7 +47,7 @@ class PreemphasisFilter : public StatelessOperator { explicit PreemphasisFilter(const OpSpec &spec) : StatelessOperator(spec), - output_type_(spec.GetArgument(arg_names::kDtype)) { + output_type_(validate::Dtype(spec)) { auto border_str = spec.GetArgument(detail::kBorder); if (border_str == "zero") { border_type_ = BorderType::Zero; diff --git a/dali/pipeline/executor/executor.cc b/dali/pipeline/executor/executor.cc index 8781326ba1d..b901f0f86c9 100644 --- a/dali/pipeline/executor/executor.cc +++ b/dali/pipeline/executor/executor.cc @@ -27,6 +27,7 @@ #include "dali/pipeline/graph/op_graph_storage.h" #include "dali/pipeline/operator/builtin/conditional/split_merge.h" #include "dali/pipeline/operator/common.h" +#include "dali/pipeline/operator/error_reporting.h" #include "dali/pipeline/workspace/workspace.h" #include "dali/pipeline/workspace/workspace_data_factory.h" @@ -487,6 +488,12 @@ void Executor::RunHelper(OpNode &op_node, Workspac if (had_empty_layout) empty_layout_in_idxs.push_back(i); } + // TODO(klecki): Extract this to a separate function, this is just an example. + for (auto &argument_input : ws.ArgumentInputs()) { + // Check the types of argument inputs before they are accessed + validate::ArgumentType(spec, ws, argument_input.name); + } + bool should_allocate = false; { DomainTimeRange tr("[DALI][Executor] Setup"); diff --git a/dali/pipeline/operator/error_reporting.cc b/dali/pipeline/operator/error_reporting.cc index 566c399c6e3..88be5bc6128 100644 --- a/dali/pipeline/operator/error_reporting.cc +++ b/dali/pipeline/operator/error_reporting.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include #include #include -#include // for fmt::join #include // for ostream support @@ -31,7 +31,8 @@ #include "dali/pipeline/operator/op_spec.h" // template <> struct fmt::formatter : fmt::ostream_formatter {}; -// template <> struct fmt::formatter : fmt::tostring_formatter {}; +// template <> struct fmt::formatter : +// fmt::tostring_formatter {}; namespace dali { @@ -149,7 +150,7 @@ DALIDataType Type(DALIDataType actual_type, DALIDataType expected_type, const st additional_msg)); } -DALIDataType Type(DALIDataType actual_type, span &expected_types, +DALIDataType Type(DALIDataType actual_type, span expected_types, const std::string &name, const std::string &additional_msg) { if (std::size(expected_types) == 1) { return Type(actual_type, expected_types[0], name, additional_msg); @@ -172,7 +173,7 @@ DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, } DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, - span &allowed_types, const std::string &additional_msg) { + span allowed_types, const std::string &additional_msg) { DALIDataType dtype = ws.GetInputDataType(input_idx); return Type(dtype, allowed_types, FormatInput(spec, input_idx), additional_msg); } @@ -190,8 +191,8 @@ DALIDataType Dtype(const OpSpec &spec, DALIDataType allowed_type, bool allow_uns additional_msg); } -DALIDataType Dtype(const OpSpec &spec, span &allowed_types, bool allow_unspecified, - const std::string &additional_msg) { +DALIDataType Dtype(const OpSpec &spec, span allowed_types, + bool allow_unspecified, const std::string &additional_msg) { if (allow_unspecified && !spec.HasArgument("dtype")) { return DALI_NO_TYPE; } else if (!allow_unspecified && !spec.HasArgument("dtype")) { @@ -225,7 +226,7 @@ DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, } DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, - span &allowed_types, const std::string &additional_msg) { + span allowed_types, const std::string &additional_msg) { DALIDataType dtype = ws.GetOutputDataType(output_idx); return Type(dtype, allowed_types, FormatOutput(spec, output_idx), additional_msg); } diff --git a/dali/pipeline/operator/error_reporting.h b/dali/pipeline/operator/error_reporting.h index 1ef379d7fc6..f4c96ba4894 100644 --- a/dali/pipeline/operator/error_reporting.h +++ b/dali/pipeline/operator/error_reporting.h @@ -25,8 +25,8 @@ #include "dali/core/api_helper.h" #include "dali/core/span.h" #include "dali/pipeline/data/types.h" -#include "dali/pipeline/operator/op_spec.h" #include "dali/pipeline/operator/name_utils.h" +#include "dali/pipeline/operator/op_spec.h" #include "dali/pipeline/workspace/workspace.h" namespace dali { @@ -212,8 +212,8 @@ namespace validate { * @param additional_message * @return DALIDataType */ -DALIDataType Type(DALIDataType actual_type, DALIDataType expected_type, const std::string &name, - const std::string &additional_message = ""); +DLL_PUBLIC DALIDataType Type(DALIDataType actual_type, DALIDataType expected_type, + const std::string &name, const std::string &additional_message = ""); /** * @brief Check if the actual_type matches the expected_type. If not, throws TypeError @@ -228,44 +228,69 @@ DALIDataType Type(DALIDataType actual_type, DALIDataType expected_type, const st * @param additional_message * @return DALIDataType */ -DALIDataType Type(DALIDataType actual_type, span &expected_types, - const std::string &name, const std::string &additional_message = ""); +DLL_PUBLIC DALIDataType Type(DALIDataType actual_type, span expected_types, + const std::string &name, const std::string &additional_message = ""); -DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, - DALIDataType allowed_type, const std::string &additional_msg = ""); +DLL_PUBLIC DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, + DALIDataType allowed_type, + const std::string &additional_msg = ""); -DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, - span &allowed_types, const std::string &additional_msg = ""); +DLL_PUBLIC DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, + span allowed_types, + const std::string &additional_msg = ""); -template +template DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, const std::string &additional_msg = "") { - static std::array allowed_types = { - type2id::value...}; + static constexpr std::array allowed_types = { + type2id::value...}; return InputType(spec, ws, input_idx, make_cspan(allowed_types), additional_msg); } -DALIDataType Dtype(const OpSpec &spec, DALIDataType allowed_type, bool allow_unspecified = false, - const std::string &additional_msg = ""); - -DALIDataType Dtype(const OpSpec &spec, span &allowed_types, - bool allow_unspecified = false, const std::string &additional_msg = ""); - -DALIDataType Dtype(const OpSpec &spec, const Workspace &ws, bool (*is_valid)(DALIDataType), - const std::string &explanation); +/** + * @brief Check if provided argument 'dtype' has a valid value. If not, throws TypeError + * + * @param spec + * @param allowed_type + * @param allow_unspecified - if true, allows dtype to be left empty, in that case DALI_NO_TYPE is + * returned + * @param additional_msg + * @return DALIDataType - if type is valid, it will be returned. + */ +DLL_PUBLIC DALIDataType Dtype(const OpSpec &spec, DALIDataType allowed_type, + bool allow_unspecified = false, + const std::string &additional_msg = ""); + +DLL_PUBLIC DALIDataType Dtype(const OpSpec &spec, span allowed_types, + bool allow_unspecified = false, + const std::string &additional_msg = ""); + +template +DLL_PUBLIC DALIDataType Dtype(const OpSpec &spec, bool allow_unspecified = false, + const std::string &additional_msg = "") { + static constexpr std::array allowed_types = { + type2id::value...}; + return Dtype(spec, make_cspan(allowed_types), allow_unspecified, additional_msg); +} +DLL_PUBLIC DALIDataType Dtype(const OpSpec &spec, const Workspace &ws, + bool (*is_valid)(DALIDataType), const std::string &explanation); -DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, - DALIDataType allowed_type, const std::string &additional_msg = ""); +// Note that the output type checks are valid only if the output is already allocated by the +// executor. It may be tricky to use. +DLL_PUBLIC DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, + DALIDataType allowed_type, + const std::string &additional_msg = ""); -DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, - span &allowed_types, const std::string &additional_msg = ""); +DLL_PUBLIC DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, + span allowed_types, + const std::string &additional_msg = ""); -template -DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, - const std::string &additional_msg = "") { - static std::array allowed_types = { - type2id::value...}; +template +DLL_PUBLIC DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, + const std::string &additional_msg = "") { + static constexpr std::array allowed_types = { + type2id::value...}; return OutputType(spec, ws, output_idx, make_cspan(allowed_types), additional_msg); } @@ -278,8 +303,9 @@ DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, * @param additional_msg * @return DALIDataType */ -DALIDataType ArgumentType(const OpSpec &spec, const Workspace &ws, const std::string &arg_name, - const std::string &additional_msg = ""); +DLL_PUBLIC DALIDataType ArgumentType(const OpSpec &spec, const Workspace &ws, + const std::string &arg_name, + const std::string &additional_msg = ""); /** @} */ // end of TypeValidation @@ -314,7 +340,8 @@ DALIDataType ArgumentType(const OpSpec &spec, const Workspace &ws, const std::st // const std::string &name, const std::string &additional_msg = ""); // void ArgumentDim(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, -// int expected_dim, const std::string &name, const std::string &additional_msg = ""); +// int expected_dim, const std::string &name, const std::string &additional_msg = +// ""); // void ArgumentDim(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, // int expected_from, int expected_to, const std::string &name, // const std::string &additional_msg = "");