Skip to content

Commit

Permalink
First RooFit AD integration
Browse files Browse the repository at this point in the history
These are the remaining developments from the RooFit ICHEP 2024 AD
branch.
  • Loading branch information
guitargeek committed Nov 19, 2024
1 parent 6980f6f commit 58b6ea4
Show file tree
Hide file tree
Showing 9 changed files with 310 additions and 2 deletions.
5 changes: 5 additions & 0 deletions interface/AsymPow.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <RooAbsReal.h>
#include <RooRealProxy.h>

#include "CombineCodegenImpl.h"

//_________________________________________________
/*
Expand All @@ -28,6 +29,8 @@ class AsymPow : public RooAbsReal {

TObject * clone(const char *newname) const override { return new AsymPow(*this, newname); }

COMBINE_DECLARE_TRANSLATE;

RooAbsReal const &kappaLow() const { return kappaLow_.arg(); }
RooAbsReal const &kappaHigh() const { return kappaHigh_.arg(); }
RooAbsReal const &theta() const { return theta_.arg(); }
Expand All @@ -42,4 +45,6 @@ class AsymPow : public RooAbsReal {
ClassDefOverride(AsymPow,1) // Asymmetric power
};

COMBINE_DECLARE_CODEGEN_IMPL(AsymPow);

#endif
18 changes: 18 additions & 0 deletions interface/CombineCodegenImpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef HiggsAnalysis_CombinedLimit_CombineCodegenImpl_h
#define HiggsAnalysis_CombinedLimit_CombineCodegenImpl_h

#include <ROOT/RConfig.hxx> // for ROOT_VERSION

#if ROOT_VERSION_CODE >= ROOT_VERSION(6,35,0)
# define COMBINE_DECLARE_CODEGEN_IMPL(CLASS_NAME) \
namespace RooFit { namespace Experimental { void codegenImpl(CLASS_NAME &arg, CodegenContext &ctx); }}
# define COMBINE_DECLARE_TRANSLATE
#elif ROOT_VERSION_CODE >= ROOT_VERSION(6,32,0)
# define COMBINE_DECLARE_CODEGEN_IMPL(CLASS_NAME)
# define COMBINE_DECLARE_TRANSLATE void translate(RooFit::Detail::CodeSquashContext &ctx) const override;
#else
# define COMBINE_DECLARE_CODEGEN_IMPL
# define COMBINE_DECLARE_TRANSLATE
#endif

#endif
4 changes: 4 additions & 0 deletions interface/HZZ4LRooPdfs.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ class RooqqZZPdf_v2 : public RooAbsPdf {
RooqqZZPdf_v2(const RooqqZZPdf_v2& other, const char* name=0) ;
TObject* clone(const char* newname) const override { return new RooqqZZPdf_v2(*this,newname); }
inline ~RooqqZZPdf_v2() override { }

std::unique_ptr<RooAbsArg> compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileContext & ctx) const override;

protected:

Expand Down Expand Up @@ -299,6 +301,8 @@ class RooggZZPdf_v2 : public RooAbsPdf {
RooggZZPdf_v2(const RooggZZPdf_v2& other, const char* name=0) ;
TObject* clone(const char* newname) const override { return new RooggZZPdf_v2(*this,newname); }
inline ~RooggZZPdf_v2() override { }

std::unique_ptr<RooAbsArg> compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileContext & ctx) const override;

protected:

Expand Down
6 changes: 6 additions & 0 deletions interface/ProcessNormalization.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <RooAbsReal.h>
#include "RooListProxy.h"

#include "CombineCodegenImpl.h"

//_________________________________________________
/*
BEGIN_HTML
Expand All @@ -28,6 +30,8 @@ class ProcessNormalization : public RooAbsReal {
void addOtherFactor(RooAbsReal &factor) ;
void dump() const ;

COMBINE_DECLARE_TRANSLATE;

double nominalValue() const { return nominalValue_; }
std::vector<double> const &logKappa() const { return logKappa_; }
RooArgList const &thetaList() const { return thetaList_; }
Expand Down Expand Up @@ -57,4 +61,6 @@ class ProcessNormalization : public RooAbsReal {
ClassDefOverride(ProcessNormalization,1) // Process normalization interpolator
};

COMBINE_DECLARE_CODEGEN_IMPL(ProcessNormalization);

#endif
9 changes: 8 additions & 1 deletion interface/VerticalInterpHistPdf.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
#include "TH1.h"
#include "SimpleCacheSentry.h"
#include "FastTemplate_Old.h"
#include <cmath>

#include "CombineCodegenImpl.h"

class FastVerticalInterpHistPdf;
class FastVerticalInterpHistPdf2Base;
Expand Down Expand Up @@ -328,6 +329,8 @@ class FastVerticalInterpHistPdf2 : public FastVerticalInterpHistPdf2Base {

FastHisto const& cache() const { return _cache; }

COMBINE_DECLARE_TRANSLATE;

FastHisto const &cacheNominal() const { return _cacheNominal; }

friend class FastVerticalInterpHistPdf2V;
Expand Down Expand Up @@ -387,6 +390,8 @@ class FastVerticalInterpHistPdf2D2 : public FastVerticalInterpHistPdf2Base {

Double_t evaluate() const override ;

COMBINE_DECLARE_TRANSLATE;

FastHisto2D const &cacheNominal() const { return _cacheNominal; }

protected:
Expand Down Expand Up @@ -459,5 +464,7 @@ class FastVerticalInterpHistPdf3D : public FastVerticalInterpHistPdfBase {
};


COMBINE_DECLARE_CODEGEN_IMPL(FastVerticalInterpHistPdf2);
COMBINE_DECLARE_CODEGEN_IMPL(FastVerticalInterpHistPdf2D2);

#endif
38 changes: 38 additions & 0 deletions scripts/fitRooFitAD.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
import ROOT
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--input", "-i", help="input ws file")
parser.add_argument("--backend", help='set evaluation backend, default is "combine"')

args = parser.parse_args()

with ROOT.TFile.Open(args.input) as f:
ws = f.Get("w")

global_observables = ws.set("globalObservables")
constrain = ws.set("nuisances")

pdf = ws["model_s"]
data = ws["data_obs"]

# To use the evaluation backends of RooFit
ROOT.gInterpreter.Declare("#include <HiggsAnalysis/CombinedLimit/interface/Combine.h>")

ROOT.Combine.nllBackend_ = args.backend

# Change verbosity
ROOT.RooMsgService.instance().getStream(1).removeTopic(ROOT.RooFit.Minimization)
ROOT.RooMsgService.instance().getStream(1).removeTopic(ROOT.RooFit.Fitting)

ROOT.gInterpreter.Declare("#include <HiggsAnalysis/CombinedLimit/interface/CombineMathFuncs.h>")

nll = pdf.createNLL(data, Constrain=constrain, GlobalObservables=global_observables, Offset="initial")

cfg = ROOT.RooMinimizer.Config()
minim = ROOT.RooMinimizer(nll, cfg)
minim.setEps(1.0)
minim.setStrategy(0)
minim.minimize("Minuit2", "")
minim.save().Print()
213 changes: 213 additions & 0 deletions src/CombineCodegenImpl.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
#include "../interface/CombineCodegenImpl.h"

#if ROOT_VERSION_CODE >= ROOT_VERSION(6,32,0)

#include "../interface/AsymPow.h"
#include "../interface/ProcessNormalization.h"
#include "../interface/VerticalInterpHistPdf.h"

#include <RooUniformBinning.h>

#if ROOT_VERSION_CODE >= ROOT_VERSION(6,35,0)
namespace RooFit {
namespace Experimental {
# define CODEGEN_IMPL(CLASS_NAME) void codegenImpl(CLASS_NAME &arg0, CodegenContext &ctx)
# define ARG_VAR auto &arg = arg0;
#else
# define CODEGEN_IMPL(CLASS_NAME) void CLASS_NAME::translate(RooFit::Detail::CodeSquashContext &ctx) const
# define ARG_VAR auto &arg = *this;
#endif


CODEGEN_IMPL(AsymPow) {
ARG_VAR;
ctx.addResult(&arg,
ctx.buildCall("RooFit::Detail::MathFuncs::asymPow", arg.theta(), arg.kappaLow(), arg.kappaHigh()));
}

CODEGEN_IMPL(ProcessNormalization) {
ARG_VAR;

std::vector<double> logAsymmKappaLow;
std::vector<double> logAsymmKappaHigh;
logAsymmKappaLow.reserve(arg.logAsymmKappa().size());
logAsymmKappaHigh.reserve(arg.logAsymmKappa().size());
for (auto [lo, hi] : arg.logAsymmKappa()) {
logAsymmKappaLow.push_back(lo);
logAsymmKappaHigh.push_back(hi);
}

ctx.addResult(&arg,
ctx.buildCall("RooFit::Detail::MathFuncs::processNormalization",
arg.nominalValue(),
arg.thetaList().size(),
arg.asymmThetaList().size(),
arg.otherFactorList().size(),
arg.thetaList(),
arg.logKappa(),
arg.asymmThetaList(),
logAsymmKappaLow,
logAsymmKappaHigh,
arg.otherFactorList()));
}

CODEGEN_IMPL(FastVerticalInterpHistPdf2) {
ARG_VAR;

if (arg.smoothAlgo() < 0) {
throw std::runtime_error("We only support smoothAlgo >= 0");
}

RooRealVar const &xVar = arg.x();

int numBins = xVar.numBins();

std::vector<double> nominalVec(numBins);
std::vector<double> widthVec(numBins);
std::vector<double> morphsVecSum;
std::vector<double> morphsVecDiff;

auto const &cacheNominal = arg.cacheNominal();

for (int i = 0; i < numBins; ++i) {
nominalVec[i] = cacheNominal.GetBinContent(i);
widthVec[i] = cacheNominal.GetWidth(i);
}

std::size_t nCoefs = arg.coefList().size();

morphsVecSum.reserve(numBins * nCoefs);
morphsVecDiff.reserve(numBins * nCoefs);
auto const &morphs = arg.morphs();
for (unsigned int j = 0; j < nCoefs; ++j) {
for (int i = 0; i < numBins; ++i) {
morphsVecSum.push_back(morphs[j].sum[i]);
morphsVecDiff.push_back(morphs[j].diff[i]);
}
}

for (int i = 0; i < numBins; ++i) {
nominalVec[i] = cacheNominal.GetBinContent(i);
}

// The bin index part
// We also have to assert that x is uniformely binned!
if (!dynamic_cast<RooUniformBinning const *>(&xVar.getBinning())) {
throw std::runtime_error("We only support uniform binning!");
}
double xLow = xVar.getMin();
double xHigh = xVar.getMax();
std::string binIdx = ctx.buildCall("RooFit::Detail::MathFuncs::getUniformBinning", xLow, xHigh, xVar, numBins);

std::string arrName = ctx.getTmpVarName();

std::stringstream code;
code << "double " << arrName << "[" << numBins << "];\n";
code << ctx.buildCall("RooFit::Detail::MathFuncs::fastVerticalInterpHistPdf2",
numBins,
nCoefs,
arg.coefList(),
nominalVec,
widthVec,
morphsVecSum,
morphsVecDiff,
arg.smoothRegion(),
arrName) +
";\n";

ctx.addToCodeBody(code.str(), true);
ctx.addResult(&arg, arrName + "[" + binIdx + "]");
}

CODEGEN_IMPL(FastVerticalInterpHistPdf2D2) {
ARG_VAR;

if (arg.smoothAlgo() < 0) {
throw std::runtime_error("We only support smoothAlgo >= 0");
}

if (!arg.conditional()) {
throw std::runtime_error("We only support conditional == true");
}

RooRealVar const &xVar = arg.x();
RooRealVar const &yVar = arg.y();

// We also have to assert that x and y are uniformely binned!
if (!dynamic_cast<RooUniformBinning const *>(&xVar.getBinning())) {
throw std::runtime_error("We only support uniform binning!");
}
if (!dynamic_cast<RooUniformBinning const *>(&yVar.getBinning())) {
throw std::runtime_error("We only support uniform binning!");
}

auto const &cacheNominal = arg.cacheNominal();

int numBinsX = cacheNominal.binX();
int numBinsY = cacheNominal.binY();
int numBins = numBinsY * numBinsY;

std::vector<double> nominalVec(numBins);
std::vector<double> widthVec(numBins);
std::vector<double> morphsVecSum;
std::vector<double> morphsVecDiff;

for (int i = 0; i < numBins; ++i) {
nominalVec[i] = cacheNominal.GetBinContent(i);
widthVec[i] = cacheNominal.GetWidth(i);
}

std::size_t nCoefs = arg.coefList().size();

morphsVecSum.reserve(numBins * nCoefs);
morphsVecDiff.reserve(numBins * nCoefs);
auto const &morphs = arg.morphs();
for (unsigned int j = 0; j < nCoefs; ++j) {
for (int i = 0; i < numBins; ++i) {
morphsVecSum.push_back(morphs[j].sum[i]);
morphsVecDiff.push_back(morphs[j].diff[i]);
}
}

for (int i = 0; i < numBins; ++i) {
nominalVec[i] = cacheNominal.GetBinContent(i);
}

// The bin index part
double xLow = xVar.getMin();
double xHigh = xVar.getMax();
std::string binIdxX = ctx.buildCall("RooFit::Detail::MathFuncs::getUniformBinning", xLow, xHigh, arg.x(), numBinsX);
double yLow = yVar.getMin();
double yHigh = yVar.getMax();
std::string binIdxY = ctx.buildCall("RooFit::Detail::MathFuncs::getUniformBinning", yLow, yHigh, arg.y(), numBinsY);

std::stringstream binIdx;
binIdx << "(" << binIdxY << " + " << yVar.numBins() << " * " << binIdxX << ")";

std::string arrName = ctx.getTmpVarName();

std::stringstream code;
code << "double " << arrName << "[" << (numBinsX * numBinsY) << "];\n";
code << ctx.buildCall("RooFit::Detail::MathFuncs::fastVerticalInterpHistPdf2D2",
numBinsX,
numBinsY,
nCoefs,
arg.coefList(),
nominalVec,
widthVec,
morphsVecSum,
morphsVecDiff,
arg.smoothRegion(),
arrName) +
";\n";

ctx.addToCodeBody(code.str(), true);
ctx.addResult(&arg, arrName + "[" + binIdx.str() + "]");
}

#if ROOT_VERSION_CODE >= ROOT_VERSION(6,35,0)
} // namespace RooFit
} // namespace Experimental
#endif

#endif // ROOT_VERSION_CODE >= ROOT_VERSION(6,32,0)
Loading

0 comments on commit 58b6ea4

Please sign in to comment.