From 6306f1bd82b6bc51f5fee44d03fee1bfb675b8bc Mon Sep 17 00:00:00 2001 From: Wouter Besse Date: Tue, 16 May 2023 15:55:50 +0200 Subject: [PATCH] Finished first iteration of Tybalt + training --- TybaltPlayground.ipynb | 286 +++++++++++++++++++++++++++++++++++- models/Tybalt/TybaltData.py | 4 +- models/Tybalt/TybaltVAE.py | 4 +- models/Tybalt/train.py | 203 +++++++++++++++++++++++++ 4 files changed, 485 insertions(+), 12 deletions(-) create mode 100644 models/Tybalt/train.py diff --git a/TybaltPlayground.ipynb b/TybaltPlayground.ipynb index 562d609..e7c8495 100644 --- a/TybaltPlayground.ipynb +++ b/TybaltPlayground.ipynb @@ -18,17 +18,68 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\woute\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python38\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "data": { + "text/plain": [ + "TybaltVAE(\n", + " (encoder): Encoder(\n", + " (linear_1): Sequential(\n", + " (0): Linear(in_features=5000, out_features=1000, bias=True)\n", + " (1): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU()\n", + " )\n", + " (linear_mu): Sequential(\n", + " (0): Linear(in_features=1000, out_features=32, bias=True)\n", + " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU()\n", + " )\n", + " (linear_var): Sequential(\n", + " (0): Linear(in_features=1000, out_features=32, bias=True)\n", + " (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU()\n", + " )\n", + " )\n", + " (decoder): Decoder(\n", + " (decode): Sequential(\n", + " (0): Linear(in_features=32, out_features=1000, bias=True)\n", + " (1): Sigmoid()\n", + " (2): Linear(in_features=1000, out_features=5000, bias=True)\n", + " (3): Sigmoid()\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from models.Tybalt.TybaltVAE import TybaltVAE\n", "\n", - "batchsize = 128\n", + "batchsize = 512\n", "input_size = 5000\n", "output_size = 5000\n", + "export_path = './exports/Tybalt/'\n", + "learning_rate = 0.00001\n", + "epochs = 100\n", + "device = 'cuda:0'\n", "\n", - "model = TybaltVAE(input_size=input_size, output_size=output_size)" + "\n", + "model = TybaltVAE(input_size=input_size, output_size=output_size)\n", + "model.to(device)" ] }, { @@ -41,15 +92,24 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded data of size: torch.Size([1046, 5000])\n", + "Loaded data of size: torch.Size([9413, 5000])\n" + ] + } + ], "source": [ "from models.Tybalt.TybaltData import getTybaltDatasets\n", "from torch.utils.data import DataLoader\n", "\n", "data_path = './tybaltdata/pancan_scaled_zeroone_rnaseq.tsv.gz'\n", - "dataset_train, dataset_val = getTybaltDatasets()\n", + "dataset_train, dataset_val = getTybaltDatasets(data_path)\n", "\n", "dataloader_train = DataLoader(dataset_train,\n", " batch_size = batchsize,\n", @@ -59,11 +119,221 @@ " batch_size = batchsize,\n", " shuffle = False)" ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Validating. Rec loss: 0.08.: 100%|██████████| 9/9 [00:00<00:00, 14.00it/s] 12/74 [00:08<00:13, 4.67it/s]\n", + "Validating. Rec loss: 0.07.: 100%|██████████| 9/9 [00:00<00:00, 65.48it/s]m| 24/74 [00:10<00:06, 8.19it/s]\n", + "Validating. Rec loss: 0.07.: 100%|██████████| 9/9 [00:00<00:00, 65.48it/s]m| 39/74 [00:11<00:02, 16.35it/s]\n", + "Validating. Rec loss: 0.07.: 100%|██████████| 9/9 [00:00<00:00, 63.99it/s]m| 53/74 [00:12<00:01, 18.58it/s]\n", + "Validating. Rec loss: 0.06.: 100%|██████████| 9/9 [00:00<00:00, 59.35it/s]m| 66/74 [00:13<00:00, 18.53it/s]\n", + "Training. Rec/real loss for step 74: 0.06/233.94.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:13<00:00, 5.42it/s]\n", + "Validating. Rec loss: 0.06.: 100%|██████████| 9/9 [00:00<00:00, 67.56it/s]m| 11/74 [00:00<00:03, 18.96it/s]\n", + "Validating. Rec loss: 0.06.: 100%|██████████| 9/9 [00:00<00:00, 45.69it/s]m| 25/74 [00:01<00:02, 18.88it/s]\n", + "Validating. Rec loss: 0.05.: 100%|██████████| 9/9 [00:00<00:00, 54.44it/s]m| 40/74 [00:02<00:01, 18.57it/s]\n", + "Validating. Rec loss: 0.05.: 100%|██████████| 9/9 [00:00<00:00, 66.95it/s]m| 53/74 [00:03<00:01, 18.46it/s]\n", + "Validating. Rec loss: 0.05.: 100%|██████████| 9/9 [00:00<00:00, 70.44it/s]m| 67/74 [00:04<00:00, 19.51it/s]\n", + "Training. Rec/real loss for step 74: 0.05/173.18.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 16.71it/s]\n", + "Validating. Rec loss: 0.05.: 100%|██████████| 9/9 [00:00<00:00, 64.80it/s]m| 12/74 [00:00<00:03, 20.15it/s]\n", + "Validating. Rec loss: 0.05.: 100%|██████████| 9/9 [00:00<00:00, 63.85it/s]m| 26/74 [00:01<00:02, 19.18it/s]\n", + "Validating. Rec loss: 0.05.: 100%|██████████| 9/9 [00:00<00:00, 54.21it/s]m| 40/74 [00:02<00:01, 18.75it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 61.38it/s]m| 54/74 [00:03<00:01, 18.22it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 56.74it/s]m| 67/74 [00:04<00:00, 18.57it/s]\n", + "Training. Rec/real loss for step 74: 0.04/117.73.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 16.46it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 39.45it/s]m| 12/74 [00:00<00:02, 21.41it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 61.76it/s]m| 25/74 [00:01<00:02, 17.75it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 71.84it/s]m| 38/74 [00:02<00:01, 18.59it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 69.21it/s]m| 52/74 [00:03<00:01, 19.24it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 68.60it/s]m| 67/74 [00:03<00:00, 19.50it/s]\n", + "Training. Rec/real loss for step 74: 0.04/260.62.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 17.07it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 67.96it/s]m| 12/74 [00:00<00:02, 20.73it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 68.01it/s]m| 26/74 [00:01<00:02, 17.96it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 60.53it/s]m| 40/74 [00:02<00:01, 18.99it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 43.13it/s]m| 53/74 [00:03<00:01, 17.45it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 57.01it/s]m| 67/74 [00:04<00:00, 17.70it/s]\n", + "Training. Rec/real loss for step 74: 0.04/104.7.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 15.87it/s] \n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 55.77it/s]m| 12/74 [00:00<00:03, 17.90it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 57.37it/s]m| 26/74 [00:01<00:02, 18.01it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 44.09it/s]m| 39/74 [00:02<00:01, 17.58it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 50.70it/s]m| 53/74 [00:03<00:01, 17.68it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 63.30it/s]m| 68/74 [00:04<00:00, 18.46it/s]\n", + "Training. Rec/real loss for step 74: 0.04/144.31.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 15.43it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 66.08it/s]m| 12/74 [00:00<00:02, 21.90it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 56.83it/s]m| 26/74 [00:01<00:02, 18.90it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 52.31it/s]m| 40/74 [00:02<00:02, 16.35it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 67.18it/s]m| 52/74 [00:03<00:01, 17.98it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 58.63it/s]m| 67/74 [00:04<00:00, 18.79it/s]\n", + "Training. Rec/real loss for step 74: 0.04/136.59.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 16.49it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 73.79it/s]m| 12/74 [00:00<00:02, 22.13it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 67.65it/s]m| 26/74 [00:01<00:02, 19.74it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 66.44it/s]m| 40/74 [00:02<00:01, 18.48it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 67.67it/s]m| 54/74 [00:03<00:01, 19.42it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 68.41it/s]m| 66/74 [00:03<00:00, 19.13it/s]\n", + "Training. Rec/real loss for step 74: 0.04/121.95.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 17.69it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 75.15it/s]m| 12/74 [00:00<00:02, 22.12it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 69.53it/s]m| 24/74 [00:01<00:02, 20.19it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 67.27it/s]m| 40/74 [00:02<00:01, 19.29it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 61.43it/s]m| 54/74 [00:02<00:01, 19.63it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 65.68it/s]m| 66/74 [00:03<00:00, 18.78it/s]\n", + "Training. Rec/real loss for step 74: 0.03/114.05.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 17.67it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 67.58it/s]m| 12/74 [00:00<00:02, 21.38it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 37.62it/s]m| 24/74 [00:01<00:02, 19.25it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 51.65it/s]m| 39/74 [00:02<00:01, 17.71it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 54.94it/s]| 54/74 [00:03<00:01, 18.72it/s] \n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 58.13it/s]m| 66/74 [00:04<00:00, 18.14it/s]\n", + "Training. Rec/real loss for step 74: 0.04/116.81.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 16.21it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 51.66it/s]m| 12/74 [00:00<00:02, 20.86it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 52.33it/s]m| 25/74 [00:01<00:02, 17.55it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 62.30it/s]m| 40/74 [00:02<00:01, 18.50it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 56.49it/s]m| 54/74 [00:03<00:01, 18.62it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 50.29it/s]m| 67/74 [00:04<00:00, 18.42it/s]\n", + "Training. Rec/real loss for step 74: 0.03/140.11.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 15.87it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 67.67it/s]m| 12/74 [00:00<00:02, 21.53it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 54.20it/s]m| 24/74 [00:01<00:02, 19.03it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 51.41it/s]m| 40/74 [00:02<00:01, 18.90it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 47.19it/s]m| 54/74 [00:03<00:01, 18.44it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 55.94it/s]m| 66/74 [00:04<00:00, 17.41it/s]\n", + "Training. Rec/real loss for step 74: 0.04/130.66.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 16.21it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 66.84it/s]m| 12/74 [00:00<00:02, 22.15it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 68.27it/s]m| 24/74 [00:01<00:02, 19.30it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 67.42it/s]m| 38/74 [00:02<00:01, 19.19it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 60.82it/s]m| 54/74 [00:03<00:01, 18.81it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 68.74it/s]m| 68/74 [00:03<00:00, 19.34it/s]\n", + "Training. Rec/real loss for step 74: 0.04/100.16.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 17.36it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 58.74it/s]m| 12/74 [00:00<00:02, 21.93it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 42.38it/s]| 26/74 [00:01<00:02, 19.07it/s] \n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 56.92it/s]| 38/74 [00:02<00:02, 17.91it/s] \n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 62.75it/s]m| 53/74 [00:03<00:01, 18.27it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 59.63it/s]m| 66/74 [00:04<00:00, 17.43it/s]\n", + "Training. Rec/real loss for step 74: 0.03/121.83.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 16.41it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 55.50it/s]m| 12/74 [00:00<00:02, 22.07it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 63.74it/s]| 24/74 [00:01<00:02, 18.95it/s] \n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 68.44it/s]m| 38/74 [00:02<00:01, 19.00it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 57.29it/s]m| 53/74 [00:03<00:01, 19.72it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 60.83it/s]m| 68/74 [00:03<00:00, 18.28it/s]\n", + "Training. Rec/real loss for step 74: 0.03/110.23.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 16.93it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 68.07it/s]m| 10/74 [00:00<00:03, 20.61it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 65.03it/s]m| 25/74 [00:01<00:02, 19.90it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 56.49it/s]m| 40/74 [00:02<00:01, 19.69it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 36.55it/s]| 53/74 [00:03<00:01, 17.19it/s] \n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 54.70it/s]| 67/74 [00:04<00:00, 12.09it/s] \n", + "Training. Rec/real loss for step 74: 0.04/135.07.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 14.93it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 49.66it/s]m| 11/74 [00:00<00:03, 20.32it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 50.85it/s]m| 25/74 [00:01<00:02, 18.68it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 40.13it/s]m| 39/74 [00:02<00:02, 16.77it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 52.23it/s]m| 53/74 [00:03<00:01, 16.74it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 50.91it/s]m| 67/74 [00:04<00:00, 17.36it/s]\n", + "Training. Rec/real loss for step 74: 0.03/148.31.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:05<00:00, 14.80it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 43.92it/s]m| 11/74 [00:00<00:03, 19.87it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 48.65it/s]m| 25/74 [00:01<00:02, 17.60it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 40.21it/s]m| 39/74 [00:02<00:02, 16.80it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 46.70it/s]m| 53/74 [00:03<00:01, 14.97it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 50.57it/s]m| 67/74 [00:04<00:00, 16.79it/s]\n", + "Training. Rec/real loss for step 74: 0.04/152.17.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:05<00:00, 14.23it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 40.90it/s]m| 12/74 [00:00<00:03, 19.05it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 55.61it/s]m| 26/74 [00:01<00:02, 16.17it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 35.97it/s]m| 40/74 [00:02<00:02, 16.96it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 48.85it/s]m| 54/74 [00:03<00:01, 16.68it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 44.52it/s]m| 68/74 [00:04<00:00, 16.81it/s]\n", + "Training. Rec/real loss for step 74: 0.03/106.07.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:05<00:00, 14.02it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 52.86it/s]m| 12/74 [00:00<00:03, 19.38it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 43.84it/s]m| 26/74 [00:01<00:02, 16.72it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 49.32it/s]m| 40/74 [00:02<00:02, 16.22it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 44.57it/s]m| 54/74 [00:03<00:01, 17.47it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 44.86it/s]m| 68/74 [00:04<00:00, 16.74it/s]\n", + "Training. Rec/real loss for step 74: 0.04/158.95.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:05<00:00, 14.43it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 48.15it/s]m| 12/74 [00:00<00:03, 18.28it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 60.23it/s]m| 25/74 [00:01<00:02, 18.11it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 59.92it/s]m| 39/74 [00:02<00:01, 18.14it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 60.45it/s]m| 53/74 [00:03<00:01, 18.82it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 57.02it/s]m| 68/74 [00:04<00:00, 17.81it/s]\n", + "Training. Rec/real loss for step 74: 0.04/167.11.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 15.90it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 63.29it/s]m| 11/74 [00:00<00:03, 20.34it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 56.55it/s]m| 26/74 [00:01<00:02, 18.51it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 41.60it/s]m| 40/74 [00:02<00:01, 18.17it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 60.16it/s]| 54/74 [00:03<00:01, 17.68it/s] \n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 58.76it/s]m| 68/74 [00:04<00:00, 18.70it/s]\n", + "Training. Rec/real loss for step 74: 0.04/130.18.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 15.90it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 55.70it/s]m| 11/74 [00:00<00:03, 18.59it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 60.66it/s]| 26/74 [00:01<00:02, 18.66it/s] \n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 56.85it/s]m| 40/74 [00:02<00:01, 18.33it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 47.20it/s]m| 53/74 [00:03<00:01, 18.82it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 64.09it/s]m| 66/74 [00:04<00:00, 18.39it/s]\n", + "Training. Rec/real loss for step 74: 0.03/114.12.: 100%|\u001b[35m██████████\u001b[0m| 74/74 [00:04<00:00, 16.30it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 48.93it/s]m| 11/74 [00:00<00:03, 18.51it/s]\n", + "Validating. Rec loss: 0.04.: 100%|██████████| 9/9 [00:00<00:00, 69.13it/s]m| 25/74 [00:01<00:02, 17.70it/s]\n", + "Training. Rec/real loss for step 39: 0.04/240.61.: 53%|\u001b[35m█████▎ \u001b[0m| 39/74 [00:02<00:02, 15.85it/s]\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[1;32mm:\\Projects\\2022-2023\\Sonified-Latent-Data\\TybaltPlayground.ipynb Cell 7\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 5\u001b[0m warnings\u001b[39m.\u001b[39mfilterwarnings(\u001b[39m\"\u001b[39m\u001b[39mignore\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m 6\u001b[0m writer \u001b[39m=\u001b[39m SummaryWriter()\n\u001b[1;32m----> 8\u001b[0m train(model, dataloader_train, dataloader_val, \n\u001b[0;32m 9\u001b[0m writer\u001b[39m=\u001b[39;49mwriter, \n\u001b[0;32m 10\u001b[0m export_path\u001b[39m=\u001b[39;49mexport_path,\n\u001b[0;32m 11\u001b[0m learning_rate\u001b[39m=\u001b[39;49mlearning_rate,\n\u001b[0;32m 12\u001b[0m epoch_amount\u001b[39m=\u001b[39;49mepochs,\n\u001b[0;32m 13\u001b[0m device\u001b[39m=\u001b[39;49mdevice)\n", + "File \u001b[1;32mm:\\Projects\\2022-2023\\Sonified-Latent-Data\\models\\Tybalt\\train.py:90\u001b[0m, in \u001b[0;36mtrain\u001b[1;34m(model, dataloader_train, dataloader_val, writer, export_path, learning_rate, epoch_amount, logs_per_epoch, kl_anneal, max_kl, device, verbose)\u001b[0m\n\u001b[0;32m 87\u001b[0m divstep \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[0;32m 89\u001b[0m \u001b[39mwith\u001b[39;00m tqdm(\u001b[39menumerate\u001b[39m(dataloader_train), total\u001b[39m=\u001b[39m\u001b[39mlen\u001b[39m(dataloader_train), desc\u001b[39m=\u001b[39m\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mTraining. Epoch: \u001b[39m\u001b[39m{\u001b[39;00mepoch\u001b[39m}\u001b[39;00m\u001b[39m. Loss for step \u001b[39m\u001b[39m{\u001b[39;00mstep\u001b[39m}\u001b[39;00m\u001b[39m: n.v.t.\u001b[39m\u001b[39m\"\u001b[39m, colour\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39mmagenta\u001b[39m\u001b[39m'\u001b[39m) \u001b[39mas\u001b[39;00m t:\n\u001b[1;32m---> 90\u001b[0m \u001b[39mfor\u001b[39;00m batch_idx, (x) \u001b[39min\u001b[39;00m t:\n\u001b[0;32m 91\u001b[0m model\u001b[39m.\u001b[39mtrain(\u001b[39mTrue\u001b[39;00m)\n\u001b[0;32m 92\u001b[0m optimizer\u001b[39m.\u001b[39mzero_grad(set_to_none\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python38\\site-packages\\tqdm\\std.py:1178\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 1175\u001b[0m time \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_time\n\u001b[0;32m 1177\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m-> 1178\u001b[0m \u001b[39mfor\u001b[39;00m obj \u001b[39min\u001b[39;00m iterable:\n\u001b[0;32m 1179\u001b[0m \u001b[39myield\u001b[39;00m obj\n\u001b[0;32m 1180\u001b[0m \u001b[39m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[0;32m 1181\u001b[0m \u001b[39m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python38\\site-packages\\torch\\utils\\data\\dataloader.py:652\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 649\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 650\u001b[0m \u001b[39m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[0;32m 651\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset() \u001b[39m# type: ignore[call-arg]\u001b[39;00m\n\u001b[1;32m--> 652\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[0;32m 653\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[0;32m 654\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[0;32m 655\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[0;32m 656\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python38\\site-packages\\torch\\utils\\data\\dataloader.py:692\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 690\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[0;32m 691\u001b[0m index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index() \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m--> 692\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index) \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m 693\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[0;32m 694\u001b[0m data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory_device)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python38\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:49\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m 47\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mfetch\u001b[39m(\u001b[39mself\u001b[39m, possibly_batched_index):\n\u001b[0;32m 48\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mauto_collation:\n\u001b[1;32m---> 49\u001b[0m data \u001b[39m=\u001b[39m [\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[idx] \u001b[39mfor\u001b[39;00m idx \u001b[39min\u001b[39;00m possibly_batched_index]\n\u001b[0;32m 50\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m 51\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python38\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:49\u001b[0m, in \u001b[0;36m\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 47\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mfetch\u001b[39m(\u001b[39mself\u001b[39m, possibly_batched_index):\n\u001b[0;32m 48\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mauto_collation:\n\u001b[1;32m---> 49\u001b[0m data \u001b[39m=\u001b[39m [\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[idx] \u001b[39mfor\u001b[39;00m idx \u001b[39min\u001b[39;00m possibly_batched_index]\n\u001b[0;32m 50\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m 51\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n", + "File \u001b[1;32mm:\\Projects\\2022-2023\\Sonified-Latent-Data\\models\\Tybalt\\TybaltData.py:28\u001b[0m, in \u001b[0;36mTybaltDataset.__getitem__\u001b[1;34m(self, idx)\u001b[0m\n\u001b[0;32m 27\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__getitem__\u001b[39m(\u001b[39mself\u001b[39m, idx):\n\u001b[1;32m---> 28\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdata[idx]\u001b[39m.\u001b[39;49mtype(torch\u001b[39m.\u001b[39;49mFloatTensor)\n", + "\u001b[1;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "from models.Tybalt.train import train\n", + "\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "writer = SummaryWriter()\n", + "\n", + "train(model, dataloader_train, dataloader_val, \n", + " writer=writer, \n", + " export_path=export_path,\n", + " learning_rate=learning_rate,\n", + " epoch_amount=epochs,\n", + " logs_per_epoch=1,\n", + " device=device)" + ] } ], "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" }, "orig_nbformat": 4 }, diff --git a/models/Tybalt/TybaltData.py b/models/Tybalt/TybaltData.py index 8b1cca1..b39465c 100644 --- a/models/Tybalt/TybaltData.py +++ b/models/Tybalt/TybaltData.py @@ -22,9 +22,9 @@ def __init__(self, dataframe): print("Loaded data of size: ", self.data.size()) def __len__(self): - return len(self.files) + return self.data.size()[0] def __getitem__(self, idx): - return self.data[idx] + return self.data[idx].type(torch.FloatTensor) diff --git a/models/Tybalt/TybaltVAE.py b/models/Tybalt/TybaltVAE.py index d7c26be..e763fb6 100644 --- a/models/Tybalt/TybaltVAE.py +++ b/models/Tybalt/TybaltVAE.py @@ -72,8 +72,8 @@ def forward(self, x, verbose = False): mean (Tensor): Mean of latent space, shape (B x zsize) var (Tensor): Variance of latent space, shape (B x zsize) """ - mean, log_var = self.encoder(x, verbose) + mean, log_var = self.encoder(x) z = self.sample(mean, log_var) - x_hat = self.decoder(z, verbose) + x_hat = self.decoder(z) return x_hat, mean, log_var \ No newline at end of file diff --git a/models/Tybalt/train.py b/models/Tybalt/train.py new file mode 100644 index 0000000..f35f923 --- /dev/null +++ b/models/Tybalt/train.py @@ -0,0 +1,203 @@ +from tqdm.auto import tqdm +import torch +import os +import argparse + +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from models.Tybalt.TybaltVAE import TybaltVAE +from models.Tybalt.TybaltData import getTybaltDatasets +import warnings +from datetime import datetime + +def calculate_loss(x_hat, x, mu, logvar, kl_term, loss_fn): + reconstruction_loss = loss_fn(x_hat, x) + kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + kl_loss = torch.mean(kl_loss, dim=0) + + return reconstruction_loss + kl_loss * kl_term, reconstruction_loss, kl_loss + +def clear_screen(): + # Clearing the Screen + # posix is os name for Linux or mac + if (os.name == 'posix'): + os.system('clear') + # else screen will be cleared for windows + else: + os.system('cls') + +def anneal_kl(kl_term, kl_annealing, kl_max): + kl_term += kl_annealing + + return max(kl_term, kl_max) + +def export_model(model, path, epoch, name = None): + isExist = os.path.exists(path) + if not isExist: + os.makedirs(path) + modelname = '' + if name is None: + date = datetime.today().strftime('%Y-%m-%d') + modelname = date + str(epoch) + else: + modelname = str(epoch) + path = os.path.join(path, modelname) + torch.save(model.state_dict(), path) + +def validate(model, dataloader, kl_mult, loss_fn, device='cuda', verbose = False): + model.eval() + total_eval_loss = [0, 0, 0] + eval_step = 1 + + with torch.no_grad(): + + with tqdm(enumerate(dataloader), total=len(dataloader), desc="Validating", colour='orange') as t: + for batch_idx, (x) in t: + x = x.to(device) + + x_hat, mean, variance = model(x) + real_loss, rec_loss, kl_loss = calculate_loss( + x_hat, x, mean, variance, kl_mult, loss_fn) + total_eval_loss = [ + total_eval_loss[0] + real_loss.item(), + total_eval_loss[1] + rec_loss.item(), + total_eval_loss[2] + kl_loss.item() + ] + + t.set_description( + f"Validating. Rec loss: {round(rec_loss.item(), 2)}.") + eval_step += 1 + + return total_eval_loss[0] / eval_step, total_eval_loss[1] / eval_step, total_eval_loss[2] / eval_step + + +def train(model, dataloader_train, dataloader_val, writer, export_path, learning_rate=0.00001, epoch_amount=100, logs_per_epoch=5, kl_anneal=0.01, max_kl=0.5, device='cuda', verbose = False): + torch.cuda.empty_cache() + loss_fn = torch.nn.MSELoss() + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) + logstep = 0 + kl_mult = 0.0 + total_step = 0 + + + for epoch in range(epoch_amount): + model.train(True) + total_epoch_loss = [0, 0, 0] + step = 1 + divstep = 1 + + with tqdm(enumerate(dataloader_train), total=len(dataloader_train), desc=f"Training. Epoch: {epoch}. Loss for step {step}: n.v.t.", colour='magenta') as t: + for batch_idx, (x) in t: + model.train(True) + optimizer.zero_grad(set_to_none=True) + + x = x.to(device) + + x_hat, mean, variance = model(x) + + real_loss, rec_loss, kl_loss = calculate_loss( + x_hat, x, mean, variance, kl_mult, loss_fn) + + real_loss.backward() + optimizer.step() + + # Save losses for total, reconstruction and kl seperately for better inspection of optimisation for different parts + total_epoch_loss = [ + total_epoch_loss[0] + real_loss.item(), + total_epoch_loss[1] + rec_loss.item(), + total_epoch_loss[2] + kl_loss.item() + ] + + t.set_description( + f"Training. Rec/real loss for step {step}: {round(rec_loss.item(), 2)}/{round(real_loss.item(), 2)}.") + writer.add_scalar('Train step loss:', + real_loss.item(), total_step) + step += 1 + total_step += 1 + divstep += 1 + + if step % (len(dataloader_train) // logs_per_epoch) == 0 or step - 1 == 0: + + eval_loss_real, eval_loss_rec, eval_loss_kl = validate( + model, dataloader_val, kl_mult, loss_fn, device, verbose) + + writer.add_scalars('Validation Loss', { + 'Real loss': eval_loss_real, + 'Reconstruction loss': eval_loss_rec, + 'KL loss': eval_loss_kl + }, logstep) + + writer.add_scalars('Train loss', { + 'Real loss': total_epoch_loss[0] / divstep, + 'Reconstruction loss': total_epoch_loss[1] / divstep, + 'Kl loss': total_epoch_loss[2] / divstep + }, logstep) + + logstep += 1 + total_epoch_loss = [0, 0, 0] + divstep = 0 + kl_mult = anneal_kl(kl_mult, kl_anneal, max_kl) + + # if epoch % 5 == 0: + export_model(model, export_path, epoch=epoch) + + export_model(model, export_path, epoch='final') + +if __name__ == '__main__': + warnings.filterwarnings("ignore") + clear_screen() + + parser = argparse.ArgumentParser(description = 'Train the model') + parser.add_argument('-dp', '--data_path', help='enter the path to the training data', type=str) + parser.add_argument('-ep' , '--epochs', help='enter the amount of epochs', type=int) + parser.add_argument('-ex', '--export_path', help='Location to export models', type=str, default = './exports/tybalt/') + parser.add_argument('-bs', '--batch_size', help='enter the batch size', type=int, default = 2) + parser.add_argument('-lr', '--learning_rate', help='enter the learning rate', type=float, default = 0.00001) + parser.add_argument('-kla', '--kl_anneal', help='enter the kl anneal', type=float, default = 0.01) + parser.add_argument('-mkl', '--max_kl', help='enter the max kl', type=float, default = 0.5) + parser.add_argument('-lpe', '--logs_per_epoch', help='enter the logs per epoch', type=int, default = 6) + parser.add_argument('-d', '--device', help='enter the device', type=str, default = 'cuda:2') + parser.add_argument('-mf', '--max_files', help='enter the max files', type=int, default = 800) + + + args = parser.parse_args() + + batchsize = args.batch_size + device = args.device + + isExist = os.path.exists(args.export_path) + if not isExist: + os.makedirs(args.export_path) + print("Export path didn't exist, created directory") + else: + print("Export path exists") + + input_size = 5000 + output_size = 5000 + + model = TybaltVAE(input_size=input_size, output_size=output_size) + + model.to(device) + export_model(model, args.export_path, 0) + + dataset_train, dataset_val = getTybaltDatasets(args.dp) + + dataloader_train = DataLoader(dataset_train, + batch_size = batchsize, + shuffle = True) + + dataloader_val = DataLoader(dataset_val, + batch_size = batchsize, + shuffle = False) + + writer = SummaryWriter() + + train(model, dataloader_train, dataloader_val, + writer=writer, + export_path=args.export_path, + learning_rate=args.learning_rate, + epoch_amount=args.epochs, + logs_per_epoch=args.logs_per_epoch, + kl_anneal=args.kl_anneal, + max_kl=args.max_kl, + device=device) \ No newline at end of file