Skip to content

Commit

Permalink
Merge pull request #12 from skhrg/extra_utils
Browse files Browse the repository at this point in the history
Extra utils
  • Loading branch information
skhrg authored May 9, 2024
2 parents 84d6b4c + 3c7d047 commit eb4a00d
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 2 deletions.
43 changes: 43 additions & 0 deletions megham/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,49 @@ def get_affine(
return affine, shift


def get_affine_two_stage(
src: NDArray[np.floating],
dst: NDArray[np.floating],
weights: NDArray[np.floating],
) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
"""
Get affine transformation between two point clouds with a two stage solver.
This first uses the SVD to do an intitial alignment and
then uses weighted least squares to compute a correction on top of that.
Transformation is dst = affine@src + shift
Parameters
----------
src : NDArray[np.floating]
A (npoints, ndim) array of source points.
dst : NDArray[np.floating]
A (npoints, ndim) array of destination points.
weights : NDArray[np.floating]
(npoints,) array of weights to use.
If provided a weighted least squares is done instead of an SVD.
Returns
-------
affine : NDArray[np.floating]
The (ndim, ndim) transformation matrix.
shift : NDArray[np.floating]
The (ndim,) shift to apply after transformation.
"""
# Do an initial alignment without weights
affine_0, shift_0 = get_affine(src, dst, force_svd=True)
init_align = apply_transform(src, affine_0, shift_0)
# Now compute the actual transform
affine, shift = get_affine(init_align, dst, weights)
# Compose the transforms
affine, shift = compose_transform(affine_0, shift_0, affine, shift)
# Now one last shift correction
transformed = apply_transform(src, affine, shift)
shift += get_shift(transformed, dst, "mean", weights)

return affine, shift


def apply_transform(
src: NDArray[np.floating],
transform: NDArray[np.floating],
Expand Down
54 changes: 52 additions & 2 deletions megham/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,69 @@ def estimate_var(
return var


def estimate_spacing(coords: NDArray[np.floating]):
def estimate_spacing(coords: NDArray[np.floating]) -> float:
"""
Estimate the spacing between points in a point cloud.
This is just the median distance between nearest neighbors.
Parameters
----------
coords: NDArray[np.floating]
coords : NDArray[np.floating]
The point cloud to estimate spacing of.
Should have shape (npoint, ndim).
Returns
-------
spacing : float
The spacing between points.
"""
edm = make_edm(coords)
edm[edm == 0] = np.nan
nearest_dists = np.nanmin(edm, axis=0)

return np.median(nearest_dists)


def gen_weights(
src: NDArray[np.floating],
dst: NDArray[np.floating],
var: Optional[NDArray[np.floating]] = None,
pdf: bool = False,
) -> NDArray[np.floating]:
"""
Generate weights between points in two registered point clouds.
The weight here is just the liklihood from a gaussian.
Note that this is not a GMM, each weight is computed from a single
gaussian since we are assuming a known registration.
Parameters
----------
src : NDArray[np.floating]
The set of source points to be mapped onto the target points.
Should have shape (nsrcpoints, ndim).
dst : NDArray[np.floating]
The set of destination points to be mapped onto.
Should have shape (ndstpoints, ndim).
var : Optional[NDArray[np.floating]], default: None
The variance along each axis.
Should have shape (ndim,) if provided.
If None, will be computed with estimate_var
pdf : bool, default: False
If True apply the 1/sqrt(2*pi*var) normalization factor.
This makes the weights the PDF of a normal distribution.
Returns
-------
weights : NDArray[np.floating]
(npoints,) array of weights.
"""
if var is None:
var = estimate_var(src, dst)
norm = np.ones_like(var)
if pdf:
norm = 1.0 / np.sqrt(2 * np.pi * var)

# Compute nd gaussian for each pair
weights = np.prod(norm * np.exp(-0.5 * (src - dst) ** 2 / var), axis=1)

return weights

0 comments on commit eb4a00d

Please sign in to comment.