Skip to content

Commit

Permalink
Add graph scan support for arbitrary number of input and output dimen…
Browse files Browse the repository at this point in the history
…sions for dense layers

PiperOrigin-RevId: 697590463
Change-Id: Id419b8b3a7288651a142c889d6bd2964c5da97b7
  • Loading branch information
FermiNet Contributor authored and jsspencer committed Dec 8, 2024
1 parent 38df5df commit 9a1deec
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion ferminet/curvature_tags_and_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@ def register_qmc(y, x, w, **kwargs):
return kfac_jax.register_dense(y, x, w, variant="qmc", **kwargs)


_dense = kfac_jax.tag_graph_matcher._dense # pylint: disable=protected-access
_dense = functools.partial(
kfac_jax.tag_graph_matcher._dense, # pylint: disable=protected-access
axes=1,
with_reshape=False,
)


_repeated_dense_parameter_extractor = functools.partial(
kfac_jax.tag_graph_matcher._dense_parameter_extractor, # pylint: disable=protected-access
variant="repeated_dense",
Expand Down

0 comments on commit 9a1deec

Please sign in to comment.