|
47 | 47 | #include <RooGaussModel.h> |
48 | 48 | #include <RooWorkspace.h> |
49 | 49 | #include <RooRealIntegral.h> |
| 50 | +#include <RooSpline.h> |
| 51 | +#include <TSpline.h> |
50 | 52 |
|
51 | 53 | #include <TF1.h> |
52 | 54 | #include <TH1.h> |
@@ -599,6 +601,63 @@ class ParamHistFuncFactory : public RooFit::JSONIO::Importer { |
599 | 601 | } |
600 | 602 | }; |
601 | 603 |
|
| 604 | +class RooSplineFactory : public RooFit::JSONIO::Importer { |
| 605 | +public: |
| 606 | + bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override |
| 607 | + { |
| 608 | + const std::string name(RooJSONFactoryWSTool::name(p)); |
| 609 | + |
| 610 | + // Mandatory fields |
| 611 | + if (!p.has_child("x")) { |
| 612 | + RooJSONFactoryWSTool::error("no x given in '" + name + "'"); |
| 613 | + } |
| 614 | + if (!p.has_child("x0") || !p.has_child("y0")) { |
| 615 | + RooJSONFactoryWSTool::error("no x0/y0 given in '" + name + "'"); |
| 616 | + } |
| 617 | + |
| 618 | + RooAbsReal *x = tool->requestArg<RooAbsReal>(p, "x"); |
| 619 | + |
| 620 | + // Optional fields (defaults follow RooSpline ctor defaults) |
| 621 | + std::string algo = p.has_child("interpolation") ? p["interpolation"].val() : "poly3"; |
| 622 | + int order = 0; |
| 623 | + if (algo == "poly3") |
| 624 | + order = 3; |
| 625 | + else if (algo == "poly5") |
| 626 | + order = 5; |
| 627 | + else { |
| 628 | + RooJSONFactoryWSTool::error("unsupported algo '" + algo + "' for RooSpline in '" + name + |
| 629 | + "': allowed are 'poly3' and 'poly5'"); |
| 630 | + } |
| 631 | + const bool logx = p.has_child("logx") ? p["logx"].val_bool() : false; |
| 632 | + const bool logy = p.has_child("logy") ? p["logy"].val_bool() : false; |
| 633 | + |
| 634 | + // Read knots |
| 635 | + std::vector<double> x0; |
| 636 | + std::vector<double> y0; |
| 637 | + x0.reserve(p["x0"].num_children()); |
| 638 | + y0.reserve(p["y0"].num_children()); |
| 639 | + |
| 640 | + for (const auto &v : p["x0"].children()) |
| 641 | + x0.push_back(v.val_double()); |
| 642 | + for (const auto &v : p["y0"].children()) |
| 643 | + y0.push_back(v.val_double()); |
| 644 | + |
| 645 | + if (x0.size() != y0.size()) { |
| 646 | + RooJSONFactoryWSTool::error("x0/y0 size mismatch in '" + name + "': x0 has " + std::to_string(x0.size()) + |
| 647 | + ", y0 has " + std::to_string(y0.size())); |
| 648 | + } |
| 649 | + if (x0.size() < 2) { |
| 650 | + RooJSONFactoryWSTool::error("need at least 2 knots in '" + name + "'"); |
| 651 | + } |
| 652 | + |
| 653 | + // Construct RooSpline(name,title, x, x0, y0, order, logx, logy) |
| 654 | + tool->wsEmplace<::RooSpline>(name.c_str(), *x, std::span<const double>(x0.data(), x0.size()), |
| 655 | + std::span<const double>(y0.data(), y0.size()), order, logx, logy); |
| 656 | + |
| 657 | + return true; |
| 658 | + } |
| 659 | +}; |
| 660 | + |
602 | 661 | /////////////////////////////////////////////////////////////////////////////////////////////////////// |
603 | 662 | // specialized exporter implementations |
604 | 663 | /////////////////////////////////////////////////////////////////////////////////////////////////////// |
@@ -1061,6 +1120,42 @@ class ParamHistFuncStreamer : public RooFit::JSONIO::Exporter { |
1061 | 1120 | } |
1062 | 1121 | }; |
1063 | 1122 |
|
| 1123 | +class RooSplineStreamer : public RooFit::JSONIO::Exporter { |
| 1124 | +public: |
| 1125 | + std::string const &key() const override; |
| 1126 | + |
| 1127 | + bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, RooFit::Detail::JSONNode &elem) const override |
| 1128 | + { |
| 1129 | + auto const *rs = static_cast<RooSpline const *>(func); |
| 1130 | + |
| 1131 | + elem["type"] << key(); |
| 1132 | + |
| 1133 | + // Independent variable |
| 1134 | + elem["x"] << rs->x().GetName(); |
| 1135 | + |
| 1136 | + // Spline configuration |
| 1137 | + // Canonical algo for RooSpline |
| 1138 | + elem["interpolation"] << (rs->order() == 5 ? "poly5" : "poly3"); |
| 1139 | + elem["logx"] << rs->logx(); |
| 1140 | + elem["logy"] << rs->logy(); |
| 1141 | + |
| 1142 | + // Serialize knots as primitive arrays |
| 1143 | + TSpline const &sp = rs->spline(); |
| 1144 | + auto &x0 = elem["x0"].set_seq(); |
| 1145 | + auto &y0 = elem["y0"].set_seq(); |
| 1146 | + |
| 1147 | + const int np = sp.GetNp(); |
| 1148 | + for (int i = 0; i < np; ++i) { |
| 1149 | + double xk = 0.0, yk = 0.0; |
| 1150 | + sp.GetKnot(i, xk, yk); |
| 1151 | + x0.append_child() << xk; |
| 1152 | + y0.append_child() << yk; |
| 1153 | + } |
| 1154 | + |
| 1155 | + return true; |
| 1156 | + } |
| 1157 | +}; |
| 1158 | + |
1064 | 1159 | #define DEFINE_EXPORTER_KEY(class_name, name) \ |
1065 | 1160 | std::string const &class_name::key() const \ |
1066 | 1161 | { \ |
@@ -1099,6 +1194,7 @@ DEFINE_EXPORTER_KEY(RooDerivativeStreamer, "derivative"); |
1099 | 1194 | DEFINE_EXPORTER_KEY(RooFFTConvPdfStreamer, "fft_conv_pdf"); |
1100 | 1195 | DEFINE_EXPORTER_KEY(RooExtendPdfStreamer, "extend_pdf"); |
1101 | 1196 | DEFINE_EXPORTER_KEY(ParamHistFuncStreamer, "step"); |
| 1197 | +DEFINE_EXPORTER_KEY(RooSplineStreamer, "spline"); |
1102 | 1198 |
|
1103 | 1199 | /////////////////////////////////////////////////////////////////////////////////////////////////////// |
1104 | 1200 | // instantiate all importers and exporters |
@@ -1132,6 +1228,7 @@ STATIC_EXECUTE([]() { |
1132 | 1228 | registerImporter<RooFFTConvPdfFactory>("fft_conv_pdf", false); |
1133 | 1229 | registerImporter<RooExtendPdfFactory>("extend_pdf", false); |
1134 | 1230 | registerImporter<ParamHistFuncFactory>("step", false); |
| 1231 | + registerImporter<RooSplineFactory>("spline", false); |
1135 | 1232 |
|
1136 | 1233 | registerExporter<RooAddPdfStreamer<RooAddPdf>>(RooAddPdf::Class(), false); |
1137 | 1234 | registerExporter<RooAddPdfStreamer<RooAddModel>>(RooAddModel::Class(), false); |
@@ -1159,6 +1256,7 @@ STATIC_EXECUTE([]() { |
1159 | 1256 | registerExporter<RooFFTConvPdfStreamer>(RooFFTConvPdf::Class(), false); |
1160 | 1257 | registerExporter<RooExtendPdfStreamer>(RooExtendPdf::Class(), false); |
1161 | 1258 | registerExporter<ParamHistFuncStreamer>(ParamHistFunc::Class(), false); |
| 1259 | + registerExporter<RooSplineStreamer>(RooSpline::Class(), false); |
1162 | 1260 | }); |
1163 | 1261 |
|
1164 | 1262 | } // namespace |
0 commit comments