Skip to content

Commit

Permalink
onmt/utils/distributed.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Attieh authored and Joseph Attieh committed Feb 6, 2024
1 parent b9d50ed commit 956aea9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions mammoth/distributed/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,15 @@ def batch_producer(generator_to_serve, queue, semaphore, opts, device_id):
logger.info("BATCH PRODUCER")
logger.info(generator_to_serve)

for batch, metadata, communication_batch_id in generator_to_serve:
for batch, metadata, communication_batch_id, data_states in generator_to_serve:
semaphore.acquire()
# Move batch to correspond device_id when consumer iterate
# hack to dodge unpicklable `dict_keys`
# batch.fields = list(batch.fields)
queue.put((batch, metadata, communication_batch_id))
queue.put((batch, metadata, communication_batch_id, data_states))


def consumer(process_fn, opts, device_context, error_queue, batch_queue, semaphore, task_queue_manager):
def consumer(process_fn, opts, device_context, error_queue, batch_queue, semaphore, task_queue_manager, checkpoint):
"""Run `process_fn` on `device_id` with data from `batch_queue`."""
try:
logger.info(
Expand All @@ -271,6 +271,7 @@ def consumer(process_fn, opts, device_context, error_queue, batch_queue, semapho
batch_queue=batch_queue,
semaphore=semaphore,
task_queue_manager=task_queue_manager,
checkpoint=checkpoint,
)

except KeyboardInterrupt:
Expand Down
2 changes: 1 addition & 1 deletion mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def _gather_data_state(self, device_context):
new_data_state = self.model_saver.data_state.copy()
for taskname, idx_n_buckets in self.model_saver.data_state.items():
# gather indices
tmplist = mammoth.utils.distributed.all_gather_list(idx_n_buckets['indices'])
tmplist = mammoth.distributed.all_gather_list(idx_n_buckets['indices'])
tmplist = [x for x in tmplist if isinstance(x, int)]
if device_context.is_master():
new_data_state[taskname]['indices'] = max(tmplist)
Expand Down

0 comments on commit 956aea9

Please sign in to comment.