Skip to content

Commit

Permalink
Add PyTorch to TVM decomposition for Masked Scatter op (#62)
Browse files Browse the repository at this point in the history
* masked_scatter op support

* masked_scatter op support
  • Loading branch information
ashokkumarkannan1 authored Feb 25, 2025
1 parent 93f9fb6 commit 690fb99
Showing 1 changed file with 91 additions and 0 deletions.
91 changes: 91 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3409,6 +3409,96 @@ def masked_select(self, inputs, input_types):
indices = self.nonzero([mask], input_types, is_numpy_style=True)
return _op.adv_index([inputs[0]] + [indices[i] for i in range(indices.size)])

def masked_scatter(self, inputs, input_types):
"""
Performs a masked scatter operation on the input tensor `data`, replacing values where
the `mask` is True with corresponding values from the `source` tensor.
This implementation follows the decomposition of the PyTorch `aten::masked_scatter` operation:
"aten::masked_scatter": https://github.com/pytorch/pytorch/blob/f95bdf5e6c8ea482ba6f64d655513b6a191ac142/torch/_inductor/decomposition.py#L830
The masked scatter requires `aten::unsafe_masked_index`, which is implemented in our TVM
decomposition based on the following PyTorch function:
"aten::unsafe_masked_index": https://github.com/pytorch/pytorch/blob/f95bdf5e6c8ea482ba6f64d655513b6a191ac142/aten/src/ATen/native/TensorAdvancedIndexing.cpp#L773
The unsafe masked index is adapted from PyTorch's TensorAdvancedIndexing.
Args:
inputs (list): A list containing the input tensors:
- `data`: The tensor to be updated.
- `mask`: A boolean tensor, where True values indicate positions to scatter the `source` tensor.
- `source`: The tensor with values to scatter into `data` where `mask` is True.
input_types (list): A list of the types of input tensors (not used in this implementation).
Returns:
Tensor: The updated tensor, with values from `source` scattered into `data` at positions where `mask` is True.
"""
data = inputs[0]
mask = inputs[1]
source = inputs[2]

# Count the number of True values in the mask
mask_true_count = _op.sum(_op.cast(mask, "int32"))
source_size = _op.prod(_op.shape_of(source))
assert _op.less_equal(mask_true_count, source_size), "Source tensor must have at least as many elements as ones in mask (source size: %s, mask true count: %s)" % (source_size, mask_true_count)

mask = _op.cast(mask, dtype="float32")

def broadcast_tensors(mask,data):
mask_shape = _infer_shape(mask)
data_shape = _infer_shape(data)
shape1 = [1] * (len(data_shape) - len(mask_shape)) + list(mask_shape)
shape2 = [1] * (len(mask_shape) - len(data_shape)) + list(data_shape)
common_shape = []
for dim1, dim2 in zip(shape1, shape2):
if dim1 == dim2:
common_shape.append(dim1)
elif dim1 == 1:
common_shape.append(dim2)
elif dim2 == 1:
common_shape.append(dim1)
else:
raise ValueError(f"Cannot broadcast shapes: {shape1} and {shape2}")
common_shape = tuple(common_shape)
mask = _op.broadcast_to(mask, common_shape)
data = _op.broadcast_to(data, common_shape)
return mask,data

mask,data = broadcast_tensors(mask,data)


data_shape = _infer_shape(data)

# Flatten the mask and compute the cumulative sum for indexing
flattened_mask = _op.reshape(mask, newshape=(-1))
cumsum_result = _op.cumsum(flattened_mask, axis=0, exclusive=False)

# Calculate the indices for the source tensor, shifted by 1
source_idx = _op.subtract(cumsum_result, _expr.const(1, dtype="float32"))
source_idx = _op.cast(source_idx, dtype="int32")

# Flatten data and source for easier manipulation
data_flat = _op.reshape(data, newshape=(-1))
source_flat = _op.reshape(source, newshape=(-1))
source_flat_shape = _infer_shape(source_flat)

# Clamp the indices to ensure they are within bounds
clamped_indices = _op.minimum(
_op.maximum(source_idx, _expr.const(0, dtype="int32")),
_op.subtract(_op.const([source_flat_shape], dtype="int32"), _expr.const(1, dtype="int32"))
)

# Perform the scatter operation
result = _op.transform.take(source_flat, _op.reshape(clamped_indices,(-1)), axis=0, mode='wrap')
fills = _op.full_like(result, _expr.const(0, dtype='int64'))
result = _op.where(flattened_mask, _op.reshape(result,(-1)), _op.reshape(fills,(-1)))
result = _op.where(flattened_mask, result, data_flat)

# Reshape the result back to the original data shape
result = _op.reshape(result, newshape=data_shape)
return result


def sort(self, inputs, input_types):
data = inputs[0]
dim = inputs[1]
Expand Down Expand Up @@ -4941,6 +5031,7 @@ def create_convert_map(self):
"aten::cumsum": self.cumsum,
"aten::masked_fill": self.masked_fill,
"aten::masked_select": self.masked_select,
"aten::masked_scatter": self.masked_scatter,
"aten::argsort": self.argsort,
"aten::sort": self.sort,
"aten::_unique2": self.unique,
Expand Down

0 comments on commit 690fb99

Please sign in to comment.