Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Implement type promotion #2010

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions onnxscript/function_libs/torch_lib/_type_promotion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Type promotion functions for op implementations."""

Check warning

Code scanning / lintrunner

RUFF/CPY001 Warning

Missing copyright notice at top of file.
See https://docs.astral.sh/ruff/rules/missing-copyright-notice

Check warning

Code scanning / lintrunner

RUFF/format Warning

Run lintrunner -a to apply this patch.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.

from typing import Sequence

Check warning

Code scanning / lintrunner

RUFF/I001 Warning

Import block is un-sorted or un-formatted.
See https://docs.astral.sh/ruff/rules/unsorted-imports
from onnxscript import ir

def _get_higher_dtype(a: ir.DataType, b: ir.DataType) -> ir.DataType:
"""Get the higher dtype of two dtypes."""
# Reference: https://github.com/pytorch/pytorch/blob/bdd942efd76e74baa5dd0a262f7c843ddfe2e11b/torch/_prims_common/__init__.py#L1160
if a == b:
return a

if a is None:
return b

if b is None:
return a

ordered_datatypes = (
(ir.DataType.BOOL,),
(ir.DataType.UINT8, ir.DataType.INT8),
(ir.DataType.INT16,),
(ir.DataType.INT32,),
(ir.DataType.INT64,),
(ir.DataType.FLOAT16, ir.DataType.BFLOAT16),
(ir.DataType.FLOAT,),
(ir.DataType.DOUBLE,),
(ir.DataType.COMPLEX64,),
(ir.DataType.COMPLEX128,),
)

for idx, dtypes in enumerate(ordered_datatypes):
if a in dtypes and b in dtypes:
return ordered_datatypes[idx + 1][0]
if a in dtypes:
return b
if b in dtypes:
return a

raise ValueError(f"Unexpected data types: {a}, {b}")


def promote_types(op, values: Sequence[ir.Value]) -> Sequence[ir.Value]:
"""Promote the types of the given values."""
if not values:
return ()

for value in values:
if value.dtype is None:
raise ValueError(f"Value {value} does not have dtype information and cannot be promoted.")

promoted = values[0].dtype
assert promoted is not None
for value in values[1:]:
dtype = value.dtype
assert dtype is not None
promoted = _get_higher_dtype(promoted, dtype)

results = []
for value in values:
if value.dtype != promoted:
new_val = op.Cast(value, to=promoted)
new_val.dtype = promoted
new_val.shape = value.shape
results.append(new_val)
else:
results.append(value)

return results
45 changes: 29 additions & 16 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
TTensor2,
TTensorOrString,
)
from onnxscript.function_libs.torch_lib import _type_promotion
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType

Expand Down Expand Up @@ -159,19 +160,20 @@ def aten_acosh(self: TFloat) -> TFloat:


@torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True)
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
def aten_add(self: TensorType, other: TensorType, alpha: float = 1.0) -> TensorType:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
# TODO(microsoft/onnxruntime#15977): Improve fp16 precision
self, other = _type_promotion.promote_types(op, [self, other])
if alpha != 1.0:
alpha = op.CastLike(alpha, other)
other = op.Mul(other, alpha)
return op.Add(self, other)


@torch_op(("aten::add.Tensor", "aten::add.Scalar"), trace_only=True, complex=True)
def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
def aten_add_complex(self: TensorType, other: TensorType, alpha: float = 1.0) -> TensorType:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""

self, other = _type_promotion.promote_types(op, [self, other])
return aten_add(self, other, alpha=alpha)


Expand Down Expand Up @@ -199,33 +201,43 @@ def aten_addbmm(
return op.Add(scaled_self, op.Mul(reduced_batches, alpha))


@torch_op("aten::addcdiv")
def aten_addcdiv(self: TFloat, tensor1: TFloat, tensor2: TFloat, value: float = 1.0) -> TFloat:
@torch_op("aten::addcdiv", trace_only=True)
def aten_addcdiv(
self: TensorType, tensor1: TensorType, tensor2: TensorType, value: float = 1.0
) -> TensorType:
"""addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor

Performs the element-wise division of tensor1 by tensor2, multiplies the result
by the scalar value and adds it to self.
"""
# FIXME(justinchuby): Int to float promotion
self, tensor1, tensor2 = _type_promotion.promote_types(op, [self, tensor1, tensor2])
quotient = op.Div(tensor1, tensor2)
if value == 1.0:
quotient_scaled = quotient
else:
quotient_scaled = op.Mul(quotient, op.CastLike(value, tensor1))

return op.Add(self, op.Mul(op.Div(tensor1, tensor2), value))
return op.Add(self, quotient_scaled)


@torch_op("aten::addcmul")
@torch_op("aten::addcmul", trace_only=True)
def aten_addcmul(
self: TReal,
tensor1: TReal,
tensor2: TReal,
value: float = 1.0,
) -> TReal:
self: TensorType, tensor1: TensorType, tensor2: TensorType, value: float = 1.0
) -> TensorType:
"""addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor

Performs the element-wise multiplication of tensor1 by tensor2, multiplies the
result by the scalar value and adds it to self.
"""
self, tensor1, tensor2 = _type_promotion.promote_types(op, [self, tensor1, tensor2])

# Follow the order in https://github.com/pytorch/pytorch/blob/29e3fddb082b5a14262a7246bc62381a55199d45/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp#L47
# TODO(#811): Understand fp16 accuracy issue
return op.Add(self, op.Mul(op.Mul(value, tensor1), tensor2))
if value == 1.0:
tensor_1_scaled = tensor1
else:
tensor_1_scaled = op.Mul(op.CastLike(value, tensor1), tensor1)
return op.Add(self, op.Mul(tensor_1_scaled, tensor2))


@torch_op("aten::addmm", trace_only=True)
Expand Down Expand Up @@ -255,12 +267,13 @@ def aten_addmv(

@torch_op("aten::addr", trace_only=True)
def aten_addr(
self: TReal, vec1: TReal, vec2: TReal, beta: float = 1.0, alpha: float = 1.0
) -> TReal:
self: TensorType, vec1: TensorType, vec2: TensorType, beta: float = 1.0, alpha: float = 1.0
) -> TensorType:
"""addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor

Performs the outer-product of vectors vec1 and vec2 and adds it to the matrix input.
"""
self, vec1, vec2 = _type_promotion.promote_types(op, [self, vec1, vec2])
vec1_shape = op.Constant(value_ints=[-1, 1])
vec2_shape = op.Constant(value_ints=[1, -1])
vec1_reshaped = op.Reshape(vec1, vec1_shape)
Expand Down
Loading