Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce algebraic transformations for sum-reduction operations #711

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/ref_transform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ Manipulating Instructions

.. autofunction:: add_barrier

Manipulating Reductions
-----------------------

.. autofunction:: hoist_invariant_multiplicative_terms_in_sum_reduction

.. autofunction:: extract_multiplicative_terms_in_sum_reduction_as_subst
Comment on lines +86 to +88
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.. autofunction:: hoist_invariant_multiplicative_terms_in_sum_reduction
.. autofunction:: extract_multiplicative_terms_in_sum_reduction_as_subst
.. automodule:: loopy.transform.reduction


Registering Library Routines
----------------------------

Expand Down
6 changes: 6 additions & 0 deletions loopy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@
from loopy.transform.parameter import assume, fix_parameters
from loopy.transform.save import save_and_reload_temporaries
from loopy.transform.add_barrier import add_barrier
from loopy.transform.reduction import (
hoist_invariant_multiplicative_terms_in_sum_reduction,
extract_multiplicative_terms_in_sum_reduction_as_subst)
from loopy.transform.callable import (register_callable,
merge, inline_callable_kernel, rename_callable)
from loopy.transform.pack_and_unpack_args import pack_and_unpack_args_for_call
Expand Down Expand Up @@ -247,6 +250,9 @@

"add_barrier",

"hoist_invariant_multiplicative_terms_in_sum_reduction",
"extract_multiplicative_terms_in_sum_reduction_as_subst",

"register_callable",
"merge",

Expand Down
292 changes: 292 additions & 0 deletions loopy/transform/reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
"""
.. currentmodule:: loopy

.. autofunction:: hoist_invariant_multiplicative_terms_in_sum_reduction

.. autofunction:: extract_multiplicative_terms_in_sum_reduction_as_subst
"""

__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni"

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

import pymbolic.primitives as p

from typing import (FrozenSet, TypeVar, Callable, List, Tuple, Iterable, Union, Any,
Optional, Sequence)
from loopy.symbolic import IdentityMapper, Reduction, CombineMapper
from loopy.kernel import LoopKernel
from loopy.kernel.data import SubstitutionRule
from loopy.diagnostic import LoopyError


# {{{ partition (copied from more-itertools)

Tpart = TypeVar("Tpart")


def partition(pred: Callable[[Tpart], bool],
iterable: Iterable[Tpart]) -> Tuple[List[Tpart],
List[Tpart]]:
"""
Use a predicate to partition entries into false entries and true
entries
"""
# Inspired from https://docs.python.org/3/library/itertools.html
# partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9
from itertools import tee, filterfalse
t1, t2 = tee(iterable)
return list(filterfalse(pred, t1)), list(filter(pred, t2))

# }}}


# {{{ hoist_reduction_invariant_terms

class EinsumTermsHoister(IdentityMapper):
"""
Mapper to hoist products out of a sum-reduction.

.. attribute:: reduction_inames

Inames of the reduction expressions to perform the hoisting.
"""
def __init__(self, reduction_inames: FrozenSet[str]):
super().__init__()
self.reduction_inames = reduction_inames

# type-ignore-reason: super-class.map_reduction returns 'Any'
def map_reduction(self, expr: Reduction # type: ignore[override]
) -> p.Expression:
if frozenset(expr.inames) != self.reduction_inames:
return super().map_reduction(expr)

from loopy.library.reduction import SumReductionOperation
from loopy.symbolic import get_dependencies
if isinstance(expr.operation, SumReductionOperation):
if isinstance(expr.expr, p.Product):
from pymbolic.primitives import flattened_product
multiplicative_terms = (flattened_product(self.rec(expr.expr)
.children)
.children)
else:
multiplicative_terms = (expr.expr,)

invariants, variants = partition(lambda x: (get_dependencies(x)
& self.reduction_inames),
multiplicative_terms)
if not variants:
# -> everything is invariant
return self.rec(expr.expr) * Reduction(
expr.operation,
inames=expr.inames,
expr=1, # FIXME: invalid dtype (not sure how?)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe try inferring the dtype?

allow_simultaneous=expr.allow_simultaneous)
if not invariants:
# -> nothing to hoist
return Reduction(
expr.operation,
inames=expr.inames,
expr=self.rec(expr.expr),
allow_simultaneous=expr.allow_simultaneous)

