Skip to content

Commit

Permalink
Fixing build errors after merge
Browse files Browse the repository at this point in the history
  • Loading branch information
keenon committed Dec 2, 2024
1 parent 63599c5 commit 73619fb
Show file tree
Hide file tree
Showing 11 changed files with 951 additions and 359 deletions.
20 changes: 14 additions & 6 deletions dart/neural/WithRespectToDamping.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ WithRespectToDamping::WithRespectToDamping()

//==============================================================================
WrtDampingJointEntry::WrtDampingJointEntry(
std::string jointName, WrtDampingJointEntryType type, int dofs, Eigen::VectorXi worldDofs)
std::string jointName,
WrtDampingJointEntryType type,
int dofs,
Eigen::VectorXi worldDofs)
: jointName(jointName), type(type), mDofs(dofs), mWorldDofs(worldDofs)
{
}
Expand All @@ -31,21 +34,19 @@ void WrtDampingJointEntry::set(
dynamics::Skeleton* skel, const Eigen::Ref<Eigen::VectorXs>& value)
{
dynamics::Joint* joint = skel->getJoint(jointName);
for(int i = 0; i < joint->getNumDofs(); i++)
for (int i = 0; i < joint->getNumDofs(); i++)
{
joint->setDampingCoefficient(i, value(i));
}
return;


}

