Skip to content

Commit

Permalink
Add documentation to core and change outputs of training function
Browse files Browse the repository at this point in the history
Co-authored-by: Damiano Sgarbossa <damiano.sgarbossa@epfl.ch>
Co-authored-by: Umberto Lupo <umberto.lupo@epfl.ch>
  • Loading branch information
damiano-sg and ulupo committed Aug 17, 2023
1 parent d22f215 commit 60a2083
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 14 deletions.
21 changes: 16 additions & 5 deletions diffpalm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,14 @@ def train(
`output_dir`: if not None save the plots in this directory
`save_all_figs`: if True save all the plots at each batch_size
`only_loss_plot`: if True save only the loss plot at each batch_size
Outputs:
`losses`: list of loss values for each iteration (`batch_size`*`epochs`)
`list_lr`: list of the learning rate used at each epoch
`list_idx`: list of the indexes of the predicted pairs at each iteration (`batch_size`*`epochs`)
`mats`: list of the permutation matrices at each epoch (hard permutation)
`mats_gs`: list of the soft-permutation matrices at each epoch
`list_log_alpha`: list of the log_alpha matrices at each epoch
"""
self._validator(input_left, input_right, fixed_pairings=fixed_pairings)
if not sum(self._effective_depth_not_fixed):
Expand Down Expand Up @@ -469,7 +477,7 @@ def _rand_perm():
mats, mats_gs = [], []
list_idx = []
list_log_alpha = []
list_scheduler = []
list_lr = []
gs_matching_mat = None
target_idx = torch.arange(
self._depth_total, dtype=torch.float, device=self.device
Expand Down Expand Up @@ -656,13 +664,14 @@ def _rand_perm():
else:
scheduler.step()

list_scheduler.append(optimizer.param_groups[0]["lr"])
list_lr.append(optimizer.param_groups[0]["lr"])

return (
losses,
list_scheduler,
[target_idx_np, list_idx],
[mats, mats_gs],
list_lr,
list_idx,
mats,
mats_gs,
list_log_alpha,
)

Expand All @@ -681,6 +690,8 @@ def target_loss(
`positive_examples`: if not None it's a concatenation of correct pairs to use
as context (not masked)
`batch_size`: batch size for the target loss (number of different masks to use at each epoch)
Output: list of target loss values for each masking iteration (`batch_size`)
"""
self._validator(input_left, input_right, fixed_pairings=fixed_pairings)
pbar = tqdm(range(batch_size))
Expand Down
21 changes: 16 additions & 5 deletions nbs/00_core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,14 @@
" `output_dir`: if not None save the plots in this directory\n",
" `save_all_figs`: if True save all the plots at each batch_size\n",
" `only_loss_plot`: if True save only the loss plot at each batch_size\n",
"\n",
" Outputs:\n",
" `losses`: list of loss values for each iteration (`batch_size`*`epochs`)\n",
" `list_lr`: list of the learning rate used at each epoch\n",
" `list_idx`: list of the indexes of the predicted pairs at each iteration (`batch_size`*`epochs`)\n",
" `mats`: list of the permutation matrices at each epoch (hard permutation)\n",
" `mats_gs`: list of the soft-permutation matrices at each epoch\n",
" `list_log_alpha`: list of the log_alpha matrices at each epoch\n",
" \"\"\"\n",
" self._validator(input_left, input_right, fixed_pairings=fixed_pairings)\n",
" if not sum(self._effective_depth_not_fixed):\n",
Expand Down Expand Up @@ -525,7 +533,7 @@
" mats, mats_gs = [], []\n",
" list_idx = []\n",
" list_log_alpha = []\n",
" list_scheduler = []\n",
" list_lr = []\n",
" gs_matching_mat = None\n",
" target_idx = torch.arange(\n",
" self._depth_total, dtype=torch.float, device=self.device\n",
Expand Down Expand Up @@ -712,13 +720,14 @@
" else:\n",
" scheduler.step()\n",
"\n",
" list_scheduler.append(optimizer.param_groups[0][\"lr\"])\n",
" list_lr.append(optimizer.param_groups[0][\"lr\"])\n",
"\n",
" return (\n",
" losses,\n",
" list_scheduler,\n",
" [target_idx_np, list_idx],\n",
" [mats, mats_gs],\n",
" list_lr,\n",
" list_idx,\n",
" mats,\n",
" mats_gs,\n",
" list_log_alpha,\n",
" )\n",
"\n",
Expand All @@ -737,6 +746,8 @@
" `positive_examples`: if not None it's a concatenation of correct pairs to use\n",
" as context (not masked)\n",
" `batch_size`: batch size for the target loss (number of different masks to use at each epoch)\n",
"\n",
" Output: list of target loss values for each masking iteration (`batch_size`)\n",
" \"\"\"\n",
" self._validator(input_left, input_right, fixed_pairings=fixed_pairings)\n",
" pbar = tqdm(range(batch_size))\n",
Expand Down
16 changes: 12 additions & 4 deletions nbs/_example_prokaryotic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,13 @@
"When `save_all_figs=True`, a figure is saved and shown after each gradient step, illustrating the current state of the optimization. This slows the overall optimization down and may create memory leakage issues. Set `save_all_figs=False` to only have the figure saved and shown after the last gradient step."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The plotting function is able to show the number of correctly predicted pairs because the ground truth pairs are known. The model assumes that the input pairs are already correctly matched (i.e. the correct matching matrix is a diagonal matrix) because in the HK-RR and MALG-MALK datasets the sequences are are already ordered with the correct matches in the same position of the MSA."
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -219,8 +226,9 @@
"\n",
"(losses,\n",
" list_scheduler,\n",
" indexes,\n",
" [mat_perm, mat_gs],\n",
" shuffled_indexes,\n",
" mat_perm,\n",
" mat_gs,\n",
" list_log_alpha) = dpalm.train(\n",
" left_msa,\n",
" right_msa,\n",
Expand All @@ -232,9 +240,9 @@
")\n",
"\n",
"results = {\n",
" \"trainng_results\": (losses, list_scheduler, indexes, [mat_perm, mat_gs], list_log_alpha),\n",
" \"trainng_results\": (losses, list_scheduler, shuffled_indexes, [mat_perm, mat_gs], list_log_alpha),\n",
" \"target_loss\": tar_loss,\n",
" \"depth\": depth\n",
" \"species_sizes\": species_sizes\n",
"}"
]
}
Expand Down

0 comments on commit 60a2083

Please sign in to comment.