-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rewrite for 1 ** x = 1
#1179
Changes from 1 commit
97bd1c4
d873bff
daf9286
6df788c
836eb45
c4f662a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,8 @@ | |
|
||
import pytensor.scalar.basic as ps | ||
import pytensor.scalar.math as ps_math | ||
from pytensor.graph.basic import Constant, Variable | ||
from pytensor.graph import FunctionGraph | ||
from pytensor.graph.basic import Apply, Constant, Variable | ||
from pytensor.graph.rewriting.basic import ( | ||
NodeRewriter, | ||
PatternNodeRewriter, | ||
|
@@ -1914,6 +1915,33 @@ def local_pow_canonicalize(fgraph, node): | |
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)] | ||
|
||
|
||
@register_canonicalize | ||
@node_rewriter([pt_pow]) | ||
def local_pow_canonicalize_base_1( | ||
fgraph: FunctionGraph, node: Apply | ||
) -> list[TensorVariable] | None: | ||
""" | ||
Replace `1 ** x` with 1, broadcast to the shape of the output. | ||
|
||
Parameters | ||
---------- | ||
fgraph: FunctionGraph | ||
Full function graph being rewritten | ||
node: Apply | ||
Specific node being rewritten | ||
|
||
Returns | ||
------- | ||
rewritten_output: list[TensorVariable] | None | ||
Rewritten output of node, or None if no rewrite is possible | ||
""" | ||
cst = get_underlying_scalar_constant_value( | ||
node.inputs[0], only_process_constants=True, raise_not_constant=False | ||
) | ||
if cst == 1: | ||
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could make an infinite recursion if the You can do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, but does that mean we should also change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines to that rewrite? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. found it yes. Why don't you combine your changes with that rewrite? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check how it looks now? I had to set the dtype to the output as well, not sure if there's a better way |
||
|
||
|
||
@register_specialize | ||
@node_rewriter([mul]) | ||
def local_mul_to_sqr(fgraph, node): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should add docstrings for rewrites, just adds lines to the codebase. Nobody will be calling this function manually
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose I agree w.r.t Parameters and Returns, but there should at least be a small explainer of what the rewrite does.