Skip to content

Commit

Permalink
Fixes for quant transfer learn (#76) (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
KSGulin authored Jun 29, 2022
1 parent a065b8d commit d5807bf
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 29 deletions.
17 changes: 11 additions & 6 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')

def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs):
def create_checkpoint(epoch, final_epoch, model, optimizer, ema, sparseml_wrapper, **kwargs):
pickle = not sparseml_wrapper.qat_active(math.inf if epoch <0 else epoch) # qat does not support pickled exports
ckpt_model = deepcopy(model.module if is_parallel(model) else model).float()
yaml = ckpt_model.yaml
Expand All @@ -445,7 +445,7 @@ def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs):
'yaml': yaml,
'hyp': model.hyp,
**ema.state_dict(pickle),
**sparseml_wrapper.state_dict(),
**sparseml_wrapper.state_dict(final_epoch),
**kwargs}

def load_checkpoint(
Expand All @@ -469,6 +469,10 @@ def load_checkpoint(
weights = attempt_download(weights) or check_download_sparsezoo_weights(weights)
ckpt = torch.load(weights[0] if isinstance(weights, list) or isinstance(weights, tuple)
else weights, map_location="cpu") # load checkpoint

# temporary fix until SparseML and ZooModels are updated
ckpt['checkpoint_recipe'] = ckpt.get('recipe') or ckpt.get('checkpoint_recipe')

pickled = isinstance(ckpt['model'], nn.Module)
train_type = type_ == 'train'
ensemble_type = type_ == 'ensemble'
Expand Down Expand Up @@ -500,21 +504,22 @@ def load_checkpoint(
# load sparseml recipe for applying pruning and quantization
checkpoint_recipe = train_recipe = None
if resume:
train_recipe = ckpt.get('recipe')
elif recipe or ckpt.get('recipe'):
train_recipe, checkpoint_recipe = recipe, ckpt.get('recipe')
train_recipe, checkpoint_recipe = ckpt.get('train_recipe'), ckpt.get('checkpoint_recipe')
elif recipe or ckpt.get('checkpoint_recipe'):
train_recipe, checkpoint_recipe = recipe, ckpt.get('checkpoint_recipe')

sparseml_wrapper = SparseMLWrapper(
model.model if val_type else model,
checkpoint_recipe,
train_recipe,
train_mode=train_type,
epoch=ckpt['epoch'],
one_shot=one_shot,
steps_per_epoch=max_train_steps,
)
exclude_anchors = not ensemble_type and (cfg or hyp.get('anchors')) and not resume
loaded = False

sparseml_wrapper.apply_checkpoint_structure()
if train_type:
# intialize the recipe for training and restore the weights before if no quantized weights
quantized_state_dict = any([name.endswith('.zero_point') for name in state_dict.keys()])
Expand Down
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
model,
None,
opt.recipe,
train_mode=True,
steps_per_epoch=opt.max_train_steps,
one_shot=opt.one_shot,
)
Expand Down Expand Up @@ -314,7 +315,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
"date": datetime.now().isoformat(),
}
ckpt = create_checkpoint(
-1, model, optimizer, ema, sparseml_wrapper, **ckpt_extras
-1, True, model, optimizer, ema, sparseml_wrapper, **ckpt_extras
)
one_shot_checkpoint_name = w / "checkpoint-one-shot.pt"
torch.save(ckpt, one_shot_checkpoint_name)
Expand Down Expand Up @@ -486,7 +487,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
'best_fitness': best_fitness,
'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
'date': datetime.now().isoformat()}
ckpt = create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **ckpt_extras)
ckpt = create_checkpoint(epoch, final_epoch, model, optimizer, ema, sparseml_wrapper, **ckpt_extras)

# Save last, best and delete
torch.save(ckpt, last)
Expand Down
69 changes: 48 additions & 21 deletions utils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import SparsificationGroupLogger
from sparseml.pytorch.utils import GradSampler
from sparseml.pytorch.sparsification.quantization import QuantizationModifier
import torchvision.transforms.functional as F

