Skip to content

Commit

Permalink
Merge pull request #851 from rkansal47/regex-float-parameters
Browse files Browse the repository at this point in the history
Add regex for floatNuisances
  • Loading branch information
kcormi authored Oct 10, 2023
2 parents 6aad891 + a3fa324 commit 97b5da4
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 80 deletions.
1 change: 1 addition & 0 deletions interface/Combine.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class Combine {
private:
bool mklimit(RooWorkspace *w, RooStats::ModelConfig *mc_s, RooStats::ModelConfig *mc_b, RooAbsData &data, double &limit, double &limitErr) ;

std::string parseRegex(std::string instr, const RooArgSet *nuisances, RooWorkspace *w) ;
void addDiscreteNuisances(RooWorkspace *);
void addNuisances(const RooArgSet *);
void addFloatingParameters(const RooArgSet &);
Expand Down
182 changes: 102 additions & 80 deletions src/Combine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ Combine::Combine() :
("validateModel,V", "Perform some sanity checks on the model and abort if they fail.")
("saveToys", "Save results of toy MC in output file")
("floatAllNuisances", po::value<bool>(&floatAllNuisances_)->default_value(false), "Make all nuisance parameters floating")
("floatParameters", po::value<string>(&floatNuisances_)->default_value(""), "Set to floating these parameters (note freeze will take priority over float)")
("floatParameters", po::value<string>(&floatNuisances_)->default_value(""), "Set to floating these parameters (note freeze will take priority over float), also accepts regexp with syntax 'rgx{<my regexp>}' or 'var{<my regexp>}'")
("freezeAllGlobalObs", po::value<bool>(&freezeAllGlobalObs_)->default_value(true), "Make all global observables constant")
;
miscOptions_.add_options()
Expand Down Expand Up @@ -215,6 +215,62 @@ void Combine::applyOptions(const boost::program_options::variables_map &vm) {
makeToyGenSnapshot_ = (method == "FitDiagnostics" && !vm.count("justFit"));
}

std::string Combine::parseRegex(std::string instr, const RooArgSet *nuisances, RooWorkspace *w) {
// expand regexps inside the "rgx{}" option
while (instr.find("rgx{") != std::string::npos) {
size_t pos1 = instr.find("rgx{");
size_t pos2 = instr.find("}",pos1);
std::string prestr = instr.substr(0,pos1);
std::string poststr = instr.substr(pos2+1,instr.size()-pos2);
std::string reg_esp = instr.substr(pos1+4,pos2-pos1-4);

std::regex rgx( reg_esp, std::regex::ECMAScript);

std::string matchingParams="";
std::unique_ptr<TIterator> iter(nuisances->createIterator());
for (RooAbsArg *a = (RooAbsArg*) iter->Next(); a != 0; a = (RooAbsArg*) iter->Next()) {
const std::string &target = a->GetName();
std::smatch match;
if (std::regex_match(target, match, rgx)) {
matchingParams = matchingParams + target + ",";
}
}

instr = prestr+matchingParams+poststr;
instr = boost::replace_all_copy(instr, ",,", ",");
}

// expand regexps inside the "var{}" option
while (instr.find("var{") != std::string::npos) {
size_t pos1 = instr.find("var{");
size_t pos2 = instr.find("}",pos1);
std::string prestr = instr.substr(0,pos1);
std::string poststr = instr.substr(pos2+1,instr.size()-pos2);
std::string reg_esp = instr.substr(pos1+4,pos2-pos1-4);

std::regex rgx( reg_esp, std::regex::ECMAScript);

std::string matchingParams="";
std::unique_ptr<TIterator> iter(w->componentIterator());
for (RooAbsArg *a = (RooAbsArg*) iter->Next(); a != 0; a = (RooAbsArg*) iter->Next()) {

if ( ! (a->IsA()->InheritsFrom(RooRealVar::Class()) || a->IsA()->InheritsFrom(RooCategory::Class()))) continue;

const std::string &target = a->GetName();
// std::cout<<"var "<<target<<std::endl;
std::smatch match;
if (std::regex_match(target, match, rgx)) {
matchingParams = matchingParams + target + ",";
}
}

instr = prestr+matchingParams+poststr;
instr = boost::replace_all_copy(instr, ",,", ",");
}

return instr;
}

bool Combine::mklimit(RooWorkspace *w, RooStats::ModelConfig *mc_s, RooStats::ModelConfig *mc_b, RooAbsData &data, double &limit, double &limitErr) {
TStopwatch timer;

Expand Down Expand Up @@ -589,77 +645,43 @@ void Combine::run(TString hlfFile, const std::string &dataset, double &limit, do
}

if (floatNuisances_ != "") {
RooArgSet toFloat((floatNuisances_=="all")?*nuisances:(w->argSet(floatNuisances_.c_str())));
floatNuisances_ = parseRegex(floatNuisances_, nuisances, w);

RooArgSet toFloat;
if (floatNuisances_=="all") {
toFloat.add(*nuisances);
} else {
std::vector<std::string> nuisToFloat;
boost::split(nuisToFloat, floatNuisances_, boost::is_any_of(","), boost::token_compress_on);
for (int k=0; k<(int)nuisToFloat.size(); k++) {
if (nuisToFloat[k]=="") continue;
else if(nuisToFloat[k]=="all") {
toFloat.add(*nuisances);
continue;
}
else if (!w->fundArg(nuisToFloat[k].c_str())) {
std::cout<<"WARNING: cannot float nuisance parameter "<<nuisToFloat[k].c_str()<<" if it doesn't exist!"<<std::endl;
continue;
}
const RooAbsArg *arg = (RooAbsArg*)w->fundArg(nuisToFloat[k].c_str());
toFloat.add(*arg);
}
}

if (verbose > 0) {
std::cout << "Set floating the following parameters: "; toFloat.Print("");
Logger::instance().log(std::string(Form("Combine.cc: %d -- Set floating the following parameters: ",__LINE__)),Logger::kLogLevelInfo,__func__);
std::cout << "Floating the following parameters: "; toFloat.Print("");
Logger::instance().log(std::string(Form("Combine.cc: %d -- Floating the following parameters: ",__LINE__)),Logger::kLogLevelInfo,__func__);
std::unique_ptr<TIterator> iter(toFloat.createIterator());
for (RooAbsArg *a = (RooAbsArg*) iter->Next(); a != 0; a = (RooAbsArg*) iter->Next()) {
Logger::instance().log(std::string(Form("Combine.cc: %d %s ",__LINE__,a->GetName())),Logger::kLogLevelInfo,__func__);
}
}
}
utils::setAllConstant(toFloat, false);
}

if (freezeNuisances_ != "") {
freezeNuisances_ = parseRegex(freezeNuisances_, nuisances, w);

// expand regexps
while (freezeNuisances_.find("rgx{") != std::string::npos) {
size_t pos1 = freezeNuisances_.find("rgx{");
size_t pos2 = freezeNuisances_.find("}",pos1);
std::string prestr = freezeNuisances_.substr(0,pos1);
std::string poststr = freezeNuisances_.substr(pos2+1,freezeNuisances_.size()-pos2);
std::string reg_esp = freezeNuisances_.substr(pos1+4,pos2-pos1-4);

//std::cout<<"interpreting "<<reg_esp<<" as regex "<<std::endl;
std::regex rgx( reg_esp, std::regex::ECMAScript);

std::string matchingParams="";
std::unique_ptr<TIterator> iter(nuisances->createIterator());
for (RooAbsArg *a = (RooAbsArg*) iter->Next(); a != 0; a = (RooAbsArg*) iter->Next()) {
const std::string &target = a->GetName();
std::smatch match;
if (std::regex_match(target, match, rgx)) {
matchingParams = matchingParams + target + ",";
}
}

freezeNuisances_ = prestr+matchingParams+poststr;
freezeNuisances_ = boost::replace_all_copy(freezeNuisances_, ",,", ",");

}

// expand regexps
while (freezeNuisances_.find("var{") != std::string::npos) {
size_t pos1 = freezeNuisances_.find("var{");
size_t pos2 = freezeNuisances_.find("}",pos1);
std::string prestr = freezeNuisances_.substr(0,pos1);
std::string poststr = freezeNuisances_.substr(pos2+1,freezeNuisances_.size()-pos2);
std::string reg_esp = freezeNuisances_.substr(pos1+4,pos2-pos1-4);

// std::cout<<"interpreting "<<reg_esp<<" as regex "<<std::endl;
std::regex rgx( reg_esp, std::regex::ECMAScript);

std::string matchingParams="";
std::unique_ptr<TIterator> iter(w->componentIterator());
for (RooAbsArg *a = (RooAbsArg*) iter->Next(); a != 0; a = (RooAbsArg*) iter->Next()) {

if ( ! (a->IsA()->InheritsFrom(RooRealVar::Class()) || a->IsA()->InheritsFrom(RooCategory::Class()))) continue;

const std::string &target = a->GetName();
// std::cout<<"var "<<target<<std::endl;
std::smatch match;
if (std::regex_match(target, match, rgx)) {
matchingParams = matchingParams + target + ",";
}
}

freezeNuisances_ = prestr+matchingParams+poststr;
freezeNuisances_ = boost::replace_all_copy(freezeNuisances_, ",,", ",");

}

//RooArgSet toFreeze((freezeNuisances_=="all")?*nuisances:(w->argSet(freezeNuisances_.c_str())));
RooArgSet toFreeze;
if (freezeNuisances_=="allConstrainedNuisances") {
toFreeze.add(*nuisances);
Expand Down Expand Up @@ -687,7 +709,7 @@ void Combine::run(TString hlfFile, const std::string &dataset, double &limit, do
std::unique_ptr<TIterator> iter(toFreeze.createIterator());
for (RooAbsArg *a = (RooAbsArg*) iter->Next(); a != 0; a = (RooAbsArg*) iter->Next()) {
Logger::instance().log(std::string(Form("Combine.cc: %d %s ",__LINE__,a->GetName())),Logger::kLogLevelInfo,__func__);
}
}
}
utils::setAllConstant(toFreeze, true);
if (nuisances) {
Expand All @@ -705,24 +727,24 @@ void Combine::run(TString hlfFile, const std::string &dataset, double &limit, do
for (std::vector<string>::iterator ng_it=nuisanceGroups.begin();ng_it!=nuisanceGroups.end();ng_it++){
bool freeze_complement=false;
if (boost::algorithm::starts_with((*ng_it),"^")){
freeze_complement=true;
(*ng_it).erase(0,1);
}
freeze_complement=true;
(*ng_it).erase(0,1);
}

if (!w->set(Form("group_%s",(*ng_it).c_str()))){
std::cerr << "Unknown nuisance group: " << (*ng_it) << std::endl;
throw std::invalid_argument("Unknown nuisance group name");
}
RooArgSet groupNuisances(*(w->set(Form("group_%s",(*ng_it).c_str()))));
RooArgSet toFreeze;
if (!w->set(Form("group_%s",(*ng_it).c_str()))){
std::cerr << "Unknown nuisance group: " << (*ng_it) << std::endl;
throw std::invalid_argument("Unknown nuisance group name");
}
RooArgSet groupNuisances(*(w->set(Form("group_%s",(*ng_it).c_str()))));
RooArgSet toFreeze;

if (freeze_complement) {
RooArgSet still_floating(*mc->GetNuisanceParameters());
still_floating.remove(groupNuisances,true,true);
toFreeze.add(still_floating);
} else {
toFreeze.add(groupNuisances);
}
if (freeze_complement) {
RooArgSet still_floating(*mc->GetNuisanceParameters());
still_floating.remove(groupNuisances,true,true);
toFreeze.add(still_floating);
} else {
toFreeze.add(groupNuisances);
}

if (verbose > 0) { std::cout << "Freezing the following nuisance parameters: "; toFreeze.Print(""); }
utils::setAllConstant(toFreeze, true);
Expand All @@ -732,7 +754,7 @@ void Combine::run(TString hlfFile, const std::string &dataset, double &limit, do
mc->SetNuisanceParameters(newnuis);
if (mc_bonly) mc_bonly->SetNuisanceParameters(newnuis);
nuisances = mc->GetNuisanceParameters();
}
}
}
}

Expand Down

0 comments on commit 97b5da4

Please sign in to comment.