Skip to content

Commit

Permalink
Created tybalt playground notebook. Minor TybaltVAE fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
norberz committed May 16, 2023
1 parent 81c52b5 commit c20681d
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 1 deletion.
72 changes: 72 additions & 0 deletions TybaltPlayground.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train Tybalt VAE"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from models.Tybalt.TybaltVAE import TybaltVAE\n",
"\n",
"batchsize = 128\n",
"input_size = 5000\n",
"output_size = 5000\n",
"\n",
"model = TybaltVAE(input_size=input_size, output_size=output_size)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from models.Tybalt.TybaltData import getTybaltDatasets\n",
"from torch.utils.data import DataLoader\n",
"\n",
"data_path = './tybaltdata/pancan_scaled_zeroone_rnaseq.tsv.gz'\n",
"dataset_train, dataset_val = getTybaltDatasets()\n",
"\n",
"dataloader_train = DataLoader(dataset_train,\n",
" batch_size = batchsize,\n",
" shuffle = True)\n",
"\n",
"dataloader_val = DataLoader(dataset_val,\n",
" batch_size = batchsize,\n",
" shuffle = False)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion models/Tybalt/TybaltVAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, output_size, zsize):
self.decode = nn.Sequential(
nn.Linear(zsize, 1000),
nn.Sigmoid(),
nn.Linear(output_size),
nn.Linear(1000, output_size),
nn.Sigmoid()
)

Expand Down

0 comments on commit c20681d

Please sign in to comment.