Skip to content

Commit 5424514

Browse files
authored
[RF][HS3] RooSpline exporter importer
Adds Functionality to export and import RooSpline objects to HS3 as well as corresponding const-getters to the class itself.
1 parent 143c841 commit 5424514

File tree

4 files changed

+140
-0
lines changed

4 files changed

+140
-0
lines changed

roofit/hs3/src/JSONFactories_RooFitCore.cxx

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
#include <RooGaussModel.h>
4848
#include <RooWorkspace.h>
4949
#include <RooRealIntegral.h>
50+
#include <RooSpline.h>
51+
#include <TSpline.h>
5052

5153
#include <TF1.h>
5254
#include <TH1.h>
@@ -599,6 +601,63 @@ class ParamHistFuncFactory : public RooFit::JSONIO::Importer {
599601
}
600602
};
601603

604+
class RooSplineFactory : public RooFit::JSONIO::Importer {
605+
public:
606+
bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override
607+
{
608+
const std::string name(RooJSONFactoryWSTool::name(p));
609+
610+
// Mandatory fields
611+
if (!p.has_child("x")) {
612+
RooJSONFactoryWSTool::error("no x given in '" + name + "'");
613+
}
614+
if (!p.has_child("x0") || !p.has_child("y0")) {
615+
RooJSONFactoryWSTool::error("no x0/y0 given in '" + name + "'");
616+
}
617+
618+
RooAbsReal *x = tool->requestArg<RooAbsReal>(p, "x");
619+
620+
// Optional fields (defaults follow RooSpline ctor defaults)
621+
std::string algo = p.has_child("interpolation") ? p["interpolation"].val() : "poly3";
622+
int order = 0;
623+
if (algo == "poly3")
624+
order = 3;
625+
else if (algo == "poly5")
626+
order = 5;
627+
else {
628+
RooJSONFactoryWSTool::error("unsupported algo '" + algo + "' for RooSpline in '" + name +
629+
"': allowed are 'poly3' and 'poly5'");
630+
}
631+
const bool logx = p.has_child("logx") ? p["logx"].val_bool() : false;
632+
const bool logy = p.has_child("logy") ? p["logy"].val_bool() : false;
633+
634+
// Read knots
635+
std::vector<double> x0;
636+
std::vector<double> y0;
637+
x0.reserve(p["x0"].num_children());
638+
y0.reserve(p["y0"].num_children());
639+
640+
for (const auto &v : p["x0"].children())
641+
x0.push_back(v.val_double());
642+
for (const auto &v : p["y0"].children())
643+
y0.push_back(v.val_double());
644+
645+
if (x0.size() != y0.size()) {
646+
RooJSONFactoryWSTool::error("x0/y0 size mismatch in '" + name + "': x0 has " + std::to_string(x0.size()) +
647+
", y0 has " + std::to_string(y0.size()));
648+
}
649+
if (x0.size() < 2) {
650+
RooJSONFactoryWSTool::error("need at least 2 knots in '" + name + "'");
651+
}
652+
653+
// Construct RooSpline(name,title, x, x0, y0, order, logx, logy)
654+
tool->wsEmplace<::RooSpline>(name.c_str(), *x, std::span<const double>(x0.data(), x0.size()),
655+
std::span<const double>(y0.data(), y0.size()), order, logx, logy);
656+
657+
return true;
658+
}
659+
};
660+
602661
///////////////////////////////////////////////////////////////////////////////////////////////////////
603662
// specialized exporter implementations
604663
///////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -1061,6 +1120,42 @@ class ParamHistFuncStreamer : public RooFit::JSONIO::Exporter {
10611120
}
10621121
};
10631122

