Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wpimath] Simplify pose estimator #6705

Merged
merged 5 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 127 additions & 101 deletions wpimath/src/main/java/edu/wpi/first/math/estimator/PoseEstimator.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,21 @@
package edu.wpi.first.math.estimator;

import edu.wpi.first.math.MathSharedStore;
import edu.wpi.first.math.MathUtil;
import edu.wpi.first.math.Matrix;
import edu.wpi.first.math.Nat;
import edu.wpi.first.math.VecBuilder;
import edu.wpi.first.math.geometry.Pose2d;
import edu.wpi.first.math.geometry.Rotation2d;
import edu.wpi.first.math.geometry.Twist2d;
import edu.wpi.first.math.interpolation.Interpolatable;
import edu.wpi.first.math.interpolation.TimeInterpolatableBuffer;
import edu.wpi.first.math.kinematics.Kinematics;
import edu.wpi.first.math.kinematics.Odometry;
import edu.wpi.first.math.numbers.N1;
import edu.wpi.first.math.numbers.N3;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.NavigableMap;
import java.util.Optional;
import java.util.TreeMap;

/**
* This class wraps {@link Odometry} to fuse latency-compensated vision measurements with encoder
Expand All @@ -38,14 +37,20 @@
* @param <T> Wheel positions type.
*/
public class PoseEstimator<T> {
private final Kinematics<?, T> m_kinematics;
private final Odometry<T> m_odometry;
private final Matrix<N3, N1> m_q = new Matrix<>(Nat.N3(), Nat.N1());
private final Matrix<N3, N3> m_visionK = new Matrix<>(Nat.N3(), Nat.N3());

private static final double kBufferDuration = 1.5;
private final TimeInterpolatableBuffer<InterpolationRecord> m_poseBuffer =
// Maps timestamps to odometry-only pose estimates
private final TimeInterpolatableBuffer<Pose2d> m_odometryPoseBuffer =
TimeInterpolatableBuffer.createBuffer(kBufferDuration);
// Maps timestamps to vision updates
// Always contains one entry before the oldest entry in m_odometryPoseBuffer, unless there have
// been no vision measurements after the last reset
private final NavigableMap<Double, VisionUpdate> m_visionUpdates = new TreeMap<>();

private Pose2d m_poseEstimate;

/**
* Constructs a PoseEstimator.
Expand All @@ -59,14 +64,16 @@ public class PoseEstimator<T> {
* in meters, y position in meters, and heading in radians). Increase these numbers to trust
* the vision pose measurement less.
*/
@SuppressWarnings("PMD.UnusedFormalParameter")
public PoseEstimator(
Kinematics<?, T> kinematics,
Odometry<T> odometry,
Matrix<N3, N1> stateStdDevs,
Matrix<N3, N1> visionMeasurementStdDevs) {
m_kinematics = kinematics;
m_odometry = odometry;

m_poseEstimate = m_odometry.getPoseMeters();

for (int i = 0; i < 3; ++i) {
m_q.set(i, 0, stateStdDevs.get(i, 0) * stateStdDevs.get(i, 0));
}
Expand Down Expand Up @@ -113,7 +120,9 @@ public final void setVisionMeasurementStdDevs(Matrix<N3, N1> visionMeasurementSt
public void resetPosition(Rotation2d gyroAngle, T wheelPositions, Pose2d poseMeters) {
// Reset state estimate and error covariance
m_odometry.resetPosition(gyroAngle, wheelPositions, poseMeters);
m_poseBuffer.clear();
m_odometryPoseBuffer.clear();
m_visionUpdates.clear();
m_poseEstimate = m_odometry.getPoseMeters();
}

/**
Expand All @@ -122,7 +131,7 @@ public void resetPosition(Rotation2d gyroAngle, T wheelPositions, Pose2d poseMet
* @return The estimated robot pose in meters.
*/
public Pose2d getEstimatedPosition() {
return m_odometry.getPoseMeters();
return m_poseEstimate;
}

/**
Expand All @@ -132,7 +141,54 @@ public Pose2d getEstimatedPosition() {
* @return The pose at the given timestamp (or Optional.empty() if the buffer is empty).
*/
public Optional<Pose2d> sampleAt(double timestampSeconds) {
return m_poseBuffer.getSample(timestampSeconds).map(record -> record.poseMeters);
// Step 0: If there are no odometry updates to sample, skip.
if (m_odometryPoseBuffer.getInternalBuffer().isEmpty()) {
return Optional.empty();
}

// Step 1: Make sure timestamp matches the sample from the odometry pose buffer. (When sampling,
// the buffer will always use a timestamp between the first and last timestamps)
double oldestOdometryTimestamp = m_odometryPoseBuffer.getInternalBuffer().firstKey();
double newestOdometryTimestamp = m_odometryPoseBuffer.getInternalBuffer().lastKey();
timestampSeconds =
MathUtil.clamp(timestampSeconds, oldestOdometryTimestamp, newestOdometryTimestamp);

// Step 2: If there are no applicable vision updates, use the odometry-only information.
if (m_visionUpdates.isEmpty() || timestampSeconds < m_visionUpdates.firstKey()) {
return m_odometryPoseBuffer.getSample(timestampSeconds);
}

// Step 3: Get the latest vision update from before or at the timestamp to sample at.
double floorTimestamp = m_visionUpdates.floorKey(timestampSeconds);
var visionUpdate = m_visionUpdates.get(floorTimestamp);

// Step 4: Get the pose measured by odometry at the time of the sample.
var odometryEstimate = m_odometryPoseBuffer.getSample(timestampSeconds);

// Step 5: Apply the vision compensation to the odometry pose.
return odometryEstimate.map(odometryPose -> visionUpdate.compensate(odometryPose));
}

/** Removes stale vision updates that won't affect sampling. */
private void cleanUpVisionUpdates() {
// Step 0: If there are no odometry samples, skip.
if (m_odometryPoseBuffer.getInternalBuffer().isEmpty()) {
return;
}

// Step 1: Find the oldest timestamp that needs a vision update.
double oldestOdometryTimestamp = m_odometryPoseBuffer.getInternalBuffer().firstKey();

// Step 2: If there are no vision updates before that timestamp, skip.
if (m_visionUpdates.isEmpty() || oldestOdometryTimestamp < m_visionUpdates.firstKey()) {
return;
}

// Step 3: Find the newest vision update timestamp before or at the oldest timestamp.
double newestNeededVisionUpdateTimestamp = m_visionUpdates.floorKey(oldestOdometryTimestamp);

// Step 4: Remove all entries strictly before the newest timestamp we need.
m_visionUpdates.headMap(newestNeededVisionUpdateTimestamp, false).clear();
}

/**
Expand All @@ -156,50 +212,51 @@ public Optional<Pose2d> sampleAt(double timestampSeconds) {
*/
public void addVisionMeasurement(Pose2d visionRobotPoseMeters, double timestampSeconds) {
// Step 0: If this measurement is old enough to be outside the pose buffer's timespan, skip.
try {
if (m_poseBuffer.getInternalBuffer().lastKey() - kBufferDuration > timestampSeconds) {
return;
}
} catch (NoSuchElementException ex) {
if (m_odometryPoseBuffer.getInternalBuffer().isEmpty()
|| m_odometryPoseBuffer.getInternalBuffer().lastKey() - kBufferDuration
> timestampSeconds) {
return;
}

// Step 1: Get the pose odometry measured at the moment the vision measurement was made.
var sample = m_poseBuffer.getSample(timestampSeconds);
// Step 1: Clean up any old entries
cleanUpVisionUpdates();

// Step 2: Get the pose measured by odometry at the moment the vision measurement was made.
var odometrySample = m_odometryPoseBuffer.getSample(timestampSeconds);

if (sample.isEmpty()) {
if (odometrySample.isEmpty()) {
return;
}

// Step 2: Measure the twist between the odometry pose and the vision pose.
var twist = sample.get().poseMeters.log(visionRobotPoseMeters);
// Step 3: Get the vision-compensated pose estimate at the moment the vision measurement was
// made.
var visionSample = sampleAt(timestampSeconds);

// Step 3: We should not trust the twist entirely, so instead we scale this twist by a Kalman
if (visionSample.isEmpty()) {
return;
}

// Step 4: Measure the twist between the old pose estimate and the vision pose.
var twist = visionSample.get().log(visionRobotPoseMeters);

// Step 5: We should not trust the twist entirely, so instead we scale this twist by a Kalman
// gain matrix representing how much we trust vision measurements compared to our current pose.
var k_times_twist = m_visionK.times(VecBuilder.fill(twist.dx, twist.dy, twist.dtheta));

// Step 4: Convert back to Twist2d.
// Step 6: Convert back to Twist2d.
var scaledTwist =
new Twist2d(k_times_twist.get(0, 0), k_times_twist.get(1, 0), k_times_twist.get(2, 0));

// Step 5: Reset Odometry to state at sample with vision adjustment.
m_odometry.resetPosition(
sample.get().gyroAngle,
sample.get().wheelPositions,
sample.get().poseMeters.exp(scaledTwist));

// Step 6: Record the current pose to allow multiple measurements from the same timestamp
m_poseBuffer.addSample(
timestampSeconds,
new InterpolationRecord(
getEstimatedPosition(), sample.get().gyroAngle, sample.get().wheelPositions));

// Step 7: Replay odometry inputs between sample time and latest recorded sample to update the
// pose buffer and correct odometry.
for (Map.Entry<Double, InterpolationRecord> entry :
m_poseBuffer.getInternalBuffer().tailMap(timestampSeconds).entrySet()) {
updateWithTime(entry.getKey(), entry.getValue().gyroAngle, entry.getValue().wheelPositions);
}
// Step 7: Calculate and record the vision update.
var visionUpdate = new VisionUpdate(visionSample.get().exp(scaledTwist), odometrySample.get());
m_visionUpdates.put(timestampSeconds, visionUpdate);

// Step 8: Remove later vision measurements. (Matches previous behavior)
m_visionUpdates.tailMap(timestampSeconds, false).entrySet().clear();

// Step 9: Update latest pose estimate. Since we cleared all updates after this vision update,
// it's guaranteed to be the latest vision update.
m_poseEstimate = visionUpdate.compensate(m_odometry.getPoseMeters());
}

/**
Expand Down Expand Up @@ -258,83 +315,52 @@ public Pose2d update(Rotation2d gyroAngle, T wheelPositions) {
* @return The estimated pose of the robot in meters.
*/
public Pose2d updateWithTime(double currentTimeSeconds, Rotation2d gyroAngle, T wheelPositions) {
m_odometry.update(gyroAngle, wheelPositions);
m_poseBuffer.addSample(
currentTimeSeconds,
new InterpolationRecord(
getEstimatedPosition(), gyroAngle, m_kinematics.copy(wheelPositions)));
var odometryEstimate = m_odometry.update(gyroAngle, wheelPositions);

m_odometryPoseBuffer.addSample(currentTimeSeconds, odometryEstimate);

if (m_visionUpdates.isEmpty()) {
m_poseEstimate = odometryEstimate;
} else {
var visionUpdate = m_visionUpdates.get(m_visionUpdates.lastKey());
m_poseEstimate = visionUpdate.compensate(odometryEstimate);
}

return getEstimatedPosition();
}

/**
* Represents an odometry record. The record contains the inputs provided as well as the pose that
* was observed based on these inputs, as well as the previous record and its inputs.
* Represents a vision update record. The record contains the vision-compensated pose estimate as
* well as the corresponding odometry pose estimate.
*/
private final class InterpolationRecord implements Interpolatable<InterpolationRecord> {
// The pose observed given the current sensor inputs and the previous pose.
private final Pose2d poseMeters;
private static final class VisionUpdate {
// The vision-compensated pose estimate.
private final Pose2d visionPose;

// The current gyro angle.
private final Rotation2d gyroAngle;

// The current encoder readings.
private final T wheelPositions;
// The pose estimated based solely on odometry.
private final Pose2d odometryPose;

/**
* Constructs an Interpolation Record with the specified parameters.
* Constructs a vision update record with the specified parameters.
*
* @param poseMeters The pose observed given the current sensor inputs and the previous pose.
* @param gyro The current gyro angle.
* @param wheelPositions The current encoder readings.
* @param visionPose The vision-compensated pose estimate.
* @param odometryPose The pose estimate based solely on odometry.
*/
private InterpolationRecord(Pose2d poseMeters, Rotation2d gyro, T wheelPositions) {
this.poseMeters = poseMeters;
this.gyroAngle = gyro;
this.wheelPositions = wheelPositions;
private VisionUpdate(Pose2d visionPose, Pose2d odometryPose) {
this.visionPose = visionPose;
this.odometryPose = odometryPose;
}

/**
* Return the interpolated record. This object is assumed to be the starting position, or lower
* bound.
* Returns the vision-compensated version of the pose. Specifically, changes the pose from being
* relative to this record's odometry pose to being relative to this record's vision pose.
*
* @param endValue The upper bound, or end.
* @param t How far between the lower and upper bound we are. This should be bounded in [0, 1].
* @return The interpolated value.
* @param pose The pose to compensate.
* @return The compensated pose.
*/
@Override
public InterpolationRecord interpolate(InterpolationRecord endValue, double t) {
if (t < 0) {
return this;
} else if (t >= 1) {
return endValue;
} else {
// Find the new wheel distances.
var wheelLerp = m_kinematics.interpolate(wheelPositions, endValue.wheelPositions, t);

// Find the new gyro angle.
var gyroLerp = gyroAngle.interpolate(endValue.gyroAngle, t);

// Create a twist to represent the change based on the interpolated sensor inputs.
Twist2d twist = m_kinematics.toTwist2d(wheelPositions, wheelLerp);
twist.dtheta = gyroLerp.minus(gyroAngle).getRadians();

return new InterpolationRecord(poseMeters.exp(twist), gyroLerp, wheelLerp);
}
}

@Override
public boolean equals(Object obj) {
return this == obj
|| obj instanceof PoseEstimator<?>.InterpolationRecord record
&& Objects.equals(gyroAngle, record.gyroAngle)
&& Objects.equals(wheelPositions, record.wheelPositions)
&& Objects.equals(poseMeters, record.poseMeters);
}

@Override
public int hashCode() {
return Objects.hash(gyroAngle, wheelPositions, poseMeters);
public Pose2d compensate(Pose2d pose) {
var delta = pose.minus(this.odometryPose);
return this.visionPose.plus(delta);
}
}
}
Loading
Loading