Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP implement lowering for torchvision::roi_align #8362

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jtorchvision_roi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch
import functools
import einops
import jax
import jax.numpy as jnp
from . import ops_registry


@functools.partial(jax.vmap, in_axes=(0, None, None))
def _get_grid_per_box(box: jnp.ndarray, size: int,
sparse: bool) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Obtain a size x size meshgrid inside the given box.

Args:
box: XYXY-format boxes of shape (T, 4).
size: Resolution of the grid.
sparse: Whether to return sparse meshgrid.

Returns:
Two arrays, each has shape (T, size, size) if sparse=False, or
(T, size, 1) and (T, 1, size) if sparse=True.
"""
scale_x = size * 1.0 / (box[2] - box[0])
scale_y = size * 1.0 / (box[3] - box[1])
return jnp.meshgrid( # pytype: disable=bad-return-type # jnp-type
(jnp.arange(size, dtype=box.dtype) + 0.5) / scale_y + box[1],
(jnp.arange(size, dtype=box.dtype) + 0.5) / scale_x + box[0],
indexing="ij",
sparse=sparse)

def _roi_align_einsum(feature: jnp.ndarray, boxes: jnp.ndarray,
output_size: int, sampling_ratio: int) -> jnp.ndarray:
"""An einsum-based implementation of ROIAlign."""
height, width = feature.shape[:2]
grid_y, grid_x = _get_grid_per_box(boxes, output_size * sampling_ratio, True)
grid_y = jnp.squeeze(grid_y, axis=2) # (T, output_size * sampling_ratio)
grid_x = jnp.squeeze(grid_x, axis=1)

def _get_index_and_weights(grid):
"""Computes the 1d index & their weights to be used in interpolation."""
grid -= 0.5 # Coordinates -> Index
x0 = jnp.floor(grid)
x0x1 = jnp.stack([x0, x0 + 1], axis=-1)
# No need to handle out-of-bounds indices here, because jax.nn.one_hot
# ensures that out-of-bounds indices are encoded to all-zero vector.
# This is equivalent to interpolation with zero padding.

x1_weights = grid - x0
x0x1_weights = jnp.stack([1 - x1_weights, x1_weights], axis=-1)
return x0x1, x0x1_weights

def _get_einsum_weights(grid: jnp.ndarray, size: int) -> jnp.ndarray:
"""Combines the 1d index & their interpolation weights to do einsum.

Args:
grid: (T, output_size * sampling_ratio), 1d grid for each box.
size: the input size.

Returns:
A tensor of shape (T, output_size, size), where result[n, i] is a vector
that determines how much every input contributes to the i-th output of
boxes[n].
"""
# Each is (T, output_size * sampling_ratio, 2)
x0x1, x0x1_weights = _get_index_and_weights(grid)
x0x1 = einops.rearrange(
x0x1, "T (o s) two -> T o (s two)", s=sampling_ratio)
x0x1_weights = einops.rearrange(
x0x1_weights, "T (o s) two -> T o (s two) 1", s=sampling_ratio)
# Multiple samples defined by sampling_ratio should be averaged.
x0x1_weights = x0x1_weights / sampling_ratio

# (T, output_size, s*2, size)
x0x1 = jax.nn.one_hot(x0x1, size, dtype=grid.dtype)
# In 1d case, every output value is interpolated from sampling_ratio*2
# input values. So we sum the weights over the s*2 dimension.
return (x0x1 * x0x1_weights).sum(axis=-2)

# Bilinear interpolation can be done by two 1d interpolations.
y_weights = _get_einsum_weights(grid_y, height)
x_weights = _get_einsum_weights(grid_x, width)
return jnp.einsum( # pytype: disable=wrong-arg-types # jnp-type
"HWc,ThH,TwW->Thwc", feature, y_weights, x_weights, optimize=True)


def roi_align(feature: jnp.ndarray, boxes: jnp.ndarray, output_size: int,
sampling_ratio: int) -> jnp.ndarray:
"""ROIAlign operation that crops & resample features within the given boxes.

Args:
feature: feature of shape (H, W, C).
boxes: XYXY boxes of shape (T, 4), boxes to crop from feature.
output_size: Output resolution.
sampling_ratio: Over-sampling ratio of each output value.

Returns:
Output with shape (T, output_size, output_size, C).
"""
if len(feature.shape) != 3:
raise ValueError(f"Expect 3d feature in roi_align! Got {feature.shape}")
if len(boxes.shape) != 2:
raise ValueError(f"Expect 2d boxes in roi_align! Got {boxes.shape}")
return _roi_align_einsum(feature, boxes, output_size, sampling_ratio)

try:
import torch
import torchvision
ops_registry.register_torch_dispatch_op(torch.ops.torchvision.roi_align, roi_align)
except ImportError:
pass
172 changes: 172 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jtorchvision_roi_scratch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import jax
import jax.numpy as jnp
from jax import lax

def bilinear_interpolate(
input, # [N, C, H, W]
roi_batch_ind, # [K]
y, # [K, PH, IY]
x, # [K, PW, IX]
ymask, # [K, IY]
xmask, # [K, IX]
):
"""
Performs bilinear interpolation, respecting the provided masks.

Args:
input: Input feature map (N, C, H, W).
roi_batch_ind: Batch indices for each RoI (K).
y: Vertical sampling coordinates (K, PH, IY).
x: Horizontal sampling coordinates (K, PW, IX).
ymask: Mask for valid y coordinates (K, IY).
xmask: Mask for valid x coordinates (K, IX).

