diff --git a/diffpalm/core.py b/diffpalm/core.py index f4b29ff..190e620 100644 --- a/diffpalm/core.py +++ b/diffpalm/core.py @@ -395,6 +395,14 @@ def train( `output_dir`: if not None save the plots in this directory `save_all_figs`: if True save all the plots at each batch_size `only_loss_plot`: if True save only the loss plot at each batch_size + + Outputs: + `losses`: list of loss values for each iteration (`batch_size`*`epochs`) + `list_lr`: list of the learning rate used at each epoch + `list_idx`: list of the indexes of the predicted pairs at each iteration (`batch_size`*`epochs`) + `mats`: list of the permutation matrices at each epoch (hard permutation) + `mats_gs`: list of the soft-permutation matrices at each epoch + `list_log_alpha`: list of the log_alpha matrices at each epoch """ self._validator(input_left, input_right, fixed_pairings=fixed_pairings) if not sum(self._effective_depth_not_fixed): @@ -469,7 +477,7 @@ def _rand_perm(): mats, mats_gs = [], [] list_idx = [] list_log_alpha = [] - list_scheduler = [] + list_lr = [] gs_matching_mat = None target_idx = torch.arange( self._depth_total, dtype=torch.float, device=self.device @@ -656,13 +664,14 @@ def _rand_perm(): else: scheduler.step() - list_scheduler.append(optimizer.param_groups[0]["lr"]) + list_lr.append(optimizer.param_groups[0]["lr"]) return ( losses, - list_scheduler, - [target_idx_np, list_idx], - [mats, mats_gs], + list_lr, + list_idx, + mats, + mats_gs, list_log_alpha, ) @@ -681,6 +690,8 @@ def target_loss( `positive_examples`: if not None it's a concatenation of correct pairs to use as context (not masked) `batch_size`: batch size for the target loss (number of different masks to use at each epoch) + + Output: list of target loss values for each masking iteration (`batch_size`) """ self._validator(input_left, input_right, fixed_pairings=fixed_pairings) pbar = tqdm(range(batch_size)) diff --git a/nbs/00_core.ipynb b/nbs/00_core.ipynb index 88e9492..3104e3b 100644 --- a/nbs/00_core.ipynb +++ b/nbs/00_core.ipynb @@ -451,6 +451,14 @@ " `output_dir`: if not None save the plots in this directory\n", " `save_all_figs`: if True save all the plots at each batch_size\n", " `only_loss_plot`: if True save only the loss plot at each batch_size\n", + "\n", + " Outputs:\n", + " `losses`: list of loss values for each iteration (`batch_size`*`epochs`)\n", + " `list_lr`: list of the learning rate used at each epoch\n", + " `list_idx`: list of the indexes of the predicted pairs at each iteration (`batch_size`*`epochs`)\n", + " `mats`: list of the permutation matrices at each epoch (hard permutation)\n", + " `mats_gs`: list of the soft-permutation matrices at each epoch\n", + " `list_log_alpha`: list of the log_alpha matrices at each epoch\n", " \"\"\"\n", " self._validator(input_left, input_right, fixed_pairings=fixed_pairings)\n", " if not sum(self._effective_depth_not_fixed):\n", @@ -525,7 +533,7 @@ " mats, mats_gs = [], []\n", " list_idx = []\n", " list_log_alpha = []\n", - " list_scheduler = []\n", + " list_lr = []\n", " gs_matching_mat = None\n", " target_idx = torch.arange(\n", " self._depth_total, dtype=torch.float, device=self.device\n", @@ -712,13 +720,14 @@ " else:\n", " scheduler.step()\n", "\n", - " list_scheduler.append(optimizer.param_groups[0][\"lr\"])\n", + " list_lr.append(optimizer.param_groups[0][\"lr\"])\n", "\n", " return (\n", " losses,\n", - " list_scheduler,\n", - " [target_idx_np, list_idx],\n", - " [mats, mats_gs],\n", + " list_lr,\n", + " list_idx,\n", + " mats,\n", + " mats_gs,\n", " list_log_alpha,\n", " )\n", "\n", @@ -737,6 +746,8 @@ " `positive_examples`: if not None it's a concatenation of correct pairs to use\n", " as context (not masked)\n", " `batch_size`: batch size for the target loss (number of different masks to use at each epoch)\n", + "\n", + " Output: list of target loss values for each masking iteration (`batch_size`)\n", " \"\"\"\n", " self._validator(input_left, input_right, fixed_pairings=fixed_pairings)\n", " pbar = tqdm(range(batch_size))\n", diff --git a/nbs/_example_prokaryotic.ipynb b/nbs/_example_prokaryotic.ipynb index 6741ac2..2b68342 100644 --- a/nbs/_example_prokaryotic.ipynb +++ b/nbs/_example_prokaryotic.ipynb @@ -204,6 +204,13 @@ "When `save_all_figs=True`, a figure is saved and shown after each gradient step, illustrating the current state of the optimization. This slows the overall optimization down and may create memory leakage issues. Set `save_all_figs=False` to only have the figure saved and shown after the last gradient step." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The plotting function is able to show the number of correctly predicted pairs because the ground truth pairs are known. The model assumes that the input pairs are already correctly matched (i.e. the correct matching matrix is a diagonal matrix) because in the HK-RR and MALG-MALK datasets the sequences are are already ordered with the correct matches in the same position of the MSA." + ] + }, { "cell_type": "code", "execution_count": null, @@ -219,8 +226,9 @@ "\n", "(losses,\n", " list_scheduler,\n", - " indexes,\n", - " [mat_perm, mat_gs],\n", + " shuffled_indexes,\n", + " mat_perm,\n", + " mat_gs,\n", " list_log_alpha) = dpalm.train(\n", " left_msa,\n", " right_msa,\n", @@ -232,9 +240,9 @@ ")\n", "\n", "results = {\n", - " \"trainng_results\": (losses, list_scheduler, indexes, [mat_perm, mat_gs], list_log_alpha),\n", + " \"trainng_results\": (losses, list_scheduler, shuffled_indexes, [mat_perm, mat_gs], list_log_alpha),\n", " \"target_loss\": tar_loss,\n", - " \"depth\": depth\n", + " \"species_sizes\": species_sizes\n", "}" ] }