1123+
class RooSplineStreamer : public RooFit::JSONIO::Exporter {
1124+
public:
1125+
std::string const &key() const override;
1126+
1127+
bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, RooFit::Detail::JSONNode &elem) const override
1128+
{
1129+
auto const *rs = static_cast<RooSpline const *>(func);
1130+
1131+
elem["type"] << key();
1132+
1133+
// Independent variable
1134+
elem["x"] << rs->x().GetName();
1135+
1136+
// Spline configuration
1137+
// Canonical algo for RooSpline
1138+
elem["interpolation"] << (rs->order() == 5 ? "poly5" : "poly3");
1139+
elem["logx"] << rs->logx();
1140+
elem["logy"] << rs->logy();
1141+
1142+
// Serialize knots as primitive arrays
1143+
TSpline const &sp = rs->spline();
1144+
auto &x0 = elem["x0"].set_seq();
1145+
auto &y0 = elem["y0"].set_seq();
1146+
1147+
const int np = sp.GetNp();
1148+
for (int i = 0; i < np; ++i) {
1149+
double xk = 0.0, yk = 0.0;
1150+
sp.GetKnot(i, xk, yk);
1151+
x0.append_child() << xk;
1152+
y0.append_child() << yk;
1153+
}
1154+
1155+
return true;
1156+
}
1157+
};
1158+
10641159
#define DEFINE_EXPORTER_KEY(class_name, name) \
10651160
std::string const &class_name::key() const \
10661161
{ \
@@ -1099,6 +1194,7 @@ DEFINE_EXPORTER_KEY(RooDerivativeStreamer, "derivative");
10991194
DEFINE_EXPORTER_KEY(RooFFTConvPdfStreamer, "fft_conv_pdf");
11001195
DEFINE_EXPORTER_KEY(RooExtendPdfStreamer, "extend_pdf");
11011196
DEFINE_EXPORTER_KEY(ParamHistFuncStreamer, "step");
1197+
DEFINE_EXPORTER_KEY(RooSplineStreamer, "spline");
11021198

11031199
///////////////////////////////////////////////////////////////////////////////////////////////////////
11041200
// instantiate all importers and exporters
@@ -1132,6 +1228,7 @@ STATIC_EXECUTE([]() {
11321228
registerImporter<RooFFTConvPdfFactory>("fft_conv_pdf", false);
11331229
registerImporter<RooExtendPdfFactory>("extend_pdf", false);
11341230
registerImporter<ParamHistFuncFactory>("step", false);
1231+
registerImporter<RooSplineFactory>("spline", false);
11351232

11361233
registerExporter<RooAddPdfStreamer<RooAddPdf>>(RooAddPdf::Class(), false);
11371234
registerExporter<RooAddPdfStreamer<RooAddModel>>(RooAddModel::Class(), false);
@@ -1159,6 +1256,7 @@ STATIC_EXECUTE([]() {
11591256
registerExporter<RooFFTConvPdfStreamer>(RooFFTConvPdf::Class(), false);
11601257
registerExporter<RooExtendPdfStreamer>(RooExtendPdf::Class(), false);
11611258
registerExporter<ParamHistFuncStreamer>(ParamHistFunc::Class(), false);
1259+
registerExporter<RooSplineStreamer>(RooSpline::Class(), false);
11621260
});
11631261

11641262
} // namespace

roofit/hs3/test/testRooFitHS3.cxx

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <RooHelpers.h>
1616
#include <RooHistFunc.h>
1717
#include <RooHistPdf.h>
18+
#include <RooSpline.h>
1819
#include <RooLognormal.h>
1920
#include <RooMultiVarGaussian.h>
2021
#include <RooPoisson.h>
@@ -533,3 +534,26 @@ TEST(RooFitHS3, ModelConfigWithMultiVarGaussian)
533534
int status = validate(ws1, "mc");
534535
EXPECT_EQ(status, 0);
535536
}
537+
538+
TEST(RooFitHS3, RooSpline)
539+
{
540+
// Observable must be called "x" because validate() assumes that convention.
541+
RooWorkspace ws;
542+
543+
// Use an observable with bins to enable the per-bin closure check.
544+
auto *x = ws.factory("x[0,10]");
545+
ASSERT_NE(x, nullptr);
546+
ws.var("x")->setBins(50);
547+
548+
// Define knots. Keep it simple but nontrivial (nonlinear).
549+
const std::vector<double> x0{0.0, 1.5, 3.0, 6.0, 10.0};
550+
const std::vector<double> y0{1.0, 2.0, 1.0, 4.0, 3.0};
551+
552+
RooSpline spline{"spline", "spline", *ws.var("x"), x0, y0, /*order=*/3, /*logx=*/false, /*logy=*/false};
553+
554+
// Import the object into the workspace and validate JSON IO.
555+
ws.import(spline, RooFit::Silence());
556+
557+
const int status = validate(ws, "spline", /*exact=*/true);
558+
EXPECT_EQ(status, 0);
559+
}

roofit/roofit/inc/RooSpline.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ class RooSpline : public RooAbsReal {
3737
/// \param[in] newname The name of the cloned object (optional).
3838
TObject *clone(const char *newname) const override { return new RooSpline(*this, newname); }
3939

40+
RooAbsReal const &x() const { return static_cast<RooAbsReal const &>(*_x.absArg()); }
41+
int order() const;
42+
TSpline const &spline() const { return *_spline; }
43+
bool logx() const { return _logx; }
44+
bool logy() const { return _logy; }
45+
4046
protected:
4147
double evaluate() const override;
4248

roofit/roofit/src/RooSpline.cxx

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,15 @@ double RooSpline::evaluate() const
129129
const double x_val = (!_logx) ? _x : std::exp(_x);
130130
return (!_logy) ? _spline->Eval(x_val) : std::exp(_spline->Eval(x_val));
131131
}
132+
133+
/// Return the order of the spline
134+
int RooSpline::order() const
135+
{
136+
// RooSpline currently doesn’t store the order explicitly, so infer from TSpline dynamic type.
137+
// (Constructor uses TSpline3 for order=3 and TSpline5 for order=5.)
138+
if (dynamic_cast<TSpline5 const *>(_spline.get()))
139+
return 5;
140+
if (dynamic_cast<TSpline3 const *>(_spline.get()))
141+
return 3;
142+
return 3; // reasonable default / forward compatibility
143+
}

0 commit comments

Comments
 (0)