return p.Product(tuple(invariants)) * Reduction(
expr.operation,
inames=expr.inames,
expr=p.Product(tuple(variants)),
allow_simultaneous=expr.allow_simultaneous)
else:
return super().map_reduction(expr)


def hoist_invariant_multiplicative_terms_in_sum_reduction(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name?

  • inverse of "distribute"
  • "out of"

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this work on TranslationUnit?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about other operations that satisfy a distributive property? (Specify classes maybe? We don't need to enumerate things that obey the distributive law. We can't, not comprehensively, anyway.)

kernel: LoopKernel,
reduction_inames: Union[str, FrozenSet[str]],
within: Any = None
) -> LoopKernel:
"""
Hoists loop-invariant multiplicative terms in a sum-reduction expression.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify how this interacts with reductions with multiple inames.


:arg reduction_inames: The inames over which reduction is performed that defines
the reduction expression that is to be transformed.
:arg within: A match expression understood by :func:`loopy.match.parse_match`
that specifies the instructions over which the transformation is to be
performed.
"""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add example?

out[j] = sum(i, x[i]*y[j])
after hoisting wrt i:
out[j] = y[j]*sum(i, x[i])

from loopy.transform.instruction import map_instructions
if isinstance(reduction_inames, str):
reduction_inames = frozenset([reduction_inames])

if not (reduction_inames <= kernel.all_inames()):
raise ValueError(f"Some inames in '{reduction_inames}' not a part of"
" the kernel.")

term_hoister = EinsumTermsHoister(reduction_inames)

return map_instructions(kernel,
insn_match=within,
f=lambda x: x.with_transformed_expressions(term_hoister)
)

# }}}


# {{{ extract_multiplicative_terms_in_sum_reduction_as_subst

class ContainsSumReduction(CombineMapper):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be done by simply attempting the transformation (without checking first) and aborting if there's ambiguity? (That would make this guy redundant, possibly.)

"""
Returns *True* only if the mapper maps over an expression containing a
SumReduction operation.
"""
def combine(self, values: Iterable[bool]) -> bool:
return any(values)

# type-ignore-reason: super-class.map_reduction returns 'Any'
def map_reduction(self, expr: Reduction) -> bool: # type: ignore[override]
from loopy.library.reduction import SumReductionOperation
return (isinstance(expr.operation, SumReductionOperation)
or self.rec(expr.expr))

def map_variable(self, expr: p.Variable) -> bool:
return False

def map_algebraic_leaf(self, expr: Any) -> bool:
return False


class MultiplicativeTermReplacer(IdentityMapper):
"""
Primary mapper of
:func:`extract_multiplicative_terms_in_sum_reduction_as_subst`.
"""
def __init__(self,
*,
terms_filter: Callable[[p.Expression], bool],
subst_name: str,
subst_arguments: Tuple[str, ...]) -> None:
self.subst_name = subst_name
self.subst_arguments = subst_arguments
self.terms_filter = terms_filter
super().__init__()

# mutable state to record the expression collected by the terms_filter
self.collected_subst_rule: Optional[SubstitutionRule] = None

# type-ignore-reason: super-class.map_reduction returns 'Any'
def map_reduction(self, expr: Reduction) -> Reduction: # type: ignore[override]
from loopy.library.reduction import SumReductionOperation
from loopy.symbolic import SubstitutionMapper
if isinstance(expr.operation, SumReductionOperation):
if self.collected_subst_rule is not None:
# => there was already a sum-reduction operation -> raise
raise ValueError("Multiple sum reduction expressions found -> not"
" allowed.")

if isinstance(expr.expr, p.Product):
from pymbolic.primitives import flattened_product
terms = flattened_product(expr.expr.children).children
else:
terms = (expr.expr,)

unfiltered_terms, filtered_terms = partition(self.terms_filter, terms)
submap = SubstitutionMapper({
argument_expr: p.Variable(f"arg{i}")
for i, argument_expr in enumerate(self.subst_arguments)}.get)
self.collected_subst_rule = SubstitutionRule(
name=self.subst_name,
arguments=tuple(f"arg{i}" for i in range(len(self.subst_arguments))),
expression=submap(p.Product(tuple(filtered_terms))
if filtered_terms
else 1)
)
return Reduction(
expr.operation,
expr.inames,
p.Product((p.Variable(self.subst_name)(*self.subst_arguments),
*unfiltered_terms)),
expr.allow_simultaneous)
else:
return super().map_reduction(expr)


