From b1129336ca9a8c2b661c442449771ec67586239c Mon Sep 17 00:00:00 2001
From: Ivan Ukhov <ivan.ukhov@gmail.com>
Date: Tue, 3 Dec 2024 14:48:39 +0100
Subject: [PATCH] Update gradient accumulation to TensorFlow 2.17

---
 _posts/2024-01-31-gradient-accumulation.md | 73 ++++++++++++----------
 1 file changed, 41 insertions(+), 32 deletions(-)

diff --git a/_posts/2024-01-31-gradient-accumulation.md b/_posts/2024-01-31-gradient-accumulation.md
index 8f10143..8952e53 100644
--- a/_posts/2024-01-31-gradient-accumulation.md
+++ b/_posts/2024-01-31-gradient-accumulation.md
@@ -23,7 +23,7 @@ how it can be implemented in practice.
 
 # Solution
 
-Long story short:
+Long story short, assuming TensorFlow 2.17:
 
 ```python
 # Inherit from any optimizer of choice, such as Adam.
@@ -44,9 +44,13 @@ class Optimizer(tf.keras.optimizers.Adam):
         """
         super().__init__(**options)
         self.accumulation = accumulation
-        self._accumulation = None
         self._gradients = None
 
+    @property
+    def iterations(self) -> int:
+        """Return the number of iterations."""
+        return tf.keras.ops.floor_divide(self._iterations, self.accumulation)
+
     def apply_gradients(
         self, gradients_variables: list[tuple[tf.Tensor, tf.Tensor]]
     ) -> tf.Tensor:
@@ -54,48 +58,53 @@ class Optimizer(tf.keras.optimizers.Adam):
         # Split off the gradients from the trainable variables.
         gradients, variables = zip(*list(gradients_variables))
         # Perform the initialization if needed.
-        with tf.init_scope():
-            self.build(variables)
-        first = self._accumulation % self.accumulation == 0
-        last = (self._accumulation + 1) % self.accumulation == 0
+        if not self.built:
+            with tf.init_scope():
+                self.build(variables)
+        first = self._iterations % self.accumulation == 0
+        last = (self._iterations + 1) % self.accumulation == 0
         # Add the new gradients to the old ones with resetting if needed.
-        for gradient, increment in zip(self._gradients, gradients):
-            gradient.assign(tf.cast(~first, tf.float32) * gradient + increment)
+        for sum, delta in zip(self._gradients, gradients):
+            if delta is not None:
+                sum.assign(tf.cast(~first, tf.float32) * sum + delta)
         # Apply the average accumulated gradients to the trainable variables.
         gradients = [gradient / self.accumulation for gradient in self._gradients]
-        super().apply_gradients(zip(gradients, variables))
-        # Decrement the base counter incremented by the application if needed.
-        self.iterations.assign_sub(tf.cast(~last, tf.int64))
-        # Increment the accumulation counter.
-        self._accumulation.assign_add(1)
-        return self.iterations
-
-    def update_step(self, gradient: tf.Tensor, variable: tf.Tensor) -> None:
+        return super().apply_gradients(zip(gradients, variables))
+
+    def update_step(
+        self,
+        gradient: tf.Tensor,
+        variable: tf.Tensor,
+        learning_rate: any,
+    ) -> None:
         """Update the trainable variable with the gradient."""
         update_step = super().update_step
-        last = (self._accumulation + 1) % self.accumulation == 0
+        last = (self._iterations + 1) % self.accumulation == 0
         # Allow the update to happen only at the end of each cycle.
-        tf.cond(last, lambda: update_step(gradient, variable), lambda: None)
+        true = lambda: update_step(gradient, variable, learning_rate)
+        tf.cond(last, true, lambda: None)
 
     def build(self, variables: list[tf.Tensor]) -> None:
         """Initialize the internal state."""
         super().build(variables)
-        if self._gradients is None:
-            # Create a counter for tracking accumulation.
-            self._accumulation = self.add_variable(shape=(), dtype=tf.int64)
-            # Allocate memory for accumulation.
-            self._gradients = [
-                self.add_variable_from_reference(
-                    model_variable=variable,
-                    variable_name="gradient",
-                )
-                for variable in variables
-            ]
+        # Allocate memory for accumulation.
+        self._gradients = [
+            self.add_variable_from_reference(
+                reference_variable=variable,
+                name="gradient",
+            )
+            for variable in variables
+        ]
 ```
 
-It is important to note that the learning rate is _not_ held constant during
-accumulation. However, since it is not expected to change much from one
-iteration to another, it is an adequate simplification.
+It is important to note that the learning rate keeps on changing (if variable)
+and the weights keep on decaying (if enabled) during accumulation. Therefore,
+one should account for this when configuring the optimizer at hand.
+
+One should also note that Keras does support gradient accumulation, which is
+controlled via the `gradient_accumulation_steps` option of optimizers. However,
+it does not play well with distributed training strategies, which will hopefully
+be rectified in the future.
 
 # Acknowledgments