-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpy2ty.py
135 lines (114 loc) · 5.23 KB
/
py2ty.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def create_tracers(pyval):
return replace_concrete_values_with_tracers(pyval)
import builtins
import numpy
from xdsl.dialects.builtin import i32, i64
from eff_types import i1, i64
from eff_types import si2, si4, si8, si16, si32, si64
from eff_types import ui2, ui4, ui8, ui16, ui32, ui64
from eff_types import f16, f32, f64
from eff_types import complexf32, complexf64
from stablehlo import TensorType
def convert_python_type_to_mlir_type(pytype):
match pytype:
case builtins.bool: return i1
case builtins.int: return i64
case builtins.float: return f64
case builtins.complex: return complexf64
msg = f"Unknown conversion from type {pytype} to MLIR type."
raise ValueError(msg)
import unittest
class TestPythonTypeToMLIRType(unittest.TestCase):
def test_python_to_mlir_type(self):
input = [bool, int, float, complex]
output = [i1, i64, f64, complexf64]
for arg, exp in zip(input, output):
with self.subTest(arg=arg, exp=exp):
obs = convert_python_type_to_mlir_type(arg)
self.assertEqual(obs, exp)
def convert_numpy_dtype_to_mlir_type(dtp):
match dtp:
case builtins.bool | numpy.bool_ : return i1
case numpy.int8 | numpy.byte : return si8
case numpy.uint8 | numpy.ubyte : return ui8
case numpy.int16 | numpy.short : return si16
case numpy.uint16 | numpy.ushort : return ui16
case numpy.int32 | numpy.intc : return si32
case numpy.uint32 | numpy.uintc : return ui32
case numpy.int64 : return i64
case numpy.uint64 : return ui64
case numpy.float16 | numpy.half : return f16
case numpy.float32 | numpy.single : return f32
case numpy.float64 | numpy.double : return f64
case numpy.complex64 | numpy.csingle : return complexf32
case numpy.complex128 | numpy.cdouble : return complexf64
msg = f"Unknown conversion from dtype {dtp} to MLIR type."
raise ValueError(msg)
def convert_numpy_array_to_mlir_type(pyval):
assert isinstance(pyval, numpy.ndarray)
element_type = convert_numpy_dtype_to_mlir_type(pyval.dtype)
return TensorType(element_type, pyval.shape)
class TestConvertNumpyArray(unittest.TestCase):
def test_convert_numpy_array(self):
convert_numpy_array_to_mlir_type(numpy.array(0))
def get_mlir_type_from_python_value(pyval):
ty = type(pyval)
match ty:
case numpy.ndarray: return convert_numpy_array_to_mlir_type(pyval)
case _: return convert_python_type_to_mlir_type(ty)
class TestGetMLIRTypeFromPythonValue(unittest.TestCase):
def test_python_val_to_mlir_type(self):
input = [False, 0, 0., 0j]
output = [i1, i64, f64, complexf64]
for arg, exp in zip(input, output):
with self.subTest(arg=arg, exp=exp):
obs = get_mlir_type_from_python_value(arg)
self.assertEqual(obs, exp)
def get_mlir_types_from_python_values(flat_pyvals):
return map(get_mlir_type_from_python_value, flat_pyvals)
from pennylane import pytrees
def replace_concrete_values_with_mlir_types(pyval):
flat_vals, shape = pytrees.flatten(pyval)
flat_mlirtys = get_mlir_types_from_python_values(flat_vals)
return pytrees.unflatten(flat_mlirtys, shape)
class TestReplaceConcreteValuesWithMLIRTypes(unittest.TestCase):
def test_python_unflattened_val_to_mlir_type(self):
input = [[False], (0,), {"a":0.}, 0j]
output = [[i1], (i64,), {"a":f64}, complexf64]
for arg, exp in zip(input, output):
with self.subTest(arg=arg, exp=exp):
obs = replace_concrete_values_with_mlir_types(arg)
self.assertEqual(obs, exp)
from xdsl.ir import Block
from pennylane import pytrees
def replace_concrete_values_with_block_arguments(pyval):
shaped_mlirtys = replace_concrete_values_with_mlir_types(pyval)
mlirtys, shape = pytrees.flatten(shaped_mlirtys)
block = Block([], arg_types=mlirtys)
ssavalues = block.args
return pytrees.unflatten(ssavalues, shape)
class TestPythonValueToSSAValue(unittest.TestCase):
def test_python_to_mlir_type(self):
from xdsl.ir.core import BlockArgument
input = [False, 0, 0., 0j]
for arg in input:
with self.subTest(arg=arg):
obs = replace_concrete_values_with_block_arguments(arg)
assert isinstance(obs, BlockArgument)
from tracer import StableHLOTracer, Tracer
def get_tracer_from_ssavalue(ssavalue):
return StableHLOTracer(ssavalue)
class TestPythonValueToTracer(unittest.TestCase):
def test_python_value_to_tracer(self):
block = Block([], arg_types=[i1, i64, f64, complexf64])
for arg in block.args:
with self.subTest(arg=arg):
obs = get_tracer_from_ssavalue(arg)
assert isinstance(obs, Tracer)
def replace_concrete_values_with_tracers(pyval):
shaped_ssavalues = replace_concrete_values_with_block_arguments(pyval)
ssavals, shape = pytrees.flatten(shaped_ssavalues)
tracers = map(get_tracer_from_ssavalue, ssavals)
return pytrees.unflatten(tracers, shape)
if "__main__" == __name__:
unittest.main()