From b1129336ca9a8c2b661c442449771ec67586239c Mon Sep 17 00:00:00 2001 From: Ivan Ukhov <ivan.ukhov@gmail.com> Date: Tue, 3 Dec 2024 14:48:39 +0100 Subject: [PATCH] Update gradient accumulation to TensorFlow 2.17 --- _posts/2024-01-31-gradient-accumulation.md | 73 ++++++++++++---------- 1 file changed, 41 insertions(+), 32 deletions(-) diff --git a/_posts/2024-01-31-gradient-accumulation.md b/_posts/2024-01-31-gradient-accumulation.md index 8f10143..8952e53 100644 --- a/_posts/2024-01-31-gradient-accumulation.md +++ b/_posts/2024-01-31-gradient-accumulation.md @@ -23,7 +23,7 @@ how it can be implemented in practice. # Solution -Long story short: +Long story short, assuming TensorFlow 2.17: ```python # Inherit from any optimizer of choice, such as Adam. @@ -44,9 +44,13 @@ class Optimizer(tf.keras.optimizers.Adam): """ super().__init__(**options) self.accumulation = accumulation - self._accumulation = None self._gradients = None + @property + def iterations(self) -> int: + """Return the number of iterations.""" + return tf.keras.ops.floor_divide(self._iterations, self.accumulation) + def apply_gradients( self, gradients_variables: list[tuple[tf.Tensor, tf.Tensor]] ) -> tf.Tensor: @@ -54,48 +58,53 @@ class Optimizer(tf.keras.optimizers.Adam): # Split off the gradients from the trainable variables. gradients, variables = zip(*list(gradients_variables)) # Perform the initialization if needed. - with tf.init_scope(): - self.build(variables) - first = self._accumulation % self.accumulation == 0 - last = (self._accumulation + 1) % self.accumulation == 0 + if not self.built: + with tf.init_scope(): + self.build(variables) + first = self._iterations % self.accumulation == 0 + last = (self._iterations + 1) % self.accumulation == 0 # Add the new gradients to the old ones with resetting if needed. - for gradient, increment in zip(self._gradients, gradients): - gradient.assign(tf.cast(~first, tf.float32) * gradient + increment) + for sum, delta in zip(self._gradients, gradients): + if delta is not None: + sum.assign(tf.cast(~first, tf.float32) * sum + delta) # Apply the average accumulated gradients to the trainable variables. gradients = [gradient / self.accumulation for gradient in self._gradients] - super().apply_gradients(zip(gradients, variables)) - # Decrement the base counter incremented by the application if needed. - self.iterations.assign_sub(tf.cast(~last, tf.int64)) - # Increment the accumulation counter. - self._accumulation.assign_add(1) - return self.iterations - - def update_step(self, gradient: tf.Tensor, variable: tf.Tensor) -> None: + return super().apply_gradients(zip(gradients, variables)) + + def update_step( + self, + gradient: tf.Tensor, + variable: tf.Tensor, + learning_rate: any, + ) -> None: """Update the trainable variable with the gradient.""" update_step = super().update_step - last = (self._accumulation + 1) % self.accumulation == 0 + last = (self._iterations + 1) % self.accumulation == 0 # Allow the update to happen only at the end of each cycle. - tf.cond(last, lambda: update_step(gradient, variable), lambda: None) + true = lambda: update_step(gradient, variable, learning_rate) + tf.cond(last, true, lambda: None) def build(self, variables: list[tf.Tensor]) -> None: """Initialize the internal state.""" super().build(variables) - if self._gradients is None: - # Create a counter for tracking accumulation. - self._accumulation = self.add_variable(shape=(), dtype=tf.int64) - # Allocate memory for accumulation. - self._gradients = [ - self.add_variable_from_reference( - model_variable=variable, - variable_name="gradient", - ) - for variable in variables - ] + # Allocate memory for accumulation. + self._gradients = [ + self.add_variable_from_reference( + reference_variable=variable, + name="gradient", + ) + for variable in variables + ] ``` -It is important to note that the learning rate is _not_ held constant during -accumulation. However, since it is not expected to change much from one -iteration to another, it is an adequate simplification. +It is important to note that the learning rate keeps on changing (if variable) +and the weights keep on decaying (if enabled) during accumulation. Therefore, +one should account for this when configuring the optimizer at hand. + +One should also note that Keras does support gradient accumulation, which is +controlled via the `gradient_accumulation_steps` option of optimizers. However, +it does not play well with distributed training strategies, which will hopefully +be rectified in the future. # Acknowledgments