diff --git a/.gitmodules b/.gitmodules index 41f89cab274..4abcd184b85 100644 --- a/.gitmodules +++ b/.gitmodules @@ -29,3 +29,6 @@ [submodule "third_party/cvcuda"] path = third_party/cvcuda url = https://github.com/CVCUDA/CV-CUDA.git +[submodule "third_party/fmt"] + path = third_party/fmt + url = git@github.com:fmtlib/fmt.git diff --git a/cmake/Dependencies.common.cmake b/cmake/Dependencies.common.cmake index 46cbda225e3..8c56ab4f536 100644 --- a/cmake/Dependencies.common.cmake +++ b/cmake/Dependencies.common.cmake @@ -380,3 +380,11 @@ if(BUILD_NVIMAGECODEC) endif() endif() endif() + +################################################################## +# {fmt} +################################################################## +check_and_add_cmake_submodule(${PROJECT_SOURCE_DIR}/third_party/fmt EXCLUDE_FROM_ALL) +set_target_properties(fmt PROPERTIES POSITION_INDEPENDENT_CODE ON) +list(APPEND DALI_LIBS fmt) +list(APPEND DALI_EXCLUDES libfmt.a) diff --git a/dali/operators/audio/preemphasis_filter_op.cc b/dali/operators/audio/preemphasis_filter_op.cc index 335414a30b1..da4c8d8df47 100644 --- a/dali/operators/audio/preemphasis_filter_op.cc +++ b/dali/operators/audio/preemphasis_filter_op.cc @@ -17,6 +17,7 @@ #include #include #include "dali/operators/audio/preemphasis_filter_op.h" +#include "dali/pipeline/operator/error_reporting.h" namespace dali { @@ -93,11 +94,11 @@ void PreemphasisFilterCPU::RunImplTyped(Workspace &ws) { void PreemphasisFilterCPU::RunImpl(Workspace &ws) { const auto &input = ws.Input(0); - TYPE_SWITCH(input.type(), type2id, InputType, PREEMPH_TYPES, ( - TYPE_SWITCH(output_type_, type2id, OutputType, PREEMPH_TYPES, ( + TYPE_SWITCH(input.type(), type2id, InputType, (PREEMPH_TYPES), ( + TYPE_SWITCH(output_type_, type2id, OutputType, (PREEMPH_TYPES), ( RunImplTyped(ws); - ), DALI_FAIL(make_string("Unsupported output type: ", output_type_))); // NOLINT - ), DALI_FAIL(make_string("Unsupported input type: ", input.type()))); // NOLINT + ), (validate::OutputType(spec_, ws, 0))); // NOLINT + ), (validate::InputType(spec_, ws, 0))); // NOLINT } DALI_REGISTER_OPERATOR(PreemphasisFilter, PreemphasisFilterCPU, CPU); diff --git a/dali/operators/audio/preemphasis_filter_op.cu b/dali/operators/audio/preemphasis_filter_op.cu index c6cf80c7950..6ef2cff1a00 100644 --- a/dali/operators/audio/preemphasis_filter_op.cu +++ b/dali/operators/audio/preemphasis_filter_op.cu @@ -108,8 +108,8 @@ void PreemphasisFilterGPU::RunImplTyped(Workspace &ws) { void PreemphasisFilterGPU::RunImpl(Workspace &ws) { const auto &input = ws.Input(0); - TYPE_SWITCH(input.type(), type2id, InputType, PREEMPH_TYPES, ( - TYPE_SWITCH(output_type_, type2id, OutputType, PREEMPH_TYPES, ( + TYPE_SWITCH(input.type(), type2id, InputType, (PREEMPH_TYPES), ( + TYPE_SWITCH(output_type_, type2id, OutputType, (PREEMPH_TYPES), ( RunImplTyped(ws); ), DALI_FAIL(make_string("Unsupported output type: ", output_type_))); // NOLINT ), DALI_FAIL(make_string("Unsupported input type: ", input.type()))); // NOLINT diff --git a/dali/operators/audio/preemphasis_filter_op.h b/dali/operators/audio/preemphasis_filter_op.h index 471e83473fe..ea9a79e2a69 100644 --- a/dali/operators/audio/preemphasis_filter_op.h +++ b/dali/operators/audio/preemphasis_filter_op.h @@ -22,9 +22,10 @@ #include "dali/pipeline/data/types.h" #include "dali/pipeline/operator/checkpointing/stateless_operator.h" #include "dali/pipeline/operator/operator.h" +#include "dali/pipeline/operator/error_reporting.h" #define PREEMPH_TYPES \ - (uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, uint64_t, int64_t, float, double) + uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, uint64_t, int64_t, float, double namespace dali { namespace detail { @@ -46,7 +47,7 @@ class PreemphasisFilter : public StatelessOperator { explicit PreemphasisFilter(const OpSpec &spec) : StatelessOperator(spec), - output_type_(spec.GetArgument(arg_names::kDtype)) { + output_type_(validate::Dtype(spec)) { auto border_str = spec.GetArgument(detail::kBorder); if (border_str == "zero") { border_type_ = BorderType::Zero; diff --git a/dali/pipeline/executor/executor.cc b/dali/pipeline/executor/executor.cc index 8781326ba1d..b901f0f86c9 100644 --- a/dali/pipeline/executor/executor.cc +++ b/dali/pipeline/executor/executor.cc @@ -27,6 +27,7 @@ #include "dali/pipeline/graph/op_graph_storage.h" #include "dali/pipeline/operator/builtin/conditional/split_merge.h" #include "dali/pipeline/operator/common.h" +#include "dali/pipeline/operator/error_reporting.h" #include "dali/pipeline/workspace/workspace.h" #include "dali/pipeline/workspace/workspace_data_factory.h" @@ -487,6 +488,12 @@ void Executor::RunHelper(OpNode &op_node, Workspac if (had_empty_layout) empty_layout_in_idxs.push_back(i); } + // TODO(klecki): Extract this to a separate function, this is just an example. + for (auto &argument_input : ws.ArgumentInputs()) { + // Check the types of argument inputs before they are accessed + validate::ArgumentType(spec, ws, argument_input.name); + } + bool should_allocate = false; { DomainTimeRange tr("[DALI][Executor] Setup"); diff --git a/dali/pipeline/operator/error_reporting.cc b/dali/pipeline/operator/error_reporting.cc index 6f4e1d1a67c..88be5bc6128 100644 --- a/dali/pipeline/operator/error_reporting.cc +++ b/dali/pipeline/operator/error_reporting.cc @@ -12,18 +12,35 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include #include #include +// for fmt::join +#include +// for ostream support +#include #include "dali/core/error_handling.h" +#include "dali/pipeline/data/backend.h" +#include "dali/pipeline/data/types.h" #include "dali/pipeline/operator/error_reporting.h" +#include "dali/pipeline/operator/name_utils.h" #include "dali/pipeline/operator/op_spec.h" +// template <> struct fmt::formatter : fmt::ostream_formatter {}; +// template <> struct fmt::formatter : +// fmt::tostring_formatter {}; + + namespace dali { +auto format_as(dali::DALIDataType type) { + return dali::to_string(type); +} + std::vector GetOperatorOriginInfo(const OpSpec &spec) { auto origin_stack_filename = spec.GetRepeatedArgument("_origin_stack_filename"); auto origin_stack_lineno = spec.GetRepeatedArgument("_origin_stack_lineno"); @@ -67,8 +84,7 @@ void PropagateError(ErrorInfo error) { catch (DaliError &e) { e.UpdateMessage(make_string(error.context_info, e.what(), error.additional_message)); throw; - } - catch (DALIException &e) { + } catch (DALIException &e) { // We drop the C++ stack trace at this point and go back to runtime_error. throw std::runtime_error( make_string(error.context_info, e.what(), @@ -109,9 +125,121 @@ std::string GetErrorContextMessage(const OpSpec &spec) { formatted_origin_stack + "\n") : " "; // we need space before "encountered" - return make_string("Error in ", device, " operator `", op_name, "`", - optional_stack_mention, "encountered:\n\n"); + return make_string("Error in ", device, " operator `", op_name, "`", optional_stack_mention, + "encountered:\n\n"); } +namespace validate { + +std::string SepIfNotEmpty(const std::string &str, const std::string &sep = " ") { + if (str.empty()) { + return ""; + } + return sep; +} + +DALIDataType Type(DALIDataType actual_type, DALIDataType expected_type, const std::string &name, + const std::string &additional_msg) { + if (actual_type == expected_type) { + return actual_type; + } + + throw DaliTypeError(fmt::format("Unexpected type for {}. Got type: `{}` but expected: `{}`.{}{}", + name, actual_type, expected_type, SepIfNotEmpty(additional_msg), + additional_msg)); +} + +DALIDataType Type(DALIDataType actual_type, span expected_types, + const std::string &name, const std::string &additional_msg) { + if (std::size(expected_types) == 1) { + return Type(actual_type, expected_types[0], name, additional_msg); + } + for (auto expected_type : expected_types) { + if (actual_type == expected_type) { + return actual_type; + } + } + + throw DaliTypeError(fmt::format( + "Unexpected type for {}. Got type: `{}` but expected one of: `{}`.{}{}", name, actual_type, + fmt::join(expected_types, "`, `"), SepIfNotEmpty(additional_msg), additional_msg)); +} + +DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, + DALIDataType allowed_type, const std::string &additional_msg) { + DALIDataType dtype = ws.GetInputDataType(input_idx); + return Type(dtype, allowed_type, FormatInput(spec, input_idx), additional_msg); +} + +DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, + span allowed_types, const std::string &additional_msg) { + DALIDataType dtype = ws.GetInputDataType(input_idx); + return Type(dtype, allowed_types, FormatInput(spec, input_idx), additional_msg); +} + +DALIDataType Dtype(const OpSpec &spec, DALIDataType allowed_type, bool allow_unspecified, + const std::string &additional_msg) { + if (allow_unspecified && !spec.HasArgument("dtype")) { + return DALI_NO_TYPE; + } else if (!allow_unspecified && !spec.HasArgument("dtype")) { + throw DaliValueError(fmt::format("{} was not specified.{}{}", + FormatArgument(spec, "dtype", true), + SepIfNotEmpty(additional_msg), additional_msg)); + } + return Type(spec.GetArgument("dtype"), allowed_type, FormatArgument(spec, "dtype"), + additional_msg); +} + +DALIDataType Dtype(const OpSpec &spec, span allowed_types, + bool allow_unspecified, const std::string &additional_msg) { + if (allow_unspecified && !spec.HasArgument("dtype")) { + return DALI_NO_TYPE; + } else if (!allow_unspecified && !spec.HasArgument("dtype")) { + throw DaliValueError(fmt::format("{} was not specified.{}{}", + FormatArgument(spec, "dtype", true), + SepIfNotEmpty(additional_msg), additional_msg)); + } + return Type(spec.GetArgument("dtype"), allowed_types, FormatArgument(spec, "dtype"), + additional_msg); +} + +void Dim(int actual_dim, int expected_dim, const std::string &name, + const std::string &additional_msg) { + if (actual_dim == expected_dim) { + return; + } + throw DaliValueError(fmt::format("Got dim: `{}` for {}, but expected: `{}`.{}{}", actual_dim, + name, expected_dim, SepIfNotEmpty(additional_msg), + additional_msg)); +} + +DALIDataType Dtype(const OpSpec &spec, const Workspace &ws, bool (*is_valid)(DALIDataType), + const std::string &explanation) { + return DALI_NO_TYPE; // TODO(klecki): implement +} + +DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, + DALIDataType allowed_type, const std::string &additional_msg) { + DALIDataType dtype = ws.GetOutputDataType(output_idx); + return Type(dtype, allowed_type, FormatOutput(spec, output_idx), additional_msg); +} + +DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, + span allowed_types, const std::string &additional_msg) { + DALIDataType dtype = ws.GetOutputDataType(output_idx); + return Type(dtype, allowed_types, FormatOutput(spec, output_idx), additional_msg); +} + +DALIDataType ArgumentType(const OpSpec &spec, const Workspace &ws, const std::string &arg_name, + const std::string &additional_msg) { + DALIDataType expected_type = spec.GetSchema().GetArgumentType(arg_name); + if (!spec.HasTensorArgument(arg_name)) { + return expected_type; + } + return Type(ws.ArgumentInput(arg_name).type(), expected_type, FormatArgument(spec, arg_name), + additional_msg); +} + +} // namespace validate } // namespace dali diff --git a/dali/pipeline/operator/error_reporting.h b/dali/pipeline/operator/error_reporting.h index d84bdf0f6f9..f4c96ba4894 100644 --- a/dali/pipeline/operator/error_reporting.h +++ b/dali/pipeline/operator/error_reporting.h @@ -23,9 +23,11 @@ #include #include "dali/core/api_helper.h" +#include "dali/core/span.h" #include "dali/pipeline/data/types.h" -#include "dali/pipeline/operator/op_spec.h" #include "dali/pipeline/operator/name_utils.h" +#include "dali/pipeline/operator/op_spec.h" +#include "dali/pipeline/workspace/workspace.h" namespace dali { @@ -186,6 +188,173 @@ class DaliStopIteration : public DaliError { */ std::string GetErrorContextMessage(const OpSpec &spec); +namespace validate { + +/** @defgroup TypeValidation Validation of types. + * Input/Output type validation - based on the workspace contents and the index. + * Argument type validation - internal check if the argument matches the spec before operator + * code is able to access it. + * Argument 'dtype' validation - based on the operator spec, checks the contents. + * Type() variant provides generic validation. + * @{ + */ + +/** + * @brief Check if the actual_type matches the expected_type. If not, throws TypeError + * containing the message: + * + * Unexpected type for : Got type: , but expected: . + * + * + * @param actual_type + * @param expected_type + * @param name + * @param additional_message + * @return DALIDataType + */ +DLL_PUBLIC DALIDataType Type(DALIDataType actual_type, DALIDataType expected_type, + const std::string &name, const std::string &additional_message = ""); + +/** + * @brief Check if the actual_type matches the expected_type. If not, throws TypeError + * containing the message: + * + * Unexpected type for : Got type: , but expected one of: . + * + * + * @param actual_type + * @param expected_types + * @param name + * @param additional_message + * @return DALIDataType + */ +DLL_PUBLIC DALIDataType Type(DALIDataType actual_type, span expected_types, + const std::string &name, const std::string &additional_message = ""); + +DLL_PUBLIC DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, + DALIDataType allowed_type, + const std::string &additional_msg = ""); + +DLL_PUBLIC DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, + span allowed_types, + const std::string &additional_msg = ""); + +template +DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx, + const std::string &additional_msg = "") { + static constexpr std::array allowed_types = { + type2id::value...}; + return InputType(spec, ws, input_idx, make_cspan(allowed_types), additional_msg); +} + +/** + * @brief Check if provided argument 'dtype' has a valid value. If not, throws TypeError + * + * @param spec + * @param allowed_type + * @param allow_unspecified - if true, allows dtype to be left empty, in that case DALI_NO_TYPE is + * returned + * @param additional_msg + * @return DALIDataType - if type is valid, it will be returned. + */ +DLL_PUBLIC DALIDataType Dtype(const OpSpec &spec, DALIDataType allowed_type, + bool allow_unspecified = false, + const std::string &additional_msg = ""); + +DLL_PUBLIC DALIDataType Dtype(const OpSpec &spec, span allowed_types, + bool allow_unspecified = false, + const std::string &additional_msg = ""); + +template +DLL_PUBLIC DALIDataType Dtype(const OpSpec &spec, bool allow_unspecified = false, + const std::string &additional_msg = "") { + static constexpr std::array allowed_types = { + type2id::value...}; + return Dtype(spec, make_cspan(allowed_types), allow_unspecified, additional_msg); +} + +DLL_PUBLIC DALIDataType Dtype(const OpSpec &spec, const Workspace &ws, + bool (*is_valid)(DALIDataType), const std::string &explanation); + +// Note that the output type checks are valid only if the output is already allocated by the +// executor. It may be tricky to use. +DLL_PUBLIC DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, + DALIDataType allowed_type, + const std::string &additional_msg = ""); + +DLL_PUBLIC DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, + span allowed_types, + const std::string &additional_msg = ""); + +template +DLL_PUBLIC DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx, + const std::string &additional_msg = "") { + static constexpr std::array allowed_types = { + type2id::value...}; + return OutputType(spec, ws, output_idx, make_cspan(allowed_types), additional_msg); +} + +/** + * @brief Verifies if given argument input has a correct backing type. + * + * @param spec + * @param ws + * @param arg_name + * @param additional_msg + * @return DALIDataType + */ +DLL_PUBLIC DALIDataType ArgumentType(const OpSpec &spec, const Workspace &ws, + const std::string &arg_name, + const std::string &additional_msg = ""); + +/** @} */ // end of TypeValidation + +// TODO(klecki): Same convention can be applied to checks for inputs and arguments: +// * dimensionality +// * shape: matching, uniform, non-empty ... +// * layout (we already have some of this) +// Also, there are groups of arguments similar to 'dtype', that can have specific checks: +// * relative coordinates [0, 1] in tensor +// * absolute coordinates in tensor shape. +// * axis within dim. +// Examples: Roi, crop window, anchor points, bboxes... +// +// Mutually exclusive arguments are also popular. +// +// The downside is, this generates a lot of boilerplate code, maybe we should just stick with +// the main variant + the FormatInput/Output/Argument variants. +// +// +// void Dim(int actual_dim, int expected_dim, const std::string &name, +// const std::string &additional_msg = ""); +// void Dim(int actual_dim, int expected_from, int expected_to, const std::string &name, +// const std::string &additional_msg = ""); +// void Dim(int actual_dim, span expected, const std::string &name, +// const std::string &additional_msg = ""); + +// void InputDim(const OpSpec &spec, const Workspace &ws, int input_idx, int expected_dim, +// const std::string &name, const std::string &additional_msg = ""); +// void InputDim(const OpSpec &spec, const Workspace &ws, int input_idx, int expected_from, +// int expected_to, const std::string &name, const std::string &additional_msg = ""); +// void InputDim(const OpSpec &spec, const Workspace &ws, int input_idx, span &expected, +// const std::string &name, const std::string &additional_msg = ""); + +// void ArgumentDim(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, +// int expected_dim, const std::string &name, const std::string &additional_msg = +// ""); +// void ArgumentDim(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, +// int expected_from, int expected_to, const std::string &name, +// const std::string &additional_msg = ""); +// void ArgumentDim(const OpSpec &spec, const Workspace &ws, const std::string &argument_name, +// span &expected, const std::string &name, +// const std::string &additional_msg = ""); + +// void Shape(const TensorListShape<> &actual_shape, const TensorListShape<> &expected_shape, +// const std::string &name, const std::string &additional_msg = ""); +// void UniformShape(const TensorListShape<> &actual_shape, const std::string &name, +// const std::string &additional_msg = ""); + +} // namespace validate } // namespace dali #endif // DALI_PIPELINE_OPERATOR_ERROR_REPORTING_H_ diff --git a/dali/pipeline/operator/name_utils.cc b/dali/pipeline/operator/name_utils.cc index 7460478cbc2..dfa89b8b966 100644 --- a/dali/pipeline/operator/name_utils.cc +++ b/dali/pipeline/operator/name_utils.cc @@ -15,8 +15,8 @@ #include #include +#include -#include "dali/core/error_handling.h" #include "dali/pipeline/operator/name_utils.h" #include "dali/pipeline/operator/op_spec.h" @@ -39,4 +39,21 @@ std::string GetOpDisplayName(const OpSpec &spec, bool include_module_path) { } } +std::string FormatInput(const OpSpec &spec, int input_idx, bool capitalize) { + if (spec.GetSchema().HasInputDox()) { + return fmt::format("{}nput `{}` ('__{}')", capitalize ? "I" : "i", input_idx, + spec.GetSchema().GetInputName(input_idx)); + } + + return fmt::format("{}nput `{}`", capitalize ? "I" : "i", input_idx); +} + +std::string FormatOutput(const OpSpec &spec, int output_idx, bool capitalize) { + return fmt::format("{}utput `{}`", capitalize ? "O" : "o", output_idx); +} + +std::string FormatArgument(const OpSpec &spec, const std::string &argument, bool capitalize) { + return fmt::format("{}rgument '{}'", capitalize ? "A" : "a", argument); +} + } // namespace dali diff --git a/dali/pipeline/operator/name_utils.h b/dali/pipeline/operator/name_utils.h index bcfcbfa54e5..48759be4e0d 100644 --- a/dali/pipeline/operator/name_utils.h +++ b/dali/pipeline/operator/name_utils.h @@ -42,6 +42,35 @@ DLL_PUBLIC std::string GetOpModule(const OpSpec &spec); */ DLL_PUBLIC std::string GetOpDisplayName(const OpSpec &spec, bool include_module_path = false); +/** + * @brief Uniformly format the display of the operator input index, optionally including the name + * if provided in schema doc. + * + * @param input_idx Index of the input + * @param capitalize should be true if the output should start with capital letter (used at the + * start of the sentence) + */ +DLL_PUBLIC std::string FormatInput(const OpSpec &spec, int input_idx, bool capitalize = false); + +/** + * @brief Uniformly format the display of the operator output index. + * + * @param input_idx Index of the output + * @param capitalize should be true if the output should start with capital letter (used at the + * start of the sentence) + */ +DLL_PUBLIC std::string FormatOutput(const OpSpec &spec, int output_idx, bool capitalize = false); + +/** + * @brief Uniformly format the display of the operator argument name + * + * @param argument string representing the name of the argument (without additional quotes) + * @param capitalize should be true if the output should start with capital letter (used at the + * start of the sentence) + */ +DLL_PUBLIC std::string FormatArgument(const OpSpec &spec, const std::string &argument, + bool capitalize = false); + } // namespace dali #endif // DALI_PIPELINE_OPERATOR_NAME_UTILS_H_ diff --git a/third_party/fmt b/third_party/fmt new file mode 160000 index 00000000000..e69e5f977d4 --- /dev/null +++ b/third_party/fmt @@ -0,0 +1 @@ +Subproject commit e69e5f977d458f2650bb346dadf2ad30c5320281