//==============================================================================
void WrtDampingJointEntry::get(
dynamics::Skeleton* skel, Eigen::Ref<Eigen::VectorXs> out)
{
dynamics::Joint* joint = skel->getJoint(jointName);
for(int i = 0; i < joint->getNumDofs(); i++)
for (int i = 0; i < joint->getNumDofs(); i++)
{
out(i) = joint->getDampingCoefficient(i);
}
Expand All @@ -69,7 +70,8 @@ WrtDampingJointEntry& WithRespectToDamping::registerJoint(
{
std::string skelName = joint->getSkeleton()->getName();
std::vector<WrtDampingJointEntry>& skelEntries = mEntries[skelName];
skelEntries.emplace_back(joint->getName(), type, joint->getNumDofs(), dofs_index);
skelEntries.emplace_back(
joint->getName(), type, joint->getNumDofs(), dofs_index);

WrtDampingJointEntry& entry = skelEntries[skelEntries.size() - 1];

Expand Down Expand Up @@ -121,6 +123,12 @@ WrtDampingJointEntry& WithRespectToDamping::getJoint(dynamics::Joint* joint)
throw std::runtime_error{"Execution should never reach this point"};
}

//==============================================================================
std::string WithRespectToDamping::name()
{
return "DAMPING";
}

//==============================================================================
/// This returns this WRT from the world as a vector
Eigen::VectorXs WithRespectToDamping::get(simulation::World* world)
Expand Down
10 changes: 8 additions & 2 deletions dart/neural/WithRespectToDamping.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ struct WrtDampingJointEntry
int mDofs;
Eigen::VectorXi mWorldDofs;

WrtDampingJointEntry(std::string jointName, WrtDampingJointEntryType type, int dofs, Eigen::VectorXi worldDofs);
WrtDampingJointEntry(
std::string jointName,
WrtDampingJointEntryType type,
int dofs,
Eigen::VectorXi worldDofs);

int dim();

Expand All @@ -55,7 +59,7 @@ class WithRespectToDamping : public WithRespectTo
/// way in this differentiation
/// Need to set the dofs
WrtDampingJointEntry& registerJoint(
dynamics::Joint* joint,
dynamics::Joint* joint,
WrtDampingJointEntryType type,
Eigen::VectorXi dofs_index,
Eigen::VectorXs upperBound,
Expand All @@ -69,6 +73,8 @@ class WithRespectToDamping : public WithRespectTo
// Implement all the methods we need
//////////////////////////////////////////////////////////////

std::string name() override;

/// This returns this WRT from the world as a vector
Eigen::VectorXs get(simulation::World* world) override;

Expand Down
20 changes: 14 additions & 6 deletions dart/neural/WithRespectToSpring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ WithRespectToSpring::WithRespectToSpring()

//==============================================================================
WrtSpringJointEntry::WrtSpringJointEntry(
std::string jointName, WrtSpringJointEntryType type, int dofs, Eigen::VectorXi worldDofs)
std::string jointName,
WrtSpringJointEntryType type,
int dofs,
Eigen::VectorXi worldDofs)
: jointName(jointName), type(type), mDofs(dofs), mWorldDofs(worldDofs)
{
}
Expand All @@ -31,21 +34,19 @@ void WrtSpringJointEntry::set(
dynamics::Skeleton* skel, const Eigen::Ref<Eigen::VectorXs>& value)
{
dynamics::Joint* joint = skel->getJoint(jointName);
for(int i = 0; i < joint->getNumDofs(); i++)
for (int i = 0; i < joint->getNumDofs(); i++)
{
joint->setSpringStiffness(i, value(i));
}
return;


}

//==============================================================================
void WrtSpringJointEntry::get(
dynamics::Skeleton* skel, Eigen::Ref<Eigen::VectorXs> out)
{
dynamics::Joint* joint = skel->getJoint(jointName);
for(int i = 0; i < joint->getNumDofs(); i++)
for (int i = 0; i < joint->getNumDofs(); i++)
{
out(i) = joint->getSpringStiffness(i);
}
Expand All @@ -69,7 +70,8 @@ WrtSpringJointEntry& WithRespectToSpring::registerJoint(
{
std::string skelName = joint->getSkeleton()->getName();
std::vector<WrtSpringJointEntry>& skelEntries = mEntries[skelName];
skelEntries.emplace_back(joint->getName(), type, joint->getNumDofs(), dofs_index);
skelEntries.emplace_back(
joint->getName(), type, joint->getNumDofs(), dofs_index);

WrtSpringJointEntry& entry = skelEntries[skelEntries.size() - 1];

Expand Down Expand Up @@ -121,6 +123,12 @@ WrtSpringJointEntry& WithRespectToSpring::getJoint(dynamics::Joint* joint)
throw std::runtime_error{"Execution should never reach this point"};
}

//==============================================================================
std::string WithRespectToSpring::name()
{
return "SPRING";
}

//==============================================================================
/// This returns this WRT from the world as a vector
Eigen::VectorXs WithRespectToSpring::get(simulation::World* world)
Expand Down
10 changes: 8 additions & 2 deletions dart/neural/WithRespectToSpring.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ struct WrtSpringJointEntry
int mDofs;
Eigen::VectorXi mWorldDofs;

WrtSpringJointEntry(std::string jointName, WrtSpringJointEntryType type, int dofs, Eigen::VectorXi worldDofs);
WrtSpringJointEntry(
std::string jointName,
WrtSpringJointEntryType type,
int dofs,
Eigen::VectorXi worldDofs);

int dim();

Expand All @@ -55,7 +59,7 @@ class WithRespectToSpring : public WithRespectTo
/// way in this differentiation
/// Need to set the dofs
WrtSpringJointEntry& registerJoint(
dynamics::Joint* joint,
dynamics::Joint* joint,
WrtSpringJointEntryType type,
Eigen::VectorXi dofs_index,
Eigen::VectorXs upperBound,
Expand All @@ -69,6 +73,8 @@ class WithRespectToSpring : public WithRespectTo
// Implement all the methods we need
//////////////////////////////////////////////////////////////

std::string name() override;

/// This returns this WRT from the world as a vector
Eigen::VectorXs get(simulation::World* world) override;

Expand Down
16 changes: 3 additions & 13 deletions dart/realtime/SSID.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ void SSID::runInference(long startTime)
registerLock();
// Eigen::MatrixXs forceHistory = mControlLog.getValues(
// startTime - mPlanningHistoryMillis, steps, millisPerStep);
Eigen::MatrixXs forceHistory
= mControlLog.getRecentValuesBefore(startTime, steps + 1);
Eigen::MatrixXs forceHistory
= mControlLog.getRecentValuesBefore(startTime, steps + 1);
for (int i = 0; i < steps; i++)
Expand All @@ -230,10 +228,6 @@ void SSID::runInference(long startTime)
// Eigen::MatrixXs velHistory = mSensorLogs[1].getValues(
// startTime - mPlanningHistoryMillis, steps, millisPerStep
// );
Eigen::MatrixXs poseHistory
= mSensorLogs[0].getRecentValuesBefore(startTime, steps + 1);
Eigen::MatrixXs velHistory
= mSensorLogs[1].getRecentValuesBefore(startTime, steps + 1);
Eigen::MatrixXs poseHistory
= mSensorLogs[0].getRecentValuesBefore(startTime, steps + 1);
Eigen::MatrixXs velHistory
Expand Down Expand Up @@ -278,12 +272,10 @@ void SSID::runInference(long startTime)
}
}

Eigen::VectorXs SSID::runPlotting(
// Run plotting function should be idenpotent
// Here it only support 1 body node plotting, but that make sense
std::pair<Eigen::VectorXs, Eigen::MatrixXs> SSID::runPlotting(
long startTime, s_t upper, s_t lower, int samples)
// Run plotting function should be idenpotent
// Here it only support 1 body node plotting, but that make sense
std::pair<Eigen::VectorXs, Eigen::MatrixXs> SSID::runPlotting(
long startTime, s_t upper, s_t lower, int samples)
{
int millisPerStep
= static_cast<int>(ceil(mScale * mWorld->getTimeStep() * 1000.0));
Expand Down Expand Up @@ -341,8 +333,6 @@ Eigen::VectorXs SSID::runPlotting(
losses(0) = loss;
}

return losses;

std::pair<Eigen::VectorXs, Eigen::MatrixXs> result;
result.first = losses;
result.second = poseHistory;
Expand Down
Loading

0 comments on commit 73619fb

Please sign in to comment.