Skip to content

Commit

Permalink
Compute pushforward via double application of pullback
Browse files Browse the repository at this point in the history
Also fixes bug in Scan L_op and Max R_op

Co-authored-by: Adrian Seyboldt <aseyboldt@users.noreply.github.com>
  • Loading branch information
ricardoV94 and aseyboldt committed Feb 14, 2025
1 parent f7927c6 commit 0b64124
Show file tree
Hide file tree
Showing 11 changed files with 294 additions and 159 deletions.
1 change: 0 additions & 1 deletion doc/extending/op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -506,4 +506,3 @@ These are the function required to work with :func:`pytensor.gradient.grad`.
the outputs) back to their corresponding shapes and return them as the
output of the :meth:`Op.R_op` method.

:ref:`List of op with r op support <R_op_list>`.
76 changes: 0 additions & 76 deletions doc/library/gradient.rst

This file was deleted.

2 changes: 0 additions & 2 deletions doc/library/tensor/basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1791,5 +1791,3 @@ Gradient / Differentiation
:members: grad
:noindex:

See the :ref:`gradient <libdoc_gradient>` page for complete documentation
of the gradient module.
21 changes: 16 additions & 5 deletions doc/tutorial/gradients.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ of symbolic differentiation).
``i`` of the output list is the gradient of the first argument of
`pt.grad` with respect to the ``i``-th element of the list given as second argument.
The first argument of `pt.grad` has to be a scalar (a tensor
of size 1). For more information on the semantics of the arguments of
`pt.grad` and details about the implementation, see
:ref:`this<libdoc_gradient>` section of the library.
of size 1).

Additional information on the inner workings of differentiation may also be
found in the more advanced tutorial :ref:`Extending PyTensor<extending>`.
Expand Down Expand Up @@ -204,7 +202,21 @@ you need to do something similar to this:
>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1])
array([ 2., 2.])

:ref:`List <R_op_list>` of Op that implement Rop.
By default, the R-operator is implemented as a double application of the L_operator
(see `reference <https://j-towns.github.io/2017/06/12/A-new-trick.html>`_).
In most cases this should be as performant as a specialized implementation of the R-operator.
However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators,
such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator.

When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing
`use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method.


>>> JV = pytensor.gradient.Rop(y, W, V, use_op_rop_implementation=True)
>>> f = pytensor.function([W, V, x], JV)
>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1])
array([ 2., 2.])


L-operator
----------
Expand Down Expand Up @@ -234,7 +246,6 @@ array([[ 0., 0.],
as the input parameter, while the result of the R-operator has a shape similar
to that of the output.

:ref:`List of op with r op support <R_op_list>`.

Hessian times a Vector
======================
Expand Down
13 changes: 12 additions & 1 deletion pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,12 @@ def __init__(
``None``, this will be used as the connection_pattern for this
:class:`Op`.
.. warning::
rop overrides is ignored when `pytensor.gradient.Rop` is called with
`use_op_rop_implementation=False` (default). In this case the Lop
is used twice to obtain a mathematically equivalent Rop.
strict: bool, default False
If true, it raises when any variables needed to compute the inner graph
are not provided as explici inputs. This can only happen for graphs with
Expand Down Expand Up @@ -641,7 +647,12 @@ def _build_and_cache_rop_op(self):
return rop_overrides

eval_points = [inp_t() for inp_t in self.input_types]
fn_rop = partial(Rop, wrt=inner_inputs, eval_points=eval_points)
fn_rop = partial(
Rop,
wrt=inner_inputs,
eval_points=eval_points,
use_op_rop_implementation=True,
)

callable_args = (inner_inputs, eval_points)
if rop_overrides is None:
Expand Down
Loading

0 comments on commit 0b64124

Please sign in to comment.