from utils.torch_utils import is_parallel
Expand Down Expand Up @@ -51,7 +52,16 @@ def check_download_sparsezoo_weights(path):


class SparseMLWrapper(object):
def __init__(self, model, checkpoint_recipe, train_recipe, steps_per_epoch=-1, one_shot=False):
def __init__(
self,
model,
checkpoint_recipe,
train_recipe,
train_mode=False,
epoch=-1,
steps_per_epoch=-1,
one_shot=False,
):
self.enabled = bool(train_recipe)
self.model = model.module if is_parallel(model) else model
self.checkpoint_manager = ScheduledModifierManager.from_yaml(checkpoint_recipe) if checkpoint_recipe else None
Expand All @@ -62,21 +72,47 @@ def __init__(self, model, checkpoint_recipe, train_recipe, steps_per_epoch=-1, o
self.one_shot = one_shot
self.train_recipe = train_recipe

if self.one_shot:
self._apply_one_shot()

def state_dict(self):
manager = (ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager)
if self.checkpoint_manager and self.enabled else self.manager)
self.apply_checkpoint_structure(train_mode, epoch, one_shot)

return {
'recipe': str(manager) if self.enabled else None,
}
def state_dict(self, final_epoch):
if self.enabled or self.checkpoint_manager:
compose_recipes = self.checkpoint_manager and self.enabled and final_epoch
return {
'checkpoint_recipe': str(ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager))
if compose_recipes else str(self.checkpoint_manager),
'train_recipe': str(self.manager) if not final_epoch else None
}
else:
return {
'checkpoint_recipe': None,
'train_recipe': None
}

def apply_checkpoint_structure(self):
def apply_checkpoint_structure(self, train_mode, epoch, one_shot=False):
if self.checkpoint_manager:
# if checkpoint recipe has a QAT modifier and this is a transfer learning
# run then remove the QAT modifier from the manager
if train_mode:
qat_idx = next((
idx for idx, mod in enumerate(self.checkpoint_manager.modifiers)
if isinstance(mod, QuantizationModifier)), -1
)
if qat_idx >= 0:
_ = self.checkpoint_manager.modifiers.pop(qat_idx)

self.checkpoint_manager.apply_structure(self.model, math.inf)

if train_mode and epoch > 0 and self.enabled:
self.manager.apply_structure(self.model, epoch)
elif one_shot:
if self.enabled:
self.manager.apply(self.model)
_LOGGER.info(f"Applied recipe {self.train_recipe} in one-shot manner")
else:
_LOGGER.info(f"Training recipe for one-shot application not recognized by the manager. Got recipe: "
f"{self.train_recipe}"
)

def initialize(
self,
start_epoch,
Expand Down Expand Up @@ -144,9 +180,9 @@ def check_lr_override(self, scheduler, rank):
def check_epoch_override(self, epochs, rank):
# Override num epochs if recipe explicitly modifies epoch range
if self.enabled and self.manager.epoch_modifiers and self.manager.max_epochs:
epochs = self.manager.max_epochs or epochs # override num_epochs
if rank in [0,-1]:
self.logger.info(f'Overriding number of epochs from SparseML manager to {epochs}')
epochs = self.manager.max_epochs + self.start_epoch or epochs # override num_epochs

return epochs

Expand Down Expand Up @@ -195,15 +231,6 @@ def dataloader():
imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
yield [imgs], {}, targets
return dataloader

def _apply_one_shot(self):
if self.manager is not None:
self.manager.apply(self.model)
_LOGGER.info(f"Applied recipe {self.train_recipe} in one-shot manner")
else:
_LOGGER.info(f"Training recipe for one-shot application not recognized by the manager. Got recipe: "
f"{self.train_recipe}"
)

def save_sample_inputs_outputs(
self,
Expand Down

0 comments on commit d5807bf

Please sign in to comment.