From 9a1deec8222c1838c5aafb65830fc41372c48677 Mon Sep 17 00:00:00 2001 From: FermiNet Contributor Date: Mon, 18 Nov 2024 13:20:42 +0000 Subject: [PATCH] Add graph scan support for arbitrary number of input and output dimensions for dense layers PiperOrigin-RevId: 697590463 Change-Id: Id419b8b3a7288651a142c889d6bd2964c5da97b7 --- ferminet/curvature_tags_and_blocks.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ferminet/curvature_tags_and_blocks.py b/ferminet/curvature_tags_and_blocks.py index 46881b7..057bd55 100644 --- a/ferminet/curvature_tags_and_blocks.py +++ b/ferminet/curvature_tags_and_blocks.py @@ -38,7 +38,13 @@ def register_qmc(y, x, w, **kwargs): return kfac_jax.register_dense(y, x, w, variant="qmc", **kwargs) -_dense = kfac_jax.tag_graph_matcher._dense # pylint: disable=protected-access +_dense = functools.partial( + kfac_jax.tag_graph_matcher._dense, # pylint: disable=protected-access + axes=1, + with_reshape=False, +) + + _repeated_dense_parameter_extractor = functools.partial( kfac_jax.tag_graph_matcher._dense_parameter_extractor, # pylint: disable=protected-access variant="repeated_dense",