diff --git a/include/tl2cgen/compiler_param.h b/include/tl2cgen/compiler_param.h index 99a4f8a..b26e3db 100644 --- a/include/tl2cgen/compiler_param.h +++ b/include/tl2cgen/compiler_param.h @@ -32,6 +32,8 @@ struct CompilerParam { compilation time and reduce memory consumption during compilation. */ int parallel_comp{0}; + /*! \brief Wether to interpret threshold points as integers (0: no, >0: yes) */ + bool thresh_as_int{false}; /*! \brief If >0, produce extra messages */ int verbose{0}; /*! \brief Native lib name (without extension) */ diff --git a/include/tl2cgen/detail/compiler/ast/ast.h b/include/tl2cgen/detail/compiler/ast/ast.h index e5799fd..9d3d790 100644 --- a/include/tl2cgen/detail/compiler/ast/ast.h +++ b/include/tl2cgen/detail/compiler/ast/ast.h @@ -94,15 +94,17 @@ class NumericalConditionNode : public ConditionNode { public: using ThresholdVariantT = std::variant; NumericalConditionNode(std::uint32_t split_index, bool default_left, treelite::Operator op, - ThresholdVariantT threshold, std::optional quantized_threshold) + ThresholdVariantT threshold, std::optional quantized_threshold, bool thresh_as_int) : ConditionNode(split_index, default_left), op_(op), threshold_(threshold), quantized_threshold_(quantized_threshold), + thresh_as_int_(thresh_as_int), zero_quantized_(-1) {} treelite::Operator op_; ThresholdVariantT threshold_; std::optional quantized_threshold_; + bool thresh_as_int_; int zero_quantized_; // quantized value of 0.0f (useful when convert_missing_to_zero is set) std::string GetDump() const override; }; diff --git a/include/tl2cgen/detail/compiler/ast/builder.h b/include/tl2cgen/detail/compiler/ast/builder.h index 798808b..7b9f2b2 100644 --- a/include/tl2cgen/detail/compiler/ast/builder.h +++ b/include/tl2cgen/detail/compiler/ast/builder.h @@ -33,7 +33,7 @@ class ASTBuilder { ASTBuilder() : main_node_(nullptr) {} /* \brief Initially build AST from model */ - void BuildAST(treelite::Model const& model); + void BuildAST(treelite::Model const& model, bool thresh_as_int); /* \brief Generate is_categorical[] array, which tells whether each feature is categorical or numerical */ void GenerateIsCategoricalArray(); @@ -69,7 +69,7 @@ class ASTBuilder { template ASTNode* BuildASTFromTree(ASTNode* parent, treelite::Tree const& tree, int tree_id, - std::int32_t target_id, std::int32_t class_id, int nid); + std::int32_t target_id, std::int32_t class_id, int nid, bool thresh_as_int); // Keep tract of all nodes built so far, to prevent memory leak std::vector> nodes_; diff --git a/src/compiler/ast/build.cc b/src/compiler/ast/build.cc index 76e70db..bc7e3b5 100644 --- a/src/compiler/ast/build.cc +++ b/src/compiler/ast/build.cc @@ -56,7 +56,7 @@ std::optional> ComputeAverageFactor(treelite::Model co namespace tl2cgen::compiler::detail::ast { -void ASTBuilder::BuildAST(treelite::Model const& model) { +void ASTBuilder::BuildAST(treelite::Model const& model, bool thresh_as_int) { main_node_ = AddNode( nullptr, model.base_scores.AsVector(), ComputeAverageFactor(model), model.postprocessor); meta_.num_target_ = model.num_target; @@ -72,7 +72,7 @@ void ASTBuilder::BuildAST(treelite::Model const& model) { [&](auto&& model_preset) { for (std::size_t tree_id = 0; tree_id < model_preset.trees.size(); ++tree_id) { ASTNode* tree_head = BuildASTFromTree(func, model_preset.trees[tree_id], - static_cast(tree_id), model.target_id[tree_id], model.class_id[tree_id], 0); + static_cast(tree_id), model.target_id[tree_id], model.class_id[tree_id], 0, thresh_as_int); func->children_.push_back(tree_head); } using ModelPresetT = std::remove_const_t>; @@ -97,7 +97,7 @@ void ASTBuilder::BuildAST(treelite::Model const& model) { template ASTNode* ASTBuilder::BuildASTFromTree(ASTNode* parent, treelite::Tree const& tree, int tree_id, std::int32_t target_id, - std::int32_t class_id, int nid) { + std::int32_t class_id, int nid, bool thresh_as_int) { ASTNode* ast_node = nullptr; if (tree.IsLeaf(nid)) { if (meta_.leaf_vector_shape_[0] == 1 && meta_.leaf_vector_shape_[1] == 1) { @@ -109,7 +109,7 @@ ASTNode* ASTBuilder::BuildASTFromTree(ASTNode* parent, } else { if (tree.NodeType(nid) == treelite::TreeNodeType::kNumericalTestNode) { ast_node = AddNode(parent, tree.SplitIndex(nid), - tree.DefaultLeft(nid), tree.ComparisonOp(nid), tree.Threshold(nid), std::nullopt); + tree.DefaultLeft(nid), tree.ComparisonOp(nid), tree.Threshold(nid), std::nullopt, thresh_as_int); } else { ast_node = AddNode(parent, tree.SplitIndex(nid), tree.DefaultLeft(nid), tree.CategoryList(nid), tree.CategoryListRightChild(nid)); @@ -118,9 +118,9 @@ ASTNode* ASTBuilder::BuildASTFromTree(ASTNode* parent, dynamic_cast(ast_node)->gain_ = tree.Gain(nid); } ast_node->children_.push_back( - BuildASTFromTree(ast_node, tree, tree_id, target_id, class_id, tree.LeftChild(nid))); + BuildASTFromTree(ast_node, tree, tree_id, target_id, class_id, tree.LeftChild(nid), thresh_as_int)); ast_node->children_.push_back( - BuildASTFromTree(ast_node, tree, tree_id, target_id, class_id, tree.RightChild(nid))); + BuildASTFromTree(ast_node, tree, tree_id, target_id, class_id, tree.RightChild(nid), thresh_as_int)); } ast_node->node_id_ = nid; ast_node->tree_id_ = tree_id; @@ -135,8 +135,8 @@ ASTNode* ASTBuilder::BuildASTFromTree(ASTNode* parent, } template ASTNode* ASTBuilder::BuildASTFromTree( - ASTNode*, treelite::Tree const&, int, std::int32_t, std::int32_t, int); + ASTNode*, treelite::Tree const&, int, std::int32_t, std::int32_t, int, bool); template ASTNode* ASTBuilder::BuildASTFromTree( - ASTNode*, treelite::Tree const&, int, std::int32_t, std::int32_t, int); + ASTNode*, treelite::Tree const&, int, std::int32_t, std::int32_t, int, bool); } // namespace tl2cgen::compiler::detail::ast diff --git a/src/compiler/codegen/condition_node.cc b/src/compiler/codegen/condition_node.cc index 7ae92ac..572296b 100644 --- a/src/compiler/codegen/condition_node.cc +++ b/src/compiler/codegen/condition_node.cc @@ -34,15 +34,55 @@ std::string GetFabsCFunc(std::string const& threshold_type) { } } +std::string float_to_bin(float number) { + std::stringstream ss; + ss << "0x" << std::hex << std::setw(8) << std::setfill('0') << *(reinterpret_cast(&number)); + return ss.str(); +} + +std::string double_to_bin(double number) { + std::stringstream ss; + ss << "0x" << std::hex << std::setw(16) << std::setfill('0') << *(reinterpret_cast(&number)); + return ss.str(); +} + +std::string getOppositeOperator(const std::string& op) { + if (op == "<") return ">="; + else if (op == ">") return "<="; + else if (op == "<=") return ">"; + else if (op == ">=") return "<"; + else return ""; // unknown operator +} + +std::string thresh_as_int(const std::string& threshold_type, ast::NumericalConditionNode const* node) { + std::string negatstring = ""; + float splitval = std::get(node->threshold_); + std::string op = treelite::OperatorToString(node->op_); + std::string new_dtype = (threshold_type == "double") ? "long long" : "int"; + + if (splitval < 0) { + splitval = -splitval; + negatstring = " ^ (0b1 << 31)"; + op = getOppositeOperator(op); + } + + std::string split_val_bin = float_to_bin(splitval); + return "(*( (("+new_dtype+"*)(data)) + "+std::to_string(node->split_index_)+" )" + +negatstring+")"+op+"(("+new_dtype+")("+split_val_bin+"))"; +} + inline std::string ExtractNumericalCondition(ast::NumericalConditionNode const* node) { std::string const threshold_type = codegen::GetThresholdCType(node); std::string result; if (node->quantized_threshold_) { // Quantized threshold - std::string lhs - = fmt::format("data[{split_index}].qvalue", "split_index"_a = node->split_index_); + std::string lhs = fmt::format("data[{split_index}].qvalue", "split_index"_a = node->split_index_); result = fmt::format("{lhs} {opname} {threshold}", "lhs"_a = lhs, - "opname"_a = treelite::OperatorToString(node->op_), - "threshold"_a = *node->quantized_threshold_); + "opname"_a = treelite::OperatorToString(node->op_), "threshold"_a = *node->quantized_threshold_); + } else if (node->thresh_as_int_) { // Threshold as integer + if (!(threshold_type == "float" || threshold_type == "double")) { // Only float and double are supported + throw std::runtime_error("Invalid threshold type."); + } + result = thresh_as_int(threshold_type, node); } else { result = std::visit( [&](auto&& threshold) -> std::string { diff --git a/src/compiler/compiler.cc b/src/compiler/compiler.cc index 6a4c0b7..baf48a1 100644 --- a/src/compiler/compiler.cc +++ b/src/compiler/compiler.cc @@ -26,7 +26,7 @@ detail::ast::ASTBuilder LowerToAST( treelite::Model const& model, tl2cgen::compiler::CompilerParam const& param) { /* 1. Lower the tree ensemble model into Abstract Syntax Tree (AST) */ detail::ast::ASTBuilder builder; - builder.BuildAST(model); + builder.BuildAST(model, param.thresh_as_int); /* 2. Apply optimization passes to AST */ if (param.annotate_in != "NULL") { diff --git a/src/compiler/compiler_param.cc b/src/compiler/compiler_param.cc index eec1b9c..eeac633 100644 --- a/src/compiler/compiler_param.cc +++ b/src/compiler/compiler_param.cc @@ -26,6 +26,10 @@ CompilerParam CompilerParam::ParseFromJSON(char const* param_json_str) { TL2CGEN_CHECK(e.value.IsInt()) << "Expected an integer for 'quantize'"; param.quantize = e.value.GetInt(); TL2CGEN_CHECK_GE(param.quantize, 0) << "'quantize' must be 0 or greater"; + } else if (key == "thresh_as_int") { + TL2CGEN_CHECK(e.value.IsInt()) << "Expected an integer for 'thresh_as_int'"; + param.thresh_as_int = e.value.GetInt(); + TL2CGEN_CHECK_GE(param.quantize, 0) << "'thresh_as_int' must be 0 or greater"; } else if (key == "parallel_comp") { TL2CGEN_CHECK(e.value.IsInt()) << "Expected an integer for 'parallel_comp'"; param.parallel_comp = e.value.GetInt();