Skip to content

Commit

Permalink
Adding a data quality enum to SubjectOnDisk
Browse files Browse the repository at this point in the history
  • Loading branch information
keenon committed Sep 28, 2024
1 parent 46eaef0 commit 69ff638
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 0 deletions.
54 changes: 54 additions & 0 deletions dart/biomechanics/SubjectOnDisk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,40 @@ proto::BasicTrialType basicTrialTypeToProto(BasicTrialType type)
return proto::BasicTrialType::other;
}

DataQuality dataQualityFromProto(proto::DataQuality quality)
{
switch (quality)
{
case proto::DataQuality::pilotData:
return pilotData;
case proto::DataQuality::experimentalData:
return experimentalData;
case proto::DataQuality::internetData:
return internetData;
case proto::DataQuality_INT_MIN_SENTINEL_DO_NOT_USE_:
return internetData;
break;
case proto::DataQuality_INT_MAX_SENTINEL_DO_NOT_USE_:
return internetData;
break;
}
return internetData;
}

proto::DataQuality dataQualityToProto(DataQuality quality)
{
switch (quality)
{
case pilotData:
return proto::DataQuality::pilotData;
case experimentalData:
return proto::DataQuality::experimentalData;
case internetData:
return proto::DataQuality::internetData;
}
return proto::DataQuality::pilotData;
}

DetectedTrialFeature detectedTrialFeatureFromProto(
proto::DetectedTrialFeature feature)
{
Expand Down Expand Up @@ -1568,6 +1602,12 @@ std::vector<MissingGRFReason> SubjectOnDisk::getMissingGRF(int trial)
return mHeader->mTrials[trial]->mMissingGRFReason;
}

/// This returns the user supplied enum of type 'DataQuality'
DataQuality SubjectOnDisk::getQuality()
{
return mHeader->getQuality();
}

int SubjectOnDisk::getNumProcessingPasses()
{
return mHeader->mPasses.size();
Expand Down Expand Up @@ -3680,6 +3720,17 @@ SubjectOnDiskHeader& SubjectOnDiskHeader::setNotes(const std::string& notes)
return *this;
}

SubjectOnDiskHeader& SubjectOnDiskHeader::setQuality(DataQuality quality)
{
mDataQuality = quality;
return *this;
}

DataQuality SubjectOnDiskHeader::getQuality()
{
return mDataQuality;
}

