Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Component and dataset upgrades #524

Merged
merged 5 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions source/framework/core/inc/TRestDataSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class TRestDataSet : public TRestMetadata {
Bool_t fExternal = false; //<

/// The resulting RDF::RNode object after initialization
ROOT::RDF::RNode fDataSet = ROOT::RDataFrame(0); //!
ROOT::RDF::RNode fDataFrame = ROOT::RDataFrame(0); //!

/// A pointer to the generated tree
TChain* fTree = nullptr; //!
Expand All @@ -122,12 +122,14 @@ class TRestDataSet : public TRestMetadata {
protected:
virtual std::vector<std::string> FileSelection();

void RegenerateTree(std::vector<std::string> finalList = {});

public:
/// Gives access to the RDataFrame
ROOT::RDF::RNode GetDataFrame() const {
if (!fExternal && fTree == nullptr)
RESTWarning << "DataFrame has not been yet initialized" << RESTendl;
return fDataSet;
return fDataFrame;
}

void EnableMultiThreading(Bool_t enable = true) { fMT = enable; }
Expand All @@ -152,7 +154,7 @@ class TRestDataSet : public TRestMetadata {
}

/// Number of variables (or observables)
size_t GetNumberOfColumns() { return fDataSet.GetColumnNames().size(); }
size_t GetNumberOfColumns() { return fDataFrame.GetColumnNames().size(); }

/// Number of variables (or observables)
size_t GetNumberOfBranches() { return GetNumberOfColumns(); }
Expand Down Expand Up @@ -187,7 +189,7 @@ class TRestDataSet : public TRestMetadata {

void SetTotalTimeInSeconds(Double_t seconds) { fTotalDuration = seconds; }
void SetDataFrame(const ROOT::RDF::RNode& dS) {
fDataSet = dS;
fDataFrame = dS;
fExternal = true;
}

Expand All @@ -198,8 +200,12 @@ class TRestDataSet : public TRestMetadata {
void Export(const std::string& filename, std::vector<std::string> excludeColumns = {});

ROOT::RDF::RNode MakeCut(const TRestCut* cut);
ROOT::RDF::RNode ApplyRange(size_t from, size_t to);
ROOT::RDF::RNode Range(size_t from, size_t to);
ROOT::RDF::RNode DefineColumn(const std::string& columnName, const std::string& formula);

size_t GetEntries();

void PrintMetadata() override;
void Initialize() override;

Expand All @@ -209,6 +215,6 @@ class TRestDataSet : public TRestMetadata {
TRestDataSet(const char* cfgFileName, const std::string& name = "");
~TRestDataSet();

ClassDefOverride(TRestDataSet, 7);
ClassDefOverride(TRestDataSet, 8);
};
#endif
78 changes: 60 additions & 18 deletions source/framework/core/src/TRestDataSet.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -382,30 +382,40 @@ void TRestDataSet::GenerateDataSet() {
ROOT::DisableImplicitMT();

RESTInfo << "Initializing dataset" << RESTendl;
fDataSet = ROOT::RDataFrame("AnalysisTree", fFileSelection);
fDataFrame = ROOT::RDataFrame("AnalysisTree", fFileSelection);

RESTInfo << "Making cuts" << RESTendl;
fDataSet = MakeCut(fCut);
fDataFrame = MakeCut(fCut);

// Adding new user columns added to the dataset
for (const auto& [cName, cExpression] : fColumnNameExpressions) {
RESTInfo << "Adding column to dataset: " << cName << RESTendl;
finalList.emplace_back(cName);
fDataSet = DefineColumn(cName, cExpression);
fDataFrame = DefineColumn(cName, cExpression);
}

RegenerateTree(finalList);

RESTInfo << " - Dataset generated!" << RESTendl;
}

///////////////////////////////////////////////
/// \brief It regenerates the tree so that it is an exact copy of the present DataFrame
///
void TRestDataSet::RegenerateTree(std::vector<std::string> finalList) {
RESTInfo << "Generating snapshot." << RESTendl;
std::string user = getenv("USER");
std::string fOutName = "/tmp/rest_output_" + user + ".root";
fDataSet.Snapshot("AnalysisTree", fOutName, finalList);
if (!finalList.empty())
fDataFrame.Snapshot("AnalysisTree", fOutName, finalList);
else
fDataFrame.Snapshot("AnalysisTree", fOutName);

RESTInfo << "Re-importing analysis tree." << RESTendl;
fDataSet = ROOT::RDataFrame("AnalysisTree", fOutName);
fDataFrame = ROOT::RDataFrame("AnalysisTree", fOutName);

TFile* f = TFile::Open(fOutName.c_str());
fTree = (TChain*)f->Get("AnalysisTree");

RESTInfo << " - Dataset generated!" << RESTendl;
}

///////////////////////////////////////////////
Expand Down Expand Up @@ -517,14 +527,32 @@ std::vector<std::string> TRestDataSet::FileSelection() {
return fFileSelection;
}

///////////////////////////////////////////////
/// \brief This method returns a RDataFrame node with the number of
/// samples inside the dataset by selecting a range. It will not
/// modify internally the dataset. See ApplyRange to modify internally
/// the dataset.
///
ROOT::RDF::RNode TRestDataSet::Range(size_t from, size_t to) { return fDataFrame.Range(from, to); }

///////////////////////////////////////////////
/// \brief This method reduces the number of samples inside the
/// dataset by selecting a range.
///
ROOT::RDF::RNode TRestDataSet::ApplyRange(size_t from, size_t to) {
fDataFrame = fDataFrame.Range(from, to);
RegenerateTree();
return fDataFrame;
}

///////////////////////////////////////////////
/// \brief This function applies a TRestCut to the dataframe
/// and returns a dataframe with the applied cuts. Note that
/// the cuts are not applied directly to the dataframe on
/// TRestDataSet, to do so you should do fDataSet = MakeCut(fCut);
/// TRestDataSet, to do so you should do fDataFrame = MakeCut(fCut);
///
ROOT::RDF::RNode TRestDataSet::MakeCut(const TRestCut* cut) {
auto df = fDataSet;
auto df = fDataFrame;

if (cut == nullptr) return df;

Expand Down Expand Up @@ -561,6 +589,20 @@ ROOT::RDF::RNode TRestDataSet::MakeCut(const TRestCut* cut) {
return df;
}

///////////////////////////////////////////////
/// \brief It returns the number of entries found inside fDataFrame
/// and prints out a warning if the number of entries inside the
/// tree is not the same.
///
size_t TRestDataSet::GetEntries() {
auto nEntries = fDataFrame.Count();
if (*nEntries == (long long unsigned int)GetTree()->GetEntries()) return *nEntries;
RESTWarning << "TRestDataSet::GetEntries. Number of tree entries is not the same as RDataFrame entries."
<< RESTendl;
RESTWarning << "Returning RDataFrame entries" << RESTendl;
return *nEntries;
}

///////////////////////////////////////////////
/// \brief This function will add a new column to the RDataFrame using
/// the same scheme as the usual RDF::Define method, but it will on top of
Expand All @@ -574,7 +616,7 @@ ROOT::RDF::RNode TRestDataSet::MakeCut(const TRestCut* cut) {
/// \endcode
///
ROOT::RDF::RNode TRestDataSet::DefineColumn(const std::string& columnName, const std::string& formula) {
auto df = fDataSet;
auto df = fDataFrame;

std::string evalFormula = formula;
for (auto const& [name, properties] : fQuantity)
Expand Down Expand Up @@ -819,7 +861,7 @@ void TRestDataSet::InitFromConfigFile() {
void TRestDataSet::Export(const std::string& filename, std::vector<std::string> excludeColumns) {
RESTInfo << "Exporting dataset" << RESTendl;

std::vector<std::string> columns = fDataSet.GetColumnNames();
std::vector<std::string> columns = fDataFrame.GetColumnNames();
if (!excludeColumns.empty()) {
columns.erase(std::remove_if(columns.begin(), columns.end(),
[&excludeColumns](std::string elem) {
Expand All @@ -831,10 +873,10 @@ void TRestDataSet::Export(const std::string& filename, std::vector<std::string>
RESTInfo << "Re-Generating snapshot." << RESTendl;
std::string user = getenv("USER");
std::string fOutName = "/tmp/rest_output_" + user + ".root";
fDataSet.Snapshot("AnalysisTree", fOutName, columns);
fDataFrame.Snapshot("AnalysisTree", fOutName, columns);

RESTInfo << "Re-importing analysis tree." << RESTendl;
fDataSet = ROOT::RDataFrame("AnalysisTree", fOutName);
fDataFrame = ROOT::RDataFrame("AnalysisTree", fOutName);

TFile* f = TFile::Open(fOutName.c_str());
fTree = (TChain*)f->Get("AnalysisTree");
Expand All @@ -846,7 +888,7 @@ void TRestDataSet::Export(const std::string& filename, std::vector<std::string>
RESTInfo << "Re-Generating snapshot." << RESTendl;
std::string user = getenv("USER");
std::string fOutName = "/tmp/rest_output_" + user + ".root";
fDataSet.Snapshot("AnalysisTree", fOutName);
fDataFrame.Snapshot("AnalysisTree", fOutName);

TFile* f = TFile::Open(fOutName.c_str());
fTree = (TChain*)f->Get("AnalysisTree");
Expand Down Expand Up @@ -910,7 +952,7 @@ void TRestDataSet::Export(const std::string& filename, std::vector<std::string>
fprintf(f, "###\n");
fprintf(f, "### Data starts here\n");

auto obsNames = fDataSet.GetColumnNames();
auto obsNames = fDataFrame.GetColumnNames();
std::string obsListStr = "";
for (const auto& l : obsNames) {
if (!obsListStr.empty()) obsListStr += ":";
Expand Down Expand Up @@ -938,7 +980,7 @@ void TRestDataSet::Export(const std::string& filename, std::vector<std::string>

return;
} else if (TRestTools::GetFileNameExtension(filename) == "root") {
fDataSet.Snapshot("AnalysisTree", filename);
fDataFrame.Snapshot("AnalysisTree", filename);

TFile* f = TFile::Open(filename.c_str(), "UPDATE");
std::string name = this->GetName();
Expand Down Expand Up @@ -1038,7 +1080,7 @@ void TRestDataSet::Import(const std::string& fileName) {
else
ROOT::DisableImplicitMT();

fDataSet = ROOT::RDataFrame("AnalysisTree", fileName);
fDataFrame = ROOT::RDataFrame("AnalysisTree", fileName);

fTree = (TChain*)file->Get("AnalysisTree");
}
Expand Down Expand Up @@ -1104,7 +1146,7 @@ void TRestDataSet::Import(std::vector<std::string> fileNames) {
}

RESTInfo << "Opening list of files. First file: " << fileNames[0] << RESTendl;
fDataSet = ROOT::RDataFrame("AnalysisTree", fileNames);
fDataFrame = ROOT::RDataFrame("AnalysisTree", fileNames);

if (fTree != nullptr) {
delete fTree;
Expand Down
8 changes: 7 additions & 1 deletion source/framework/sensitivity/inc/TRestComponentDataSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ class TRestComponentDataSet : public TRestComponent {
/// The dataset used to initialize the distribution
TRestDataSet fDataSet; //!

/// It helps to split large datasets when extracting the parameterization nodes
long long unsigned int fSplitEntries = 600000000;

/// It creates a sample subset using a range definition
TVector2 fDFRange = TVector2(0, 0);

/// It is true of the dataset was loaded without issues
Bool_t fDataSetLoaded = false; //!

Expand Down Expand Up @@ -84,6 +90,6 @@ class TRestComponentDataSet : public TRestComponent {
TRestComponentDataSet(const char* cfgFileName, const std::string& name);
~TRestComponentDataSet();

ClassDefOverride(TRestComponentDataSet, 3);
ClassDefOverride(TRestComponentDataSet, 4);
};
#endif
32 changes: 22 additions & 10 deletions source/framework/sensitivity/src/TRestComponentDataSet.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ void TRestComponentDataSet::PrintMetadata() {
RESTMetadata << " " << RESTendl;
}

if (fDFRange.X() != 0 || fDFRange.Y() != 0) {
RESTMetadata << " DataFrame range: ( " << fDFRange.X() << ", " << fDFRange.Y() << ")" << RESTendl;
RESTMetadata << " " << RESTendl;
}

if (!fParameter.empty() && fParameterizationNodes.empty()) {
RESTMetadata << "This component has no nodes!" << RESTendl;
RESTMetadata << " Use: LoadDataSets() to initialize the nodes" << RESTendl;
Expand Down Expand Up @@ -383,15 +388,17 @@ std::vector<Double_t> TRestComponentDataSet::ExtractParameterizationNodes() {
return vs;
}

auto parValues = fDataSet.GetDataFrame().Take<double>(fParameter);
for (const auto v : parValues) vs.push_back(v);
auto GetUniqueElements = [](const std::vector<double>& vec) {
std::set<double> uniqueSet(vec.begin(), vec.end());
return std::vector<double>(uniqueSet.begin(), uniqueSet.end());
};

std::vector<double>::iterator ip;
ip = std::unique(vs.begin(), vs.begin() + vs.size());
vs.resize(std::distance(vs.begin(), ip));
std::sort(vs.begin(), vs.end());
ip = std::unique(vs.begin(), vs.end());
vs.resize(std::distance(vs.begin(), ip));
for (size_t n = 0; n < 1 + fDataSet.GetEntries() / fSplitEntries; n++) {
auto nEn = fDataSet.Range(n * fSplitEntries, (n + 1) * fSplitEntries).Count();
auto parValues = fDataSet.Range(n * fSplitEntries, (n + 1) * fSplitEntries).Take<double>(fParameter);
std::vector<double> uniqueVec = GetUniqueElements(*parValues);
vs.insert(vs.end(), uniqueVec.begin(), uniqueVec.end());
}

return vs;
}
Expand Down Expand Up @@ -476,6 +483,9 @@ Bool_t TRestComponentDataSet::LoadDataSets() {
fDataSet.Import(fullFileNames);
fDataSetLoaded = true;

if (fDFRange.X() != 0 || fDFRange.Y() != 0)
fDataSet.ApplyRange((size_t)fDFRange.X(), (size_t)fDFRange.Y());

if (fDataSet.GetTree() == nullptr) {
RESTError << "Problem loading dataset from file list :" << RESTendl;
for (const auto& f : fDataSetFileNames) RESTError << " - " << f << RESTendl;
Expand All @@ -486,6 +496,7 @@ Bool_t TRestComponentDataSet::LoadDataSets() {

if (VariablesOk() && WeightsOk()) {
fParameterizationNodes = ExtractParameterizationNodes();
RESTInfo << "Filling histograms" << RESTendl;
FillHistograms();
return fDataSetLoaded;
}
Expand Down Expand Up @@ -515,11 +526,12 @@ Bool_t TRestComponentDataSet::WeightsOk() {
Bool_t ok = true;
std::vector cNames = fDataSet.GetDataFrame().GetColumnNames();

for (const auto& var : fWeights)
if (std::count(cNames.begin(), cNames.end(), var) == 0) {
for (const auto& var : fWeights) {
if (!isANumber(var) && std::count(cNames.begin(), cNames.end(), var) == 0) {
RESTError << "Weight ---> " << var << " <--- NOT found on dataset" << RESTendl;
ok = false;
}
}
return ok;
}

Expand Down
Loading