diff --git a/onnxscript/function_libs/torch_lib/_type_promotion.py b/onnxscript/function_libs/torch_lib/_type_promotion.py new file mode 100644 index 0000000000..28ab83e64e --- /dev/null +++ b/onnxscript/function_libs/torch_lib/_type_promotion.py @@ -0,0 +1,68 @@ +"""Type promotion functions for op implementations.""" + +from typing import Sequence +from onnxscript import ir + +def _get_higher_dtype(a: ir.DataType, b: ir.DataType) -> ir.DataType: + """Get the higher dtype of two dtypes.""" + # Reference: https://github.com/pytorch/pytorch/blob/bdd942efd76e74baa5dd0a262f7c843ddfe2e11b/torch/_prims_common/__init__.py#L1160 + if a == b: + return a + + if a is None: + return b + + if b is None: + return a + + ordered_datatypes = ( + (ir.DataType.BOOL,), + (ir.DataType.UINT8, ir.DataType.INT8), + (ir.DataType.INT16,), + (ir.DataType.INT32,), + (ir.DataType.INT64,), + (ir.DataType.FLOAT16, ir.DataType.BFLOAT16), + (ir.DataType.FLOAT,), + (ir.DataType.DOUBLE,), + (ir.DataType.COMPLEX64,), + (ir.DataType.COMPLEX128,), + ) + + for idx, dtypes in enumerate(ordered_datatypes): + if a in dtypes and b in dtypes: + return ordered_datatypes[idx + 1][0] + if a in dtypes: + return b + if b in dtypes: + return a + + raise ValueError(f"Unexpected data types: {a}, {b}") + + +def promote_types(op, values: Sequence[ir.Value]) -> Sequence[ir.Value]: + """Promote the types of the given values.""" + if not values: + return () + + for value in values: + if value.dtype is None: + raise ValueError(f"Value {value} does not have dtype information and cannot be promoted.") + + promoted = values[0].dtype + assert promoted is not None + for value in values[1:]: + dtype = value.dtype + assert dtype is not None + promoted = _get_higher_dtype(promoted, dtype) + + results = [] + for value in values: + if value.dtype != promoted: + new_val = op.Cast(value, to=promoted) + new_val.dtype = promoted + new_val.shape = value.shape + results.append(new_val) + else: + results.append(value) + + return results diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 576aeb17a0..671ada8edd 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -48,6 +48,7 @@ TTensor2, TTensorOrString, ) +from onnxscript.function_libs.torch_lib import _type_promotion from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType @@ -159,9 +160,9 @@ def aten_acosh(self: TFloat) -> TFloat: @torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True) -def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: +def aten_add(self: TensorType, other: TensorType, alpha: float = 1.0) -> TensorType: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - # TODO(microsoft/onnxruntime#15977): Improve fp16 precision + self, other = _type_promotion.promote_types(op, [self, other]) if alpha != 1.0: alpha = op.CastLike(alpha, other) other = op.Mul(other, alpha) @@ -169,9 +170,10 @@ def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: @torch_op(("aten::add.Tensor", "aten::add.Scalar"), trace_only=True, complex=True) -def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: +def aten_add_complex(self: TensorType, other: TensorType, alpha: float = 1.0) -> TensorType: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" + self, other = _type_promotion.promote_types(op, [self, other]) return aten_add(self, other, alpha=alpha) @@ -199,33 +201,43 @@ def aten_addbmm( return op.Add(scaled_self, op.Mul(reduced_batches, alpha)) -@torch_op("aten::addcdiv") -def aten_addcdiv(self: TFloat, tensor1: TFloat, tensor2: TFloat, value: float = 1.0) -> TFloat: +@torch_op("aten::addcdiv", trace_only=True) +def aten_addcdiv( + self: TensorType, tensor1: TensorType, tensor2: TensorType, value: float = 1.0 +) -> TensorType: """addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor Performs the element-wise division of tensor1 by tensor2, multiplies the result by the scalar value and adds it to self. """ + # FIXME(justinchuby): Int to float promotion + self, tensor1, tensor2 = _type_promotion.promote_types(op, [self, tensor1, tensor2]) + quotient = op.Div(tensor1, tensor2) + if value == 1.0: + quotient_scaled = quotient + else: + quotient_scaled = op.Mul(quotient, op.CastLike(value, tensor1)) - return op.Add(self, op.Mul(op.Div(tensor1, tensor2), value)) + return op.Add(self, quotient_scaled) -@torch_op("aten::addcmul") +@torch_op("aten::addcmul", trace_only=True) def aten_addcmul( - self: TReal, - tensor1: TReal, - tensor2: TReal, - value: float = 1.0, -) -> TReal: + self: TensorType, tensor1: TensorType, tensor2: TensorType, value: float = 1.0 +) -> TensorType: """addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor Performs the element-wise multiplication of tensor1 by tensor2, multiplies the result by the scalar value and adds it to self. """ + self, tensor1, tensor2 = _type_promotion.promote_types(op, [self, tensor1, tensor2]) # Follow the order in https://github.com/pytorch/pytorch/blob/29e3fddb082b5a14262a7246bc62381a55199d45/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp#L47 - # TODO(#811): Understand fp16 accuracy issue - return op.Add(self, op.Mul(op.Mul(value, tensor1), tensor2)) + if value == 1.0: + tensor_1_scaled = tensor1 + else: + tensor_1_scaled = op.Mul(op.CastLike(value, tensor1), tensor1) + return op.Add(self, op.Mul(tensor_1_scaled, tensor2)) @torch_op("aten::addmm", trace_only=True) @@ -255,12 +267,13 @@ def aten_addmv( @torch_op("aten::addr", trace_only=True) def aten_addr( - self: TReal, vec1: TReal, vec2: TReal, beta: float = 1.0, alpha: float = 1.0 -) -> TReal: + self: TensorType, vec1: TensorType, vec2: TensorType, beta: float = 1.0, alpha: float = 1.0 +) -> TensorType: """addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor Performs the outer-product of vectors vec1 and vec2 and adds it to the matrix input. """ + self, vec1, vec2 = _type_promotion.promote_types(op, [self, vec1, vec2]) vec1_shape = op.Constant(value_ints=[-1, 1]) vec2_shape = op.Constant(value_ints=[1, -1]) vec1_reshaped = op.Reshape(vec1, vec1_shape)