diff --git a/mammoth/distributed/communication.py b/mammoth/distributed/communication.py index 687da4b9..2767fff1 100644 --- a/mammoth/distributed/communication.py +++ b/mammoth/distributed/communication.py @@ -246,15 +246,15 @@ def batch_producer(generator_to_serve, queue, semaphore, opts, device_id): logger.info("BATCH PRODUCER") logger.info(generator_to_serve) - for batch, metadata, communication_batch_id in generator_to_serve: + for batch, metadata, communication_batch_id, data_states in generator_to_serve: semaphore.acquire() # Move batch to correspond device_id when consumer iterate # hack to dodge unpicklable `dict_keys` # batch.fields = list(batch.fields) - queue.put((batch, metadata, communication_batch_id)) + queue.put((batch, metadata, communication_batch_id, data_states)) -def consumer(process_fn, opts, device_context, error_queue, batch_queue, semaphore, task_queue_manager): +def consumer(process_fn, opts, device_context, error_queue, batch_queue, semaphore, task_queue_manager, checkpoint): """Run `process_fn` on `device_id` with data from `batch_queue`.""" try: logger.info( @@ -271,6 +271,7 @@ def consumer(process_fn, opts, device_context, error_queue, batch_queue, semapho batch_queue=batch_queue, semaphore=semaphore, task_queue_manager=task_queue_manager, + checkpoint=checkpoint, ) except KeyboardInterrupt: diff --git a/mammoth/trainer.py b/mammoth/trainer.py index 3d7a4541..667acdcf 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -459,7 +459,7 @@ def _gather_data_state(self, device_context): new_data_state = self.model_saver.data_state.copy() for taskname, idx_n_buckets in self.model_saver.data_state.items(): # gather indices - tmplist = mammoth.utils.distributed.all_gather_list(idx_n_buckets['indices']) + tmplist = mammoth.distributed.all_gather_list(idx_n_buckets['indices']) tmplist = [x for x in tmplist if isinstance(x, int)] if device_context.is_master(): new_data_state[taskname]['indices'] = max(tmplist)