From ecbb6d0bd88551b8a16890c96ee1cd928cba7d62 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 14 Jan 2025 16:05:51 -0800 Subject: [PATCH 1/7] [torchlib] Implement type promotion --- .../torch_lib/_type_promotion.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 onnxscript/function_libs/torch_lib/_type_promotion.py diff --git a/onnxscript/function_libs/torch_lib/_type_promotion.py b/onnxscript/function_libs/torch_lib/_type_promotion.py new file mode 100644 index 000000000..239617c13 --- /dev/null +++ b/onnxscript/function_libs/torch_lib/_type_promotion.py @@ -0,0 +1,65 @@ +"""Type promotion functions for op implementations.""" + +from typing import Sequence +from onnxscript import ir + +def _get_higher_dtype(a: ir.DataType, b: ir.DataType) -> ir.DataType: + """Get the higher dtype of two dtypes.""" + # Reference: https://github.com/pytorch/pytorch/blob/bdd942efd76e74baa5dd0a262f7c843ddfe2e11b/torch/_prims_common/__init__.py#L1160 + if a == b: + return a + + if a is None: + return b + + if b is None: + return a + + ordered_datatypes = ( + (ir.DataType.BOOL,), + (ir.DataType.UINT8, ir.DataType.INT8), + (ir.DataType.INT16,), + (ir.DataType.INT32,), + (ir.DataType.INT64,), + (ir.DataType.FLOAT16, ir.DataType.BFLOAT16), + (ir.DataType.FLOAT,), + (ir.DataType.DOUBLE,), + (ir.DataType.COMPLEX64,), + (ir.DataType.COMPLEX128,), + ) + + for idx, dtypes in enumerate(ordered_datatypes): + if a in dtypes and b in dtypes: + return ordered_datatypes[idx + 1][0] + if a in dtypes: + return b + if b in dtypes: + return a + + raise ValueError(f"Unexpected data types: {a}, {b}") + + +def promote_types(op, values: Sequence[ir.Value]) -> Sequence[ir.Value]: + """Promote the types of the given values.""" + if not values: + return () + + for value in values: + if value.dtype is None: + raise ValueError(f"Value {value} does not have dtype information and cannot be promoted.") + + promoted = values[0].dtype + assert promoted is not None + for value in values[1:]: + dtype = value.dtype + assert dtype is not None + promoted = _get_higher_dtype(promoted, dtype) + + results = [] + for value in values: + if value.dtype != promoted: + results.append(op.Cast(value, to=promoted)) + else: + results.append(value) + + return results From b507bf52a0f8eb127b9eb986e7b1fc8696c52698 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 14 Jan 2025 16:26:27 -0800 Subject: [PATCH 2/7] wip --- .../function_libs/torch_lib/_type_promotion.py | 5 ++++- onnxscript/function_libs/torch_lib/ops/core.py | 16 +++++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/_type_promotion.py b/onnxscript/function_libs/torch_lib/_type_promotion.py index 239617c13..28ab83e64 100644 --- a/onnxscript/function_libs/torch_lib/_type_promotion.py +++ b/onnxscript/function_libs/torch_lib/_type_promotion.py @@ -58,7 +58,10 @@ def promote_types(op, values: Sequence[ir.Value]) -> Sequence[ir.Value]: results = [] for value in values: if value.dtype != promoted: - results.append(op.Cast(value, to=promoted)) + new_val = op.Cast(value, to=promoted) + new_val.dtype = promoted + new_val.shape = value.shape + results.append(new_val) else: results.append(value) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a1793858e..97cfd315b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -48,6 +48,7 @@ TTensor2, TTensorOrString, ) +from onnxscript.function_libs.torch_lib import _type_promotion from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType @@ -160,9 +161,9 @@ def aten_acosh(self: TFloat) -> TFloat: @torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True) -def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: +def aten_add(self: TTensor, other: TTensor2, alpha: float = 1.0) -> TensorType: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - # TODO(microsoft/onnxruntime#15977): Improve fp16 precision + self, other = _type_promotion.promote_types(op, [self, other]) if alpha != 1.0: alpha = op.CastLike(alpha, other) other = op.Mul(other, alpha) @@ -175,6 +176,7 @@ def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" + # TODO(justinchuby): Type promotion for complex numbers return aten_add(self, other, alpha=alpha) @@ -203,12 +205,14 @@ def aten_addbmm( @torch_op("aten::addcdiv") -def aten_addcdiv(self: TFloat, tensor1: TFloat, tensor2: TFloat, value: float = 1.0) -> TFloat: +def aten_addcdiv(self: TensorType, tensor1: TensorType, tensor2: TensorType, value: float = 1.0) -> TensorType: """addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor Performs the element-wise division of tensor1 by tensor2, multiplies the result by the scalar value and adds it to self. """ + # FIXME: Int to float + self, tensor1, tensor2 = _type_promotion.promote_types(op, [self, tensor1, tensor2]) return op.Add(self, op.Mul(op.Div(tensor1, tensor2), value)) @@ -225,6 +229,7 @@ def aten_addcmul( Performs the element-wise multiplication of tensor1 by tensor2, multiplies the result by the scalar value and adds it to self. """ + self, tensor1, tensor2 = _type_promotion.promote_types(op, [self, tensor1, tensor2]) # Follow the order in https://github.com/pytorch/pytorch/blob/29e3fddb082b5a14262a7246bc62381a55199d45/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp#L47 # TODO(#811): Understand fp16 accuracy issue @@ -258,12 +263,13 @@ def aten_addmv( @torch_op("aten::addr", traceable=True) def aten_addr( - self: TReal, vec1: TReal, vec2: TReal, beta: float = 1.0, alpha: float = 1.0 -) -> TReal: + self: TensorType, vec1: TensorType, vec2: TensorType, beta: float = 1.0, alpha: float = 1.0 +) -> TensorType: """addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor Performs the outer-product of vectors vec1 and vec2 and adds it to the matrix input. """ + self, vec1, vec2 = _type_promotion.promote_types(op, [self, vec1, vec2]) vec1_shape = op.Constant(value_ints=[-1, 1]) vec2_shape = op.Constant(value_ints=[1, -1]) vec1_reshaped = op.Reshape(vec1, vec1_shape) From 1033ac50c1ce9a22991bd76dd840da34cd0ee310 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 15 Jan 2025 12:48:14 -0800 Subject: [PATCH 3/7] , trace_only=True --- onnxscript/function_libs/torch_lib/ops/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 97cfd315b..afebde5ad 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -204,7 +204,7 @@ def aten_addbmm( return op.Add(scaled_self, op.Mul(reduced_batches, alpha)) -@torch_op("aten::addcdiv") +@torch_op("aten::addcdiv", trace_only=True) def aten_addcdiv(self: TensorType, tensor1: TensorType, tensor2: TensorType, value: float = 1.0) -> TensorType: """addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor @@ -217,7 +217,7 @@ def aten_addcdiv(self: TensorType, tensor1: TensorType, tensor2: TensorType, val return op.Add(self, op.Mul(op.Div(tensor1, tensor2), value)) -@torch_op("aten::addcmul") +@torch_op("aten::addcmul", trace_only=True) def aten_addcmul( self: TReal, tensor1: TReal, @@ -261,7 +261,7 @@ def aten_addmv( return op.Add(op.Mul(self, beta), op.Mul(op.MatMul(mat, vec), alpha)) -@torch_op("aten::addr", traceable=True) +@torch_op("aten::addr", trace_only=True) def aten_addr( self: TensorType, vec1: TensorType, vec2: TensorType, beta: float = 1.0, alpha: float = 1.0 ) -> TensorType: From eeeae08698ff72e501253df737293e3d1ee0fe53 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 15 Jan 2025 12:48:37 -0800 Subject: [PATCH 4/7] format --- onnxscript/function_libs/torch_lib/ops/core.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index afebde5ad..f3130e82c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -205,7 +205,9 @@ def aten_addbmm( @torch_op("aten::addcdiv", trace_only=True) -def aten_addcdiv(self: TensorType, tensor1: TensorType, tensor2: TensorType, value: float = 1.0) -> TensorType: +def aten_addcdiv( + self: TensorType, tensor1: TensorType, tensor2: TensorType, value: float = 1.0 +) -> TensorType: """addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor Performs the element-wise division of tensor1 by tensor2, multiplies the result @@ -218,12 +220,7 @@ def aten_addcdiv(self: TensorType, tensor1: TensorType, tensor2: TensorType, val @torch_op("aten::addcmul", trace_only=True) -def aten_addcmul( - self: TReal, - tensor1: TReal, - tensor2: TReal, - value: float = 1.0, -) -> TReal: +def aten_addcmul(self: TReal, tensor1: TReal, tensor2: TReal, value: float = 1.0) -> TReal: """addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor Performs the element-wise multiplication of tensor1 by tensor2, multiplies the From c6e300c00f943b71cae5f64e0878941bb9fecff9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 15 Jan 2025 12:48:58 -0800 Subject: [PATCH 5/7] TensorType --- onnxscript/function_libs/torch_lib/ops/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f3130e82c..14aa2dbd8 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -220,7 +220,9 @@ def aten_addcdiv( @torch_op("aten::addcmul", trace_only=True) -def aten_addcmul(self: TReal, tensor1: TReal, tensor2: TReal, value: float = 1.0) -> TReal: +def aten_addcmul( + self: TensorType, tensor1: TensorType, tensor2: TensorType, value: float = 1.0 +) -> TensorType: """addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor Performs the element-wise multiplication of tensor1 by tensor2, multiplies the From fd915c1495600fb68f6a0f840b03b8a08ace03be Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 15 Jan 2025 12:53:23 -0800 Subject: [PATCH 6/7] aten_addcdiv --- onnxscript/function_libs/torch_lib/ops/core.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 14aa2dbd8..90f2d719d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -213,10 +213,15 @@ def aten_addcdiv( Performs the element-wise division of tensor1 by tensor2, multiplies the result by the scalar value and adds it to self. """ - # FIXME: Int to float + # FIXME(justinchuby): Int to float promotion self, tensor1, tensor2 = _type_promotion.promote_types(op, [self, tensor1, tensor2]) + quotient = op.Div(tensor1, tensor2) + if value == 1.0: + quotient_scaled = quotient + else: + quotient_scaled = op.Mul(quotient, op.CastLike(value, tensor1)) - return op.Add(self, op.Mul(op.Div(tensor1, tensor2), value)) + return op.Add(self, quotient_scaled) @torch_op("aten::addcmul", trace_only=True) @@ -231,8 +236,11 @@ def aten_addcmul( self, tensor1, tensor2 = _type_promotion.promote_types(op, [self, tensor1, tensor2]) # Follow the order in https://github.com/pytorch/pytorch/blob/29e3fddb082b5a14262a7246bc62381a55199d45/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp#L47 - # TODO(#811): Understand fp16 accuracy issue - return op.Add(self, op.Mul(op.Mul(value, tensor1), tensor2)) + if value == 1.0: + tensor_1_scaled = tensor1 + else: + tensor_1_scaled = op.Mul(op.CastLike(value, tensor1), tensor1) + return op.Add(self, op.Mul(tensor_1_scaled, tensor2)) @torch_op("aten::addmm", trace_only=True) From 5c45f70fa826fdc80006685907a332cceb1d283c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 28 Jan 2025 16:50:12 -0800 Subject: [PATCH 7/7] TensorType --- onnxscript/function_libs/torch_lib/ops/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 5d67b6736..55731600b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -161,7 +161,7 @@ def aten_acosh(self: TFloat) -> TFloat: @torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True) -def aten_add(self: TTensor, other: TTensor2, alpha: float = 1.0) -> TensorType: +def aten_add(self: TensorType, other: TensorType, alpha: float = 1.0) -> TensorType: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" self, other = _type_promotion.promote_types(op, [self, other]) if alpha != 1.0: @@ -171,10 +171,10 @@ def aten_add(self: TTensor, other: TTensor2, alpha: float = 1.0) -> TensorType: @torch_op(("aten::add.Tensor", "aten::add.Scalar"), trace_only=True, complex=True) -def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: +def aten_add_complex(self: TensorType, other: TensorType, alpha: float = 1.0) -> TensorType: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - # TODO(justinchuby): Type promotion for complex numbers + self, other = _type_promotion.promote_types(op, [self, other]) return aten_add(self, other, alpha=alpha)