From d5807bf34dae01f8b7ef39ac9cdb577a3c3b8907 Mon Sep 17 00:00:00 2001 From: Konstantin Gulin <66528950+KSGulin@users.noreply.github.com> Date: Wed, 29 Jun 2022 15:44:19 +0100 Subject: [PATCH] Fixes for quant transfer learn (#76) (#77) --- export.py | 17 +++++++----- train.py | 5 ++-- utils/sparse.py | 69 ++++++++++++++++++++++++++++++++++--------------- 3 files changed, 62 insertions(+), 29 deletions(-) diff --git a/export.py b/export.py index 82d6a722a9b..5935de410b2 100644 --- a/export.py +++ b/export.py @@ -432,7 +432,7 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')): except Exception as e: LOGGER.info(f'\n{prefix} export failure: {e}') -def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs): +def create_checkpoint(epoch, final_epoch, model, optimizer, ema, sparseml_wrapper, **kwargs): pickle = not sparseml_wrapper.qat_active(math.inf if epoch <0 else epoch) # qat does not support pickled exports ckpt_model = deepcopy(model.module if is_parallel(model) else model).float() yaml = ckpt_model.yaml @@ -445,7 +445,7 @@ def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs): 'yaml': yaml, 'hyp': model.hyp, **ema.state_dict(pickle), - **sparseml_wrapper.state_dict(), + **sparseml_wrapper.state_dict(final_epoch), **kwargs} def load_checkpoint( @@ -469,6 +469,10 @@ def load_checkpoint( weights = attempt_download(weights) or check_download_sparsezoo_weights(weights) ckpt = torch.load(weights[0] if isinstance(weights, list) or isinstance(weights, tuple) else weights, map_location="cpu") # load checkpoint + + # temporary fix until SparseML and ZooModels are updated + ckpt['checkpoint_recipe'] = ckpt.get('recipe') or ckpt.get('checkpoint_recipe') + pickled = isinstance(ckpt['model'], nn.Module) train_type = type_ == 'train' ensemble_type = type_ == 'ensemble' @@ -500,21 +504,22 @@ def load_checkpoint( # load sparseml recipe for applying pruning and quantization checkpoint_recipe = train_recipe = None if resume: - train_recipe = ckpt.get('recipe') - elif recipe or ckpt.get('recipe'): - train_recipe, checkpoint_recipe = recipe, ckpt.get('recipe') + train_recipe, checkpoint_recipe = ckpt.get('train_recipe'), ckpt.get('checkpoint_recipe') + elif recipe or ckpt.get('checkpoint_recipe'): + train_recipe, checkpoint_recipe = recipe, ckpt.get('checkpoint_recipe') sparseml_wrapper = SparseMLWrapper( model.model if val_type else model, checkpoint_recipe, train_recipe, + train_mode=train_type, + epoch=ckpt['epoch'], one_shot=one_shot, steps_per_epoch=max_train_steps, ) exclude_anchors = not ensemble_type and (cfg or hyp.get('anchors')) and not resume loaded = False - sparseml_wrapper.apply_checkpoint_structure() if train_type: # intialize the recipe for training and restore the weights before if no quantized weights quantized_state_dict = any([name.endswith('.zero_point') for name in state_dict.keys()]) diff --git a/train.py b/train.py index 0157a45525d..cc815345d22 100644 --- a/train.py +++ b/train.py @@ -141,6 +141,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary model, None, opt.recipe, + train_mode=True, steps_per_epoch=opt.max_train_steps, one_shot=opt.one_shot, ) @@ -314,7 +315,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary "date": datetime.now().isoformat(), } ckpt = create_checkpoint( - -1, model, optimizer, ema, sparseml_wrapper, **ckpt_extras + -1, True, model, optimizer, ema, sparseml_wrapper, **ckpt_extras ) one_shot_checkpoint_name = w / "checkpoint-one-shot.pt" torch.save(ckpt, one_shot_checkpoint_name) @@ -486,7 +487,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary 'best_fitness': best_fitness, 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None, 'date': datetime.now().isoformat()} - ckpt = create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **ckpt_extras) + ckpt = create_checkpoint(epoch, final_epoch, model, optimizer, ema, sparseml_wrapper, **ckpt_extras) # Save last, best and delete torch.save(ckpt, last) diff --git a/utils/sparse.py b/utils/sparse.py index d5f8fb573a6..b03d44af1d4 100644 --- a/utils/sparse.py +++ b/utils/sparse.py @@ -9,6 +9,7 @@ from sparseml.pytorch.optim import ScheduledModifierManager from sparseml.pytorch.utils import SparsificationGroupLogger from sparseml.pytorch.utils import GradSampler +from sparseml.pytorch.sparsification.quantization import QuantizationModifier import torchvision.transforms.functional as F from utils.torch_utils import is_parallel @@ -51,7 +52,16 @@ def check_download_sparsezoo_weights(path): class SparseMLWrapper(object): - def __init__(self, model, checkpoint_recipe, train_recipe, steps_per_epoch=-1, one_shot=False): + def __init__( + self, + model, + checkpoint_recipe, + train_recipe, + train_mode=False, + epoch=-1, + steps_per_epoch=-1, + one_shot=False, + ): self.enabled = bool(train_recipe) self.model = model.module if is_parallel(model) else model self.checkpoint_manager = ScheduledModifierManager.from_yaml(checkpoint_recipe) if checkpoint_recipe else None @@ -62,21 +72,47 @@ def __init__(self, model, checkpoint_recipe, train_recipe, steps_per_epoch=-1, o self.one_shot = one_shot self.train_recipe = train_recipe - if self.one_shot: - self._apply_one_shot() - - def state_dict(self): - manager = (ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager) - if self.checkpoint_manager and self.enabled else self.manager) + self.apply_checkpoint_structure(train_mode, epoch, one_shot) - return { - 'recipe': str(manager) if self.enabled else None, - } + def state_dict(self, final_epoch): + if self.enabled or self.checkpoint_manager: + compose_recipes = self.checkpoint_manager and self.enabled and final_epoch + return { + 'checkpoint_recipe': str(ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager)) + if compose_recipes else str(self.checkpoint_manager), + 'train_recipe': str(self.manager) if not final_epoch else None + } + else: + return { + 'checkpoint_recipe': None, + 'train_recipe': None + } - def apply_checkpoint_structure(self): + def apply_checkpoint_structure(self, train_mode, epoch, one_shot=False): if self.checkpoint_manager: + # if checkpoint recipe has a QAT modifier and this is a transfer learning + # run then remove the QAT modifier from the manager + if train_mode: + qat_idx = next(( + idx for idx, mod in enumerate(self.checkpoint_manager.modifiers) + if isinstance(mod, QuantizationModifier)), -1 + ) + if qat_idx >= 0: + _ = self.checkpoint_manager.modifiers.pop(qat_idx) + self.checkpoint_manager.apply_structure(self.model, math.inf) + if train_mode and epoch > 0 and self.enabled: + self.manager.apply_structure(self.model, epoch) + elif one_shot: + if self.enabled: + self.manager.apply(self.model) + _LOGGER.info(f"Applied recipe {self.train_recipe} in one-shot manner") + else: + _LOGGER.info(f"Training recipe for one-shot application not recognized by the manager. Got recipe: " + f"{self.train_recipe}" + ) + def initialize( self, start_epoch, @@ -144,9 +180,9 @@ def check_lr_override(self, scheduler, rank): def check_epoch_override(self, epochs, rank): # Override num epochs if recipe explicitly modifies epoch range if self.enabled and self.manager.epoch_modifiers and self.manager.max_epochs: + epochs = self.manager.max_epochs or epochs # override num_epochs if rank in [0,-1]: self.logger.info(f'Overriding number of epochs from SparseML manager to {epochs}') - epochs = self.manager.max_epochs + self.start_epoch or epochs # override num_epochs return epochs @@ -195,15 +231,6 @@ def dataloader(): imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) yield [imgs], {}, targets return dataloader - - def _apply_one_shot(self): - if self.manager is not None: - self.manager.apply(self.model) - _LOGGER.info(f"Applied recipe {self.train_recipe} in one-shot manner") - else: - _LOGGER.info(f"Training recipe for one-shot application not recognized by the manager. Got recipe: " - f"{self.train_recipe}" - ) def save_sample_inputs_outputs( self,