std::shared_ptr<SubjectOnDiskPassHeader>
SubjectOnDiskHeader::addProcessingPass()
{
Expand Down Expand Up @@ -3915,6 +3966,7 @@ void SubjectOnDiskHeader::write(dart::proto::SubjectOnDiskHeader* header)
{
header->add_exo_dof_index(index);
}
header->set_data_quality(dataQualityToProto(mDataQuality));

if (!header->IsInitialized())
{
Expand Down Expand Up @@ -4059,6 +4111,8 @@ void SubjectOnDiskHeader::read(const dart::proto::SubjectOnDiskHeader& proto)
{
mExoDofIndices.push_back(proto.exo_dof_index(i));
}

mDataQuality = dataQualityFromProto(proto.data_quality());
}

void SubjectOnDiskHeader::writeSensorsFrame(
Expand Down
8 changes: 8 additions & 0 deletions dart/biomechanics/SubjectOnDisk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,8 @@ class SubjectOnDiskHeader
SubjectOnDiskHeader& setSubjectTags(std::vector<std::string> subjectTags);
SubjectOnDiskHeader& setHref(const std::string& sourceHref);
SubjectOnDiskHeader& setNotes(const std::string& notes);
SubjectOnDiskHeader& setQuality(DataQuality quality);
DataQuality getQuality();
std::shared_ptr<SubjectOnDiskPassHeader> addProcessingPass();
std::vector<std::shared_ptr<SubjectOnDiskPassHeader>> getProcessingPasses();
std::shared_ptr<SubjectOnDiskTrial> addTrial();
Expand Down Expand Up @@ -522,6 +524,9 @@ class SubjectOnDiskHeader
// This is exoskeleton data
std::vector<int> mExoDofIndices;

// This is the user supplied quality of the data
DataQuality mDataQuality;

friend class SubjectOnDisk;
friend struct Frame;
friend struct FramePass;
Expand Down Expand Up @@ -637,6 +642,9 @@ class SubjectOnDisk
/// include `notMissingGRF`.
std::vector<MissingGRFReason> getMissingGRF(int trial);

/// This returns the user supplied enum of type 'DataQuality'
DataQuality getQuality();

int getNumProcessingPasses();

ProcessingPassType getProcessingPassType(int processingPass);
Expand Down
7 changes: 7 additions & 0 deletions dart/biomechanics/enums.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ enum DetectedTrialFeature
flatTerrain
};

enum DataQuality
{
pilotData,
experimentalData,
internetData
};

enum MissingGRFStatus
{
no = 0, // no will cast to `false`
Expand Down
8 changes: 8 additions & 0 deletions dart/proto/SubjectOnDisk.proto
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ enum DetectedTrialFeature {
flatTerrain = 3;
}

enum DataQuality {
pilotData = 0;
experimentalData = 1;
internetData = 2;
}

// Many of the ML tasks we want to support from SubjectOnDisk data include
// effectively predicting the results of a downstream processing task from
// an upstream processing task. Trivially, that's predicting physics from
Expand Down Expand Up @@ -166,6 +172,8 @@ message SubjectOnDiskHeader {
repeated int32 exo_dof_index = 22;
// Details about the subject tags provided on the AddBiomechanics platform
repeated string subject_tag = 23;
// This is what the user has tagged this subject as, in terms of data quality
DataQuality data_quality = 25;
}

message SubjectOnDiskProcessingPassFrame {
Expand Down
27 changes: 27 additions & 0 deletions python/_nimblephysics/biomechanics/SubjectOnDisk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,21 @@ void SubjectOnDisk(py::module& m)
"This is a trial that doesn't fit into any of the "
"other categories.");

auto dataQuality
= ::py::enum_<dart::biomechanics::DataQuality>(m, "DataQuality")
.value(
"PILOT_DATA",
dart::biomechanics::DataQuality::pilotData,
"This is data that was collected as part of a pilot study.")
.value(
"EXPERIMENTAL_DATA",
dart::biomechanics::DataQuality::experimentalData,
"This is data that was collected as part of an experiment.")
.value(
"INTERNET_DATA",
dart::biomechanics::DataQuality::internetData,
"This is data that was collected from the internet.");

auto detectedTrialFeature
= ::py::enum_<dart::biomechanics::DetectedTrialFeature>(
m, "DetectedTrialFeature")
Expand Down Expand Up @@ -1089,6 +1104,13 @@ Note that these are specified in the local body frame, acting on the body at its
"setNotes",
&dart::biomechanics::SubjectOnDiskHeader::setNotes,
::py::arg("notes"))
.def(
"setQuality",
&dart::biomechanics::SubjectOnDiskHeader::setQuality,
::py::arg("quality"))
.def(
"getQuality",
&dart::biomechanics::SubjectOnDiskHeader::getQuality)
.def(
"addProcessingPass",
&dart::biomechanics::SubjectOnDiskHeader::addProcessingPass)
Expand Down Expand Up @@ -1323,6 +1345,11 @@ Note that these are specified in the local body frame, acting on the body at its
This method is provided to give a cheaper way to filter out frames we want to ignore for training, without having to call
the more expensive :code:`loadFrames()` and examine frames individually.
)doc")
.def(
"getQuality",
&dart::biomechanics::SubjectOnDisk::getQuality,
"This returns the user-supplied quality of the data in this "
"subject")
// int getNumProcessingPasses();
.def(
"getNumProcessingPasses",
Expand Down

0 comments on commit 69ff638

Please sign in to comment.