Skip to content

Commit

Permalink
Remove obsolete only_ready_reduce_and_rescale_grads, fix some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Apr 22, 2024
1 parent 72be768 commit e68e67a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 81 deletions.
4 changes: 2 additions & 2 deletions mammoth/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
batch_producer,
consumer,
broadcast_tensors,
only_ready_reduce_and_rescale_grads,
managed_reduce_and_rescale_grads,
ErrorHandler,
)
from mammoth.distributed.contexts import DeviceContext, WorldContext, DeviceContextEnum
Expand All @@ -20,7 +20,7 @@
"batch_producer",
"broadcast_tensors",
"consumer",
"only_ready_reduce_and_rescale_grads",
"managed_reduce_and_rescale_grads",
"ErrorHandler",
"DeviceContext",
"WorldContext",
Expand Down
80 changes: 1 addition & 79 deletions mammoth/distributed/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def managed_reduce_and_rescale_grads(
if p.grad is None or not has_local_gradient:
p.grad = torch.zeros_like(p)

grads = [p.grad.data for p in require_grad]
grads = [p.grad.data for name, p in require_grad]

# All devices communicate either a real gradient or a dummy zeros of the same size
all_reduce_and_rescale_tensors(grads, rescale_denom=gradient_norm, group=group)
Expand All @@ -80,84 +80,6 @@ def managed_reduce_and_rescale_grads(
# the optimizer can not use it to prevent the untrained components from being stepped


def only_ready_reduce_and_rescale_grads(named_parameters, group=None):
"""
(obsolete)
Gradient synch tolerant to missing grads.
Missing grads occur when some parameters are not trained between two
gradient synchs, e.g. the embeddings of a low-resource language with low
sampling weight.
The algorithm first uses the 'has_grad' attribute set by the forward hook
'has_grad_hook'. This hook ensures that all parameters of the modules
selected for use during the current training computation have 'has_grad'
set to True. This gives the list of parameters that have been trained on
this device ("ready").
A bit mask covering the parameters that are ready on this device is
communicated to the other devices in the group. The bit masks are reduced
using summation. The sum gives the number of real gradients for that
parameter, and can be used for normalization.
If a parameter is ready on any device, all devices communicate a value.
Devices on which the parameter is ready communicate the actual gradient,
while devices on which it is not ready communicate a dummy zero tensor
instead. The sum computed previously is used for normalization.
Args:
named_parameters: tuples of (str, Parameter) defining the parameters to consider
group: torch.distributed communication group
"""
# Set missing gradients to zero, keeping track of true gradients
require_grad = [(name, p) for (name, p) in named_parameters if p.requires_grad]
if not require_grad:
# Exit early if the component has no parameters that require a gradient
return
device = require_grad[0][1].device
ready_list = []
for name, p in require_grad:
if hasattr(p, 'has_grad') and p.has_grad:
ready_list.append(1.0)
else:
ready_list.append(0.0)
if p.grad is None:
p.grad = torch.zeros_like(p)

# Communicate the ready bits, and reduce them using summation.
# This gives the number of non-dummy gradients participating, for normalization
ready_t = torch.tensor(ready_list).to(device)
if group is None:
torch.distributed.all_reduce(ready_t)
else:
torch.distributed.all_reduce(ready_t, group=group)
rescale_denoms = ready_t # after reduction

# Omit if all nodes sent a zero ready bit
denoms_mask = (rescale_denoms > 0).cpu()
params_with_grad = [p for ((name, p), m) in zip(require_grad, denoms_mask) if m]
grads = [p.grad.data for p in params_with_grad]
rescale_denoms = [denom for (denom, m) in zip(rescale_denoms, denoms_mask) if m]
assert len(grads) == len(rescale_denoms)
if len(grads) == 0:
return

# If not, then set has_grad also on devices that did not train the parameter themselves.
# They now have a grad that they received from the other devices.
for name, p in require_grad:
p.has_grad = True

# All devices communicate either a real gradient or a dummy zeros of the same size
# Can not use rescale_denom, as each grad may have its own denominator
all_reduce_and_rescale_tensors(grads, rescale_denom=1, group=group)

# Normalize using the previously computed values
for grad, denom in zip(grads, rescale_denoms):
if denom > 1:
grad.div_(denom)
# Note: p.has_grad is reused in the optimizer to prevent the untrained components from being stepped


def all_reduce_and_rescale_tensors(tensors, rescale_denom, group=None, buffer_size=10485760):
"""
All-reduce and rescale tensors in chunks of the specified size.
Expand Down

0 comments on commit e68e67a

Please sign in to comment.