Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a -1 non-broadcastable constraint encoding for static shapes #1280

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions aesara/link/jax/dispatch/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,21 @@ def shape_i(x):
def jax_funcify_SpecifyShape(op, **kwargs):
def specifyshape(x, *shape):
assert x.ndim == len(shape)
assert jnp.all(x.shape == tuple(shape)), (
"got shape",
x.shape,
"expected",
shape,
)
for s_x, s in zip(x.shape, shape):
if s == -1:
assert s_x != 1, (
"got shape",
s_x,
"expected",
s,
)
elif s > -1:
assert s_x == s, (
"got shape",
s_x,
"expected",
s,
)
return x

return specifyshape
Expand Down
17 changes: 13 additions & 4 deletions aesara/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from contextlib import contextmanager
from functools import singledispatch
from textwrap import dedent
from textwrap import dedent, indent
from typing import Union

import numba
Expand Down Expand Up @@ -673,17 +673,26 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]

func_conditions = [
f"assert x.shape[{i}] == {shape_input_names}"
for i, (shape_input, shape_input_names) in enumerate(
dedent(
f"""
if {shape_name} == -1:
assert x.shape[{i}] != 1
if {shape_name} > -1:
assert x.shape[{i}] == {shape_name}
"""
)
for i, (shape_input, shape_name) in enumerate(
zip(shape_inputs, shape_input_names)
)
if shape_input is not NoneConst
]

conditions_block = "\n".join(func_conditions)

func = dedent(
f"""
def specify_shape(x, {create_arg_string(shape_input_names)}):
{"; ".join(func_conditions)}
{indent(conditions_block, ' ' * 12)}
return x
"""
)
Expand Down
99 changes: 73 additions & 26 deletions aesara/tensor/shape.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from numbers import Number
from textwrap import dedent
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Sequence, Tuple, Union

import numpy as np

Expand All @@ -17,11 +17,65 @@
from aesara.tensor import basic as at
from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor
from aesara.tensor.type import (
DenseTensorType,
TensorType,
int_dtypes,
shape_encode,
tensor,
)
from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorConstant, TensorVariable


def filter_shape_vars(
ref_shape: Tuple[int, ...], shape: Sequence[Variable], shape_is_encoded: bool = True
) -> Tuple[int, ...]:
r"""Compute the most \"informative\" shape based on a static reference.

Parameters
----------
ref_shape
A static shape reference using static shape constraint encoding.
shape
A symbolic shape.
shape_is_encoded
If ``True``, `shape` is assumed to be static shape constraint encoded.

Returns
-------
The most specific, and compatible (with `ref_shape`), static shape
constraint encoded values.
"""
shape_bottom = shape_encode(None)
type_shape = ()
for i, (xts, s) in enumerate(zip(ref_shape, shape)):

try:
# TODO FIXME: We shouldn't need to do this; let a rewrite
# do constant folding and update the `TensorType`s.
s_val = at.get_scalar_constant_value(s)

if isinstance(s_val, np.ndarray):
s_val = s_val.item()

if shape_is_encoded or s_val is not None and s_val > 0:
type_s = shape_encode(s_val)
else:
type_s = shape_bottom
except NotScalarConstantError:
type_s = shape_bottom

if not (xts <= -1 or type_s <= -1 or type_s == xts):
raise AssertionError(
f"SpecifyShape: Got shape {xts} at index {i}, expected {type_s}."
)

type_shape += (max(type_s, xts),)

return type_shape


def register_shape_c_code(type, code, version=()):
"""
Tell Shape Op how to generate C code for an Aesara Type.
Expand Down Expand Up @@ -394,7 +448,6 @@ class SpecifyShape(COp):
_f16_ok = True

def make_node(self, x, *shape):
from aesara.tensor.basic import get_scalar_constant_value

x = at.as_tensor_variable(x)

Expand All @@ -417,18 +470,7 @@ def make_node(self, x, *shape):
f"Input `x` is {x.type.ndim}-dimensional and will never match a shape of length {len(shape)}."
)

type_shape = [None] * x.ndim
for i, (xts, s) in enumerate(zip(x.type.shape, shape)):
if xts is not None:
type_shape[i] = xts
else:
try:
type_s = get_scalar_constant_value(s)
if type_s is not None:
type_shape[i] = int(type_s)
except NotScalarConstantError:
pass

type_shape = filter_shape_vars(x.type.shape_encoded, shape)
out_var = x.type.clone(shape=type_shape)()

return Apply(self, [x, *shape], [out_var])
Expand All @@ -441,10 +483,10 @@ def perform(self, node, inp, out_):
raise AssertionError(
f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
)
if not all(xs == s for xs, s in zip(x.shape, shape) if s is not None):
raise AssertionError(
f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
)
for xs, s in zip(x.shape, shape):
if (s == -1 and xs == 1) or (s is not None and s > -1 and not xs == s):
raise AssertionError(f"SpecifyShape: Got shape {xs}, expected {s}.")

out[0] = x

def infer_shape(self, fgraph, node, shapes):
Expand All @@ -454,11 +496,11 @@ def infer_shape(self, fgraph, node, shapes):
for dim in range(node.inputs[0].type.ndim):
s = shape[dim]
try:
s = at.get_scalar_constant_value(s)
# We assume that `None` shapes are always retrieved by
s = shape_encode(at.get_scalar_constant_value(s))
# We assume that negative shapes are always retrieved by
# `get_scalar_constant_value`, and only in that case do we default to
# the shape of the input variable
if s is None:
if s < 0:
s = xshape[dim]
except NotScalarConstantError:
pass
Expand Down Expand Up @@ -502,6 +544,9 @@ def c_code(self, node, name, i_names, o_names, sub):
);
{fail};
}}

npy_intp shp;
npy_intp actual_shp;
"""
)

Expand All @@ -510,9 +555,11 @@ def c_code(self, node, name, i_names, o_names, sub):
continue
code += dedent(
f"""
if (py_{shp_name} != Py_None){{
dtype_{shp_name} shp = ((dtype_{shp_name}*)PyArray_GETPTR1({shp_name}, 0))[0];
if (PyArray_DIMS({x_name})[{i}] != shp) {{
shp = ((dtype_{shp_name}*)PyArray_GETPTR1({shp_name}, 0))[0];

if (shp > -2) {{
actual_shp = PyArray_DIMS({x_name})[{i}];
if (actual_shp == -1 && shp == 1 || actual_shp != shp) {{
PyErr_Format(PyExc_AssertionError,
"SpecifyShape: dim %d of input has shape %d, expected %d.",
{i}, PyArray_DIMS({x_name})[{i}], shp
Expand All @@ -533,7 +580,7 @@ def c_code(self, node, name, i_names, o_names, sub):
return code

def c_code_cache_version(self):
return (2,)
return (3,)


_specify_shape = SpecifyShape()
Expand Down
Loading