Returns:
Interpolated values (K, C, PH, PW, IY, IX).
"""
_, channels, height, width = input.shape

# Clamp coordinates to be within the feature map boundaries
y = jnp.clip(y, 0)
x = jnp.clip(x, 0)
y_low = jnp.floor(y).astype(int)
x_low = jnp.floor(x).astype(int)
y_high = jnp.where(y_low >= height - 1, height - 1, y_low + 1)
y_low = jnp.where(y_low >= height - 1, height - 1, y_low)
y = jnp.where(y_low >= height - 1, y.astype(input.dtype), y)

x_high = jnp.where(x_low >= width - 1, width - 1, x_low + 1)
x_low = jnp.where(x_low >= width - 1, width - 1, x_low)
x = jnp.where(x_low >= width - 1, x.astype(input.dtype), x)

ly = y - y_low
lx = x - x_low
hy = 1.0 - ly
hx = 1.0 - lx

def masked_index(y, x):
"""Indexes the input tensor, respecting the masks."""
if ymask is not None:
assert xmask is not None
y = jnp.where(ymask[:, None, :], y, 0)
x = jnp.where(xmask[:, None, :], x, 0)
return input[
roi_batch_ind[:, None, None, None, None, None],
jnp.arange(channels)[None, :, None, None, None, None],
y[:, None, :, None, :, None], # prev [K, PH, IY]
x[:, None, None, :, None, :], # prev [K, PW, IX]
] # [K, C, PH, PW, IY, IX]

v1 = masked_index(y_low, x_low)
v2 = masked_index(y_low, x_high)
v3 = masked_index(y_high, x_low)
v4 = masked_index(y_high, x_high)

def outer_prod(y, x):
return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]

w1 = outer_prod(hy, hx)
w2 = outer_prod(hy, lx)
w3 = outer_prod(ly, hx)
w4 = outer_prod(ly, lx)

val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
return val


def roi_align(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio=-1,
aligned=False,
):
"""
Performs RoI Align operation in JAX.

Args:
input: Input feature map (N, C, H, W).
rois: RoIs (K, 5) with batch index in the first column.
spatial_scale: Spatial scale to map RoI coordinates to input coordinates.
pooled_height: Height of the output RoI.
pooled_width: Width of the output RoI.
sampling_ratio: Number of sampling points.
aligned: Whether to align the RoI coordinates.

Returns:
Pooled RoIs (K, C, pooled_height, pooled_width).
"""
orig_dtype = input.dtype

_, _, height, width = input.shape

ph = jnp.arange(pooled_height) # [PH]
pw = jnp.arange(pooled_width) # [PW]

roi_batch_ind = rois[:, 0].astype(int) # [K]
offset = 0.5 if aligned else 0.0
roi_start_w = rois[:, 1] * spatial_scale - offset # [K]
roi_start_h = rois[:, 2] * spatial_scale - offset # [K]
roi_end_w = rois[:, 3] * spatial_scale - offset # [K]
roi_end_h = rois[:, 4] * spatial_scale - offset # [K]

roi_width = roi_end_w - roi_start_w # [K]
roi_height = roi_end_h - roi_start_h # [K]
if not aligned:
roi_width = jnp.clip(roi_width, a_min=1.0) # [K]
roi_height = jnp.clip(roi_height, a_min=1.0) # [K]

bin_size_h = roi_height / pooled_height # [K]
bin_size_w = roi_width / pooled_width # [K]

exact_sampling = sampling_ratio > 0

roi_bin_grid_h = sampling_ratio if exact_sampling else jnp.ceil(
roi_height / pooled_height) # scalar or [K]
roi_bin_grid_w = sampling_ratio if exact_sampling else jnp.ceil(
roi_width / pooled_width) # scalar or [K]

if exact_sampling:
count = max(roi_bin_grid_h * roi_bin_grid_w, 1) # scalar
iy = jnp.arange(roi_bin_grid_h) # [IY]
ix = jnp.arange(roi_bin_grid_w) # [IX]
ymask = None
xmask = None
else:
count = jnp.clip(roi_bin_grid_h * roi_bin_grid_w, a_min=1) # [K]
iy = jnp.arange(height) # [IY]
ix = jnp.arange(width) # [IX]
ymask = iy[None, :] < roi_bin_grid_h[:, None] # [K, IY]
xmask = ix[None, :] < roi_bin_grid_w[:, None] # [K, IX]

def from_K(t):
return t[:, None, None]

y = (
from_K(roi_start_h)
+ ph[None, :, None] * from_K(bin_size_h)
+ (iy[None, None, :] + 0.5).astype(input.dtype) * from_K(bin_size_h / roi_bin_grid_h)
) # [K, PH, IY]
x = (
from_K(roi_start_w)
+ pw[None, :, None] * from_K(bin_size_w)
+ (ix[None, None, :] + 0.5).astype(input.dtype) * from_K(bin_size_w / roi_bin_grid_w)
) # [K, PW, IX]
val = bilinear_interpolate(input, roi_batch_ind, y, x, ymask,
xmask) # [K, C, PH, PW, IY, IX]

if not exact_sampling:
val = jnp.where(ymask[:, None, None, None, :, None], val, 0)
val = jnp.where(xmask[:, None, None, None, None, :], val, 0)

output = val.sum((-1, -2)) # remove IY, IX ~> [K, C, PH, PW]
if isinstance(count, jnp.ndarray):
output = output / count[:, None, None, None]
else:
output = output / count

output = output.astype(orig_dtype)

return output
2 changes: 1 addition & 1 deletion experimental/torch_xla2/torch_xla2/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def get_as_jax_device(self, device: Any):


def load_ops(self):
from torch_xla2.ops import jaten, jtorch, jc10d, jtorchvision_nms, ops_registry
from torch_xla2.ops import jaten, jtorch, jc10d, jtorchvision_nms, jtorchvision_roi, ops_registry
self._ops.update(ops_registry.all_aten_ops)
self._ops.update(ops_registry.all_torch_functions)

Expand Down
Loading