def extract_multiplicative_terms_in_sum_reduction_as_subst(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the outer reduction matter?

kernel: LoopKernel,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apply to TranslationUnit?

within: Any,
subst_name: str,
arguments: Sequence[p.Expression],
terms_filter: Callable[[p.Expression], bool],
Comment on lines +234 to +235
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
arguments: Sequence[p.Expression],
terms_filter: Callable[[p.Expression], bool],
arguments: Sequence[ExpressionT],
terms_filter: Callable[[ExpressionT], bool],

) -> LoopKernel:
"""
Returns a copy of *kernel* with a new substitution named *subst_name* and
*arguments* as arguments for the aggregated multiplicative terms in a
sum-reduction expression.

:arg within: A match expression understood by :func:`loopy.match.parse_match`
to specify the instructions over which the transformation is to be
performed.
:arg terms_filter: A callable to filter which terms of the sum-reduction
comprise the body of substitution rule.
:arg arguments: The sub-expressions of the product of the filtered terms that
form the arguments of the extract substitution rule in the same order.

.. note::

A ``LoopyError`` is raised if none or more than 1 sum-reduction expression
appear in *within*.
"""
from loopy.match import parse_match
within = parse_match(within)

matched_insns = [
insn
for insn in kernel.instructions
if within(kernel, insn) and ContainsSumReduction()((insn.expression,
tuple(insn.predicates)))
]

if len(matched_insns) == 0:
raise LoopyError(f"No instructions found matching '{within}'"
" with sum-reductions found.")
if len(matched_insns) > 1:
raise LoopyError(f"More than one instruction found matching '{within}'"
" with sum-reductions found -> not allowed.")

insn, = matched_insns
replacer = MultiplicativeTermReplacer(subst_name=subst_name,
subst_arguments=tuple(arguments),
terms_filter=terms_filter)
new_insn = insn.with_transformed_expressions(replacer)
new_rule = replacer.collected_subst_rule
new_substitutions = dict(kernel.substitutions).copy()
if subst_name in new_substitutions:
raise LoopyError(f"Kernel '{kernel.name}' already contains a substitution"
" rule named '{subst_name}'.")
assert new_rule is not None
new_substitutions[subst_name] = new_rule

return kernel.copy(instructions=[new_insn if insn.id == new_insn.id else insn
for insn in kernel.instructions],
substitutions=new_substitutions)

# }}}


# vim: foldmethod=marker
42 changes: 42 additions & 0 deletions test/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,48 @@ def test_prefetch_to_same_temp_var(ctx_factory):
lp.auto_test_vs_ref(ref_tunit, ctx, t_unit)


def test_sum_redn_algebraic_transforms(ctx_factory):
from pymbolic import variables
from loopy.symbolic import Reduction

t_unit = lp.make_kernel(
"{[e,i,j,x,r]: 0<=e<N_e and 0<=i,j<35 and 0<=x,r<3}",
"""
y[i] = sum([r,j], J[x, r, e]*D[r,i,j]*u[e,j])
""",
[lp.GlobalArg("J,D,u", dtype=np.float64, shape=lp.auto),
...],
)
knl = t_unit.default_entrypoint

knl = lp.split_reduction_inward(knl, "j")
knl = lp.hoist_invariant_multiplicative_terms_in_sum_reduction(
knl,
reduction_inames="j"
)
knl = lp.extract_multiplicative_terms_in_sum_reduction_as_subst(
knl,
within=None,
subst_name="grad_without_jacobi_subst",
arguments=variables("r i e"),
terms_filter=lambda x: isinstance(x, Reduction)
)

transformed_t_unit = t_unit.with_kernel(knl)
transformed_t_unit = lp.precompute(
transformed_t_unit,
"grad_without_jacobi_subst",
sweep_inames=["r", "i"],
precompute_outer_inames=frozenset({"e"}),
temporary_address_space=lp.AddressSpace.PRIVATE)

x1 = lp.get_op_map(t_unit, subgroup_size=1).eval_and_sum({"N_e": 1})
x2 = lp.get_op_map(transformed_t_unit, subgroup_size=1).eval_and_sum({"N_e": 1})

assert x1 == 33075
assert x2 == 7980 # i.e. this demonstrates a 4.14x reduction in flops


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down