Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute pushforward (Rop) via double application of pullback (Lop) and fix Scan and Max gradient bugs #1207

Merged
merged 5 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion doc/extending/op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -506,4 +506,3 @@ These are the function required to work with :func:`pytensor.gradient.grad`.
the outputs) back to their corresponding shapes and return them as the
output of the :meth:`Op.R_op` method.

:ref:`List of op with r op support <R_op_list>`.
76 changes: 0 additions & 76 deletions doc/library/gradient.rst

This file was deleted.

2 changes: 0 additions & 2 deletions doc/library/tensor/basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1791,5 +1791,3 @@ Gradient / Differentiation
:members: grad
:noindex:

See the :ref:`gradient <libdoc_gradient>` page for complete documentation
of the gradient module.
21 changes: 16 additions & 5 deletions doc/tutorial/gradients.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ of symbolic differentiation).
``i`` of the output list is the gradient of the first argument of
`pt.grad` with respect to the ``i``-th element of the list given as second argument.
The first argument of `pt.grad` has to be a scalar (a tensor
of size 1). For more information on the semantics of the arguments of
`pt.grad` and details about the implementation, see
:ref:`this<libdoc_gradient>` section of the library.
of size 1).

Additional information on the inner workings of differentiation may also be
found in the more advanced tutorial :ref:`Extending PyTensor<extending>`.
Expand Down Expand Up @@ -204,7 +202,21 @@ you need to do something similar to this:
>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1])
array([ 2., 2.])

:ref:`List <R_op_list>` of Op that implement Rop.
By default, the R-operator is implemented as a double application of the L_operator
(see `reference <https://j-towns.github.io/2017/06/12/A-new-trick.html>`_).
In most cases this should be as performant as a specialized implementation of the R-operator.
However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators,
such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator.

When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing
`use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method.


>>> JV = pytensor.gradient.Rop(y, W, V, use_op_rop_implementation=True)
>>> f = pytensor.function([W, V, x], JV)
>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1])
array([ 2., 2.])


L-operator
----------
Expand Down Expand Up @@ -234,7 +246,6 @@ array([[ 0., 0.],
as the input parameter, while the result of the R-operator has a shape similar
to that of the output.

:ref:`List of op with r op support <R_op_list>`.

Hessian times a Vector
======================
Expand Down
13 changes: 12 additions & 1 deletion pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,12 @@ def __init__(
``None``, this will be used as the connection_pattern for this
:class:`Op`.

.. warning::

rop overrides is ignored when `pytensor.gradient.Rop` is called with
`use_op_rop_implementation=False` (default). In this case the Lop
is used twice to obtain a mathematically equivalent Rop.

strict: bool, default False
If true, it raises when any variables needed to compute the inner graph
are not provided as explici inputs. This can only happen for graphs with
Expand Down Expand Up @@ -641,7 +647,12 @@ def _build_and_cache_rop_op(self):
return rop_overrides

eval_points = [inp_t() for inp_t in self.input_types]
fn_rop = partial(Rop, wrt=inner_inputs, eval_points=eval_points)
fn_rop = partial(
Rop,
wrt=inner_inputs,
eval_points=eval_points,
use_op_rop_implementation=True,
)

callable_args = (inner_inputs, eval_points)
if rop_overrides is None:
Expand Down
Loading