From 4da49c35f843e28c6d97fc311fa85efe55c4a14d Mon Sep 17 00:00:00 2001 From: Etienne Tremblay Date: Wed, 17 Feb 2021 15:01:56 -0500 Subject: [PATCH] Implemented MMD vae. --- TabularAE.ipynb | 96 +- TabularVAE.ipynb => TabularVAE-KL.ipynb | 0 TabularVAE-MMD.ipynb | 1938 +++++++++++++++++++++++ 3 files changed, 1980 insertions(+), 54 deletions(-) rename TabularVAE.ipynb => TabularVAE-KL.ipynb (100%) create mode 100644 TabularVAE-MMD.ipynb diff --git a/TabularAE.ipynb b/TabularAE.ipynb index 4b1f509..4287840 100644 --- a/TabularAE.ipynb +++ b/TabularAE.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -17,26 +17,26 @@ "output_type": "stream", "text": [ "Name: fastai\n", - "Version: 2.0.15\n", + "Version: 2.0.8\n", "Summary: fastai simplifies training fast and accurate neural nets using modern best practices\n", - "Home-page: https://github.com/fastai/fastai/tree/master/\n", + "Home-page: https://github.com/fastai/fastai\n", "Author: Jeremy Howard, Sylvain Gugger, and contributors\n", "Author-email: info@fast.ai\n", "License: Apache Software License 2.0\n", - "Location: /usr/local/lib/python3.6/dist-packages\n", - "Requires: torchvision, scikit-learn, packaging, pip, pyyaml, pillow, torch, fastprogress, requests, spacy, scipy, pandas, fastcore, matplotlib\n", + "Location: /mnt/c/work/ml/fastai\n", + "Requires: pip, packaging, fastcore, torchvision, matplotlib, pandas, requests, pyyaml, fastprogress, pillow, scikit-learn, scipy, spacy, pandas, torch\n", "Required-by: \n", "---\n", "Name: fastcore\n", - "Version: 1.0.16\n", + "Version: 1.0.1\n", "Summary: Python supercharged for fastai development\n", - "Home-page: https://github.com/fastai/fastcore/tree/master/\n", + "Home-page: https://github.com/fastai/fastcore\n", "Author: Jeremy Howard and Sylvain Gugger\n", "Author-email: infos@fast.ai\n", "License: Apache Software License 2.0\n", - "Location: /usr/local/lib/python3.6/dist-packages\n", - "Requires: packaging, pip\n", - "Required-by: fastai\n" + "Location: /home/etienne/miniconda3/envs/fastai/lib/python3.8/site-packages\n", + "Requires: packaging, wheel, pip, numpy, dataclasses\n", + "Required-by: nbdev, fastai2, fastai\n" ] } ], @@ -507,7 +507,7 @@ }, { "cell_type": "code", - "execution_count": 103, + "execution_count": 18, "metadata": { "id": "0moqukV_EdOs" }, @@ -563,7 +563,7 @@ }, { "cell_type": "code", - "execution_count": 104, + "execution_count": 19, "metadata": { "id": "24aYC4fk4Qev" }, @@ -606,7 +606,7 @@ }, { "cell_type": "code", - "execution_count": 105, + "execution_count": 20, "metadata": { "id": "zbsDGXPX_VUV" }, @@ -617,14 +617,14 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ - "b = dls.one_batch()\n", + "# b = dls.one_batch()\n", "\n", - "out = learn.model(*b[:2])\n", - "loss = loss_func(out, *b[-2:])" + "# out = learn.model(*b[:2])\n", + "# loss = loss_func(out, *b[-2:])" ] }, { @@ -638,7 +638,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 22, "metadata": { "id": "ThV0ch4x7R58" }, @@ -666,7 +666,7 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 23, "metadata": { "id": "i7xmhROaEdOw" }, @@ -687,7 +687,7 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 24, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -712,57 +712,45 @@ " \n", " \n", " 0\n", - " 8.551754\n", - " 3.598067\n", - " 00:00\n", + " 7.849534\n", + " 3.430095\n", + " 00:03\n", " \n", " \n", " 1\n", - " 4.080617\n", - " 1.116776\n", - " 00:00\n", + " 3.788902\n", + " 1.078850\n", + " 00:03\n", " \n", " \n", " 2\n", - " 2.433498\n", - " 0.573103\n", - " 00:00\n", + " 2.278958\n", + " 0.546157\n", + " 00:03\n", " \n", " \n", " 3\n", - " 1.632088\n", - " 0.242973\n", - " 00:00\n", + " 1.537396\n", + " 0.235419\n", + " 00:03\n", " \n", " \n", " 4\n", - " 1.184522\n", - " 0.177599\n", - " 00:00\n", + " 1.130815\n", + " 0.188008\n", + " 00:03\n", " \n", " \n", " 5\n", - " 0.918635\n", - " 0.137939\n", - " 00:00\n", + " 0.889883\n", + " 0.135979\n", + " 00:03\n", " \n", " \n", " 6\n", - " 0.758113\n", - " 0.131483\n", - " 00:00\n", - " \n", - " \n", - " 7\n", - " 0.649586\n", - " 0.102631\n", - " 00:01\n", - " \n", - " \n", - " 8\n", - " 0.585502\n", - " 0.111842\n", - " 00:01\n", + " 0.736168\n", + " 0.142726\n", + " 00:03\n", " \n", " \n", "" @@ -778,7 +766,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "No improvement since epoch 7: early stopping\n" + "No improvement since epoch 5: early stopping\n" ] } ], diff --git a/TabularVAE.ipynb b/TabularVAE-KL.ipynb similarity index 100% rename from TabularVAE.ipynb rename to TabularVAE-KL.ipynb diff --git a/TabularVAE-MMD.ipynb b/TabularVAE-MMD.ipynb new file mode 100644 index 0000000..76fdd17 --- /dev/null +++ b/TabularVAE-MMD.ipynb @@ -0,0 +1,1938 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 390 + }, + "id": "I4PczoTMNXi3", + "outputId": "c9842093-2082-44ee-fac6-c191946c1e77" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Name: fastai\n", + "Version: 2.0.13\n", + "Summary: fastai simplifies training fast and accurate neural nets using modern best practices\n", + "Home-page: https://github.com/fastai/fastai/tree/master/\n", + "Author: Jeremy Howard, Sylvain Gugger, and contributors\n", + "Author-email: info@fast.ai\n", + "License: Apache Software License 2.0\n", + "Location: c:\\users\\etienne-pc\\miniconda3\\envs\\fastai\\lib\\site-packages\n", + "Requires: torch, pillow, packaging, pyyaml, requests, scipy, fastprogress, torchvision, matplotlib, fastcore, scikit-learn, pandas, pip, spacy\n", + "Required-by: \n", + "---\n", + "Name: fastcore\n", + "Version: 1.0.13\n", + "Summary: Python supercharged for fastai development\n", + "Home-page: https://github.com/fastai/fastcore/tree/master/\n", + "Author: Jeremy Howard and Sylvain Gugger\n", + "Author-email: infos@fast.ai\n", + "License: Apache Software License 2.0\n", + "Location: c:\\users\\etienne-pc\\miniconda3\\envs\\fastai\\lib\\site-packages\n", + "Requires: pip, packaging\n", + "Required-by: tsai, fastai2, fastai\n" + ] + } + ], + "source": [ + "!pip show fastai fastcore" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "pS3C4jJLEdOV" + }, + "outputs": [], + "source": [ + "from matplotlib import cm\n", + "from fastai.tabular.all import *\n", + "\n", + "pd.set_option('display.float_format', lambda x: '%.3f' % x)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pu3gjQ7CNMSv" + }, + "source": [ + "We'll use the `Adult Sample` dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "UmzTzqPTEttQ" + }, + "outputs": [], + "source": [ + "path = untar_data(URLs.ADULT_SAMPLE)\n", + "df = pd.read_csv(path/'adult.csv')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8bCOcb3DNSVd" + }, + "source": [ + "And declare the relevent information:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "2YqxE0AVEdOX" + }, + "outputs": [], + "source": [ + "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n", + "cont_names = ['age', 'fnlwgt', 'education-num']\n", + "procs = [Categorify, FillMissing, Normalize]\n", + "y_names = 'salary'\n", + "y_block = CategoryBlock()\n", + "splits = RandomSplitter()(range_of(df))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Baseline" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "M1MWujHOKwBf" + }, + "outputs": [], + "source": [ + "# to = TabularPandas(df, procs = [Categorify, FillMissing, Normalize], cont_names=cont_names, cat_names=cat_names, splits=splits, y_names=['salary'], reduce_memory=False, \n", + "# y_block=CategoryBlock())" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "G4AKuKfbLgwQ" + }, + "outputs": [], + "source": [ + "# dls = to.dataloaders(bs=1024)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "NoYN1wi3NaYE" + }, + "outputs": [], + "source": [ + "# def accuracy(inp, targ, axis=-1):\n", + "# \"Compute accuracy with `targ` when `pred` is bs * n_classes\"\n", + "# pred,targ = flatten_check(inp.argmax(dim=axis), targ)\n", + "# return (pred == targ).float().mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "lhOMRBYCLAZy" + }, + "outputs": [], + "source": [ + "# learn = tabular_learner(dls, layers=[200,100], config={'ps':.1}, metrics=[accuracy])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 195 + }, + "id": "UuvlsOh4MlN_", + "outputId": "64b00442-2801-4bfe-a9ca-d823c3140590" + }, + "outputs": [], + "source": [ + "# learn.fit(10, 1e-3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# VAE AutoEncoder" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fUIkMhCENnLd" + }, + "source": [ + "Next we need our own version of `ReadTabBatch` that will return our inputs\n", + "\n", + "> The continous variables are still normalized if we used `Normalize`. Couldn't figure out an easy way to de-norm it, but it's okay that we do not" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "08quVRZuEdOc" + }, + "outputs": [], + "source": [ + "class ReadTabBatchIdentity(ItemTransform):\n", + " \"Read a batch of data and return the inputs as both `x` and `y`\"\n", + " def __init__(self, to): store_attr()\n", + "\n", + " def encodes(self, to):\n", + " if not to.with_cont: res = (tensor(to.cats).long(),) + (tensor(to.cats).long(),)\n", + " else: res = (tensor(to.cats).long(),tensor(to.conts).float()) + (tensor(to.cats).long(), tensor(to.conts).float())\n", + " if to.device is not None: res = to_device(res, to.device)\n", + " return res\n", + " \n", + "class TabularPandasIdentity(TabularPandas): pass" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "itUJ_YaoN6uT" + }, + "source": [ + "Next we need to make a new `TabDataLoader` that uses our `RadTabBatchIdentity`:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "C0YUhcd3EdOe" + }, + "outputs": [], + "source": [ + "@delegates()\n", + "class TabDataLoaderIdentity(TabDataLoader):\n", + " \"A transformed `DataLoader` for AutoEncoder problems with Tabular data\"\n", + " do_item = noops\n", + " def __init__(self, dataset, bs=16, shuffle=False, after_batch=None, num_workers=0, **kwargs):\n", + " if after_batch is None: after_batch = L(TransformBlock().batch_tfms)+ReadTabBatchIdentity(dataset)\n", + " super().__init__(dataset, bs=bs, shuffle=shuffle, after_batch=after_batch, num_workers=num_workers, **kwargs)\n", + "\n", + " def create_batch(self, b): return self.dataset.iloc[b]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PX9Vv9OPOAVl" + }, + "source": [ + "And make `TabularPandasIdentity`'s `dl_type` to `TabDataLoaderIdentity`" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "q5EEsDr-1S5_" + }, + "outputs": [], + "source": [ + "TabularPandasIdentity._dl_type = TabDataLoaderIdentity" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tXmzs8Hx1WQd" + }, + "source": [ + "To start we'll make a very basic `to` object using our new `TabularPandasIdentity`:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "4sWPDJNy1aK1" + }, + "outputs": [], + "source": [ + "bs=1024\n", + "\n", + "to = TabularPandasIdentity(df, [Categorify, FillMissing, Normalize], cat_names, cont_names, splits=RandomSplitter(seed=32)(df))\n", + "dls = to.dataloaders(bs=1024)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-kcUFTYoOGO0" + }, + "source": [ + "Set the `n_inp` to 2:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "uQDN7OpH1b9B" + }, + "outputs": [], + "source": [ + "dls.n_inp = 2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ojWg3NDI1iXx" + }, + "source": [ + "And then we'll calculate the embedding sizes:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "NERxSEdK1jyv" + }, + "outputs": [], + "source": [ + "emb_szs = get_emb_sz(to.train)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "L5neKs1q1kuF" + }, + "source": [ + "For each categorical variable we need to know the total possible values it can have:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 134 + }, + "id": "Uv7-oVlPEdOg", + "outputId": "bdfd5add-2f30-423a-aedc-42d564110157" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'workclass': 10,\n", + " 'education': 17,\n", + " 'marital-status': 8,\n", + " 'occupation': 16,\n", + " 'relationship': 7,\n", + " 'race': 6,\n", + " 'education-num_na': 3}" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "total_cats = {k:len(v) for k,v in to.classes.items()}\n", + "total_cats" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tKNiQtro13Dk" + }, + "source": [ + "We will need this dictionary in our loss function to figure out where to apply our `CrossEntropyLossFlat` for each categorical variables" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mpyLIJsO2DwA" + }, + "source": [ + "Next we need to know the total number ouf outputs possible for our categorical variables" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "id": "u56gAbBYEdOi", + "outputId": "e8a87efe-6dc2-48c8-87b8-15a2c088922d" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "67" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sum([v for k,v in total_cats.items()])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2BjAP1OVOYsQ" + }, + "source": [ + "And let's keep a batch of our data for later" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "RAqcnZI2EdOl" + }, + "outputs": [], + "source": [ + "batch = dls.one_batch()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "t5c7Vi-f2aLc" + }, + "source": [ + "Next we need to know the means and standard deviations:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 67 + }, + "id": "rvbLzGcV2cDP", + "outputId": "f17df585-8136-4869-ec7c-d263e3052d26" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'age': 38.5793696495067,\n", + " 'fnlwgt': 190006.02011593536,\n", + " 'education-num': 10.079158782958984}" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "to.means" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QhHUG2Sk2p42" + }, + "source": [ + "We can store them in a `DataFrame` for easy adjustments:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "id": "ttuLY8iWEdOn" + }, + "outputs": [], + "source": [ + "means = pd.DataFrame.from_dict({k:[v] for k,v in to.means.items()})\n", + "stds = pd.DataFrame.from_dict({k:[v] for k,v in to.stds.items()})" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "k5rHaP4N2yIl" + }, + "source": [ + "We'll also use a SigmoidRange based on the un-normalized data to reduce the range our values can be:" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "id": "lqoao9U5EdOo" + }, + "outputs": [], + "source": [ + "low = (df[cont_names].min().to_frame().T.values - means.values) / stds.values\n", + "high = (df[cont_names].max().to_frame().T.values - means.values) / stds.values" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 50 + }, + "id": "2DcGJCze3Kqa", + "outputId": "77bf0d20-aef1-426b-eeda-7c969c73cac7" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([[-1.57952443, -1.67843578, -3.55622464]]),\n", + " array([[ 3.76378659, 12.22741736, 2.31914013]]))" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "low, high" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hu8pZ9oSEdOq" + }, + "source": [ + "## Batch Swap Noise\n", + "Used in the winning solution for the Kaggle competition [Puerto Seguro Safe Driver Prediction](https://www.kaggle.com/c/porto-seguro-safe-driver-prediction/discussion/44629#250927)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "id": "NZb6cetaEdOr" + }, + "outputs": [], + "source": [ + "class BatchSwapNoise(Module):\n", + " \"Swap Noise Module\"\n", + " def __init__(self, p): store_attr()\n", + "\n", + " def forward(self, x):\n", + " if self.training:\n", + " mask = torch.rand(x.size()) > (1 - self.p)\n", + " l1 = torch.floor(torch.rand(x.size()) * x.size(0)).type(torch.LongTensor)\n", + " l2 = (mask.type(torch.LongTensor) * x.size(1))\n", + " res = (l1 * l2).view(-1)\n", + " idx = torch.arange(x.nelement()) + res\n", + " idx[idx>=x.nelement()] = idx[idx>=x.nelement()]-x.nelement()\n", + " return x.flatten()[idx].view(x.size())\n", + " else:\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QjduDB3JOpfi" + }, + "source": [ + "We'll make a custom `TabularVAE` model (Denoising Variational AutoEncoder) for us to use." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "id": "0moqukV_EdOs" + }, + "outputs": [], + "source": [ + "class TabularVAE(TabularModel):\n", + " def __init__(self, emb_szs, n_cont, hidden_size, cats, low, high, ps=0.2, embed_p=0.01, bswap=None, act_cls=Swish()):\n", + " super().__init__(emb_szs, n_cont, layers=[hidden_size*8, hidden_size*4, hidden_size*2], out_sz=hidden_size, embed_p=embed_p, act_cls=act_cls)\n", + " \n", + " self.bswap = bswap\n", + " self.cats = cats\n", + " self.activation_cats = sum([v for k,v in cats.items()])\n", + " \n", + " self.layers = nn.Sequential(*L(self.layers.children())[:-1] + nn.Sequential(LinBnDrop(hidden_size*2, hidden_size, p=ps, act=act_cls)))\n", + " \n", + " if self.bswap != None: self.noise = BatchSwapNoise(self.bswap)\n", + " self.decoder = nn.Sequential(\n", + " LinBnDrop(hidden_size, hidden_size*2, p=ps, act=act_cls),\n", + " LinBnDrop(hidden_size*2, hidden_size*4, p=ps, act=act_cls),\n", + " LinBnDrop(hidden_size*4, hidden_size*8, p=ps, act=act_cls)\n", + " )\n", + " \n", + " self.decoder_cont = nn.Sequential(\n", + " LinBnDrop(hidden_size*8, n_cont, p=ps, bn=False, act=None),\n", + " SigmoidRange(low=low, high=high)\n", + " )\n", + " \n", + " self.decoder_cat = LinBnDrop(hidden_size*8, self.activation_cats, p=ps, bn=False, act=None)\n", + " \n", + " def forward(self, x_cat, x_cont=None, encode=False):\n", + " if(self.bswap != None):\n", + " x_cat = self.noise(x_cat)\n", + " x_cont = self.noise(x_cont)\n", + " \n", + " z = super().forward(x_cat, x_cont)\n", + " if(encode): return z\n", + " \n", + " decoded_trunk = self.decoder(z)\n", + " \n", + " decoded_cats = self.decoder_cat(decoded_trunk)\n", + " \n", + " decoded_conts = self.decoder_cont(decoded_trunk)\n", + " \n", + " return decoded_cats, decoded_conts, z" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vO2gBJNoO3hP" + }, + "source": [ + "We'll also need a loss function that can grade how well our features represent the original dataset. \n", + "\n", + "The categorical features will be graded on `CrossEntropyLossFlat` and the continous with `MSELossFlat`.\n", + "\n", + "Since this is a Variationnal AutoEncoder, we have to worry about KL-Divergence too. kl_weight is a special parameter controlled by a callback. At the beginning this parameter will be zero (basically like a normal autoencoder) and we will gradually increase it to 1 so that the auto-encoder become variationnal. This is a trick suggested in [Ladder Variational AutoEncoder](https://arxiv.org/abs/1602.02282) and also used in the [NVAE](https://arxiv.org/abs/2007.03898) paper." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_kernel(x, y):\n", + " x_size = x.shape[0]\n", + " y_size = y.shape[0]\n", + " dim = x.shape[1]\n", + "\n", + " tiled_x = x.view(x_size,1,dim).repeat(1, y_size,1)\n", + " tiled_y = y.view(1,y_size,dim).repeat(x_size, 1,1)\n", + "\n", + " return torch.exp(-torch.mean((tiled_x - tiled_y)**2,dim=2)/dim*1.0)\n", + "\n", + "\n", + "def compute_mmd(x, y):\n", + " x_kernel = compute_kernel(x, x)\n", + " y_kernel = compute_kernel(y, y)\n", + " xy_kernel = compute_kernel(x, y)\n", + " return torch.mean(x_kernel) + torch.mean(y_kernel) - 2*torch.mean(xy_kernel)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "config = {\n", + " 'hidden_size': 128,\n", + " 'dropout': 0.0,\n", + " 'embed_p': 0.0,\n", + " 'wd': 0.01,\n", + " 'bswap': 0.2,\n", + " 'lr': 1e-3,\n", + " 'epochs': 50\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "id": "24aYC4fk4Qev" + }, + "outputs": [], + "source": [ + "class VAERecreatedLoss(Module):\n", + " \"Measures how well we have created the original tabular inputs, plus the KL Divergence with the unit normal distribution\"\n", + " def __init__(self, cat_dict, dataset_size, bs, hidden_size, mmd_weight = 1000, reduction='mean'):\n", + " ce = CrossEntropyLossFlat(reduction='none')\n", + " mse = MSELossFlat(reduction='none')\n", + " store_attr('cat_dict,ce,mse,dataset_size,bs,hidden_size,mmd_weight,reduction')\n", + " \n", + " def forward(self, preds, cat_targs, cont_targs):\n", + " if(len(preds) == 4):\n", + " cats,conts, z, kl_weight = preds\n", + " else:\n", + " cats,conts, z = preds\n", + " kl_weight = 1\n", + " \n", + " true_samples = torch.randn((cats.shape[0],self.hidden_size))\n", + " true_samples = nn.Parameter(true_samples).cuda()\n", + "\n", + " tot_ce, pos = [], 0\n", + " for i, (k,v) in enumerate(self.cat_dict.items()):\n", + " tot_ce += [self.ce(cats[:, pos:pos+v], cat_targs[:,i])]\n", + " pos += v\n", + "\n", + " tot_ce = torch.stack(tot_ce, dim=1).mean(dim=1)\n", + " cont_loss = self.mse(conts, cont_targs).view(conts.shape).mean(dim=1)\n", + " recons_loss = (tot_ce + cont_loss)\n", + " \n", + " mmd_loss = compute_mmd(true_samples, z).repeat(cats.shape[0])\n", + " \n", + " total_loss = recons_loss + (mmd_loss * self.mmd_weight)\n", + " \n", + " if self.reduction == 'mean':\n", + " return total_loss.mean()\n", + " elif self.reduction == 'sum':\n", + " return total_loss.sum()\n", + "\n", + " return total_loss" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RV7j8WoePBJO" + }, + "source": [ + "All we need to do is pass in our `total_cats` dictionary:" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "id": "zbsDGXPX_VUV" + }, + "outputs": [], + "source": [ + "loss_func = VAERecreatedLoss(total_cats, df.shape[0], bs, config['hidden_size'], 1000)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's create some metrics for stuff we care about while fitting the model. We have reconstruction metrics like MSE and CrossEntropy but we also have to worry about KLD.\n", + "Those metrics will help us see if the loss is dominated either by the KLD or by the reconstruction loss from MSE and CrossEntropy " + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "class MSEMetric(Metric):\n", + " def __init__(self): self.preds = []\n", + " def accumulate(self, learn):\n", + " cats, conts, z = learn.pred\n", + " cat_targs, cont_targs = learn.y\n", + " norm_conts = conts.new([conts.size(1)])\n", + " self.preds.append(to_detach(F.mse_loss(conts, cont_targs, reduction='sum') / norm_conts))\n", + " @property\n", + " def value(self):\n", + " return np.array(self.preds).mean()\n", + " \n", + "class CEMetric(Metric):\n", + " def __init__(self): self.preds = []\n", + " def accumulate(self, learn):\n", + " cats, conts, z = learn.pred\n", + " cat_targs, cont_targs = learn.y\n", + " CE = cats.new([0])\n", + " pos=0\n", + " for i, (k,v) in enumerate(total_cats.items()):\n", + " CE += F.cross_entropy(cats[:, pos:pos+v], cat_targs[:, i], reduction='sum')\n", + " pos += v\n", + "\n", + " norm = cats.new([len(total_cats.keys())])\n", + " self.preds.append(to_detach(CE/norm))\n", + " @property\n", + " def value(self):\n", + " return np.array(self.preds).mean()\n", + " \n", + "class MMDMetric(Metric):\n", + " def __init__(self): self.preds = []\n", + " def accumulate(self, learn):\n", + " cats, conts, z = learn.pred\n", + " true_samples = torch.randn((bs,config['hidden_size']))\n", + " true_samples = nn.Parameter(true_samples).cuda()\n", + " MMD = compute_mmd(true_samples, z)\n", + " self.preds.append(to_detach(MMD))\n", + " @property\n", + " def value(self):\n", + " return np.array(self.preds).mean()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WrhU2fLP7N-U" + }, + "source": [ + "We'll make an config dictionary for us to use as a list of all hyper parameters. Also I would recommend against using early stopping because our AnnealedLossCallback will make the loss go worst once the KL divergence weight become larger than 0." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "id": "ThV0ch4x7R58" + }, + "outputs": [], + "source": [ + "cbs = []\n", + "cbs += [EarlyStoppingCallback(patience=5)]\n", + "metrics = []\n", + "metrics += [MSEMetric(), CEMetric(), MMDMetric()]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PYgESCPk7TeP" + }, + "source": [ + "And make our model & `Learner`" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "id": "i7xmhROaEdOw" + }, + "outputs": [], + "source": [ + "model = TabularVAE(emb_szs, len(cont_names), config['hidden_size'], ps=config['dropout'], cats=total_cats, embed_p=config['embed_p'], bswap=config['bswap'], low=tensor(low).cuda(), high=tensor(high).cuda())\n", + "learn = Learner(dls, model, lr=config['lr'], loss_func=loss_func, wd=config['wd'], opt_func=ranger, cbs=cbs, metrics=metrics).to_fp16()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "J76BTkMP7Ws2" + }, + "source": [ + "Finally we'll fit for a few epochs:" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 271 + }, + "id": "dcFlPHLAGBZ4", + "outputId": "07d78f9a-b513-4b89-fc0c-cbfcc02db820" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epochtrain_lossvalid_lossmsecemmdtime
010.1680859.5632566804.45851954.30570.0001484155700:07
15.3279381.8693163445.0391688.60450.000196448400:07
23.3422181.1080652326.75541385.1250.000187033700:07
32.3244100.5704031761.38771132.94850.0001643427800:07
41.7520370.4409751424.0205956.98560.000149631500:07
51.4101490.3718011197.2649831.22960.0001389384300:08
61.1975070.3382971037.1759735.06880.0001307336600:07
71.0555430.297053915.67053660.07550.0001244864300:07
80.9664560.282033821.8686599.549870.00011969369400:08
90.9064780.287728748.1394550.20910.0001159071900:08
100.8689010.275358687.9604508.731080.0001126558800:08
110.8349890.233405635.36395473.710420.0001097108600:08
120.8119060.228882591.0481443.60530.0001070918600:08
130.7967480.225550552.8387417.782530.000104861600:08
140.7866370.215709519.25964395.36620.00010291167600:08
150.7783440.238432491.4604375.399080.0001012310400:08
160.7670840.232167466.42215358.002179.9715064e-0500:07
170.7614320.211073443.5655342.143779.828711e-0500:07
180.7576350.209919422.91708327.934789.702023e-0500:08
190.7550530.220945404.8146315.239079.5819574e-0500:07
200.7496760.222952388.4365303.889479.477868e-0500:08
210.7467870.223583373.76233293.406349.374572e-0500:08
220.7437450.228290360.6602283.75169.291586e-0500:08
230.7408260.197392347.68954274.71379.199551e-0500:08
240.7367060.212763336.3346266.36919.1233254e-0500:08
250.7334260.226913326.3569258.670389.0506706e-0500:08
260.7314970.202380316.22864251.587518.981569e-0500:08
270.7255660.214096307.1354245.085868.916429e-0500:08
280.7227050.218085298.9792238.937448.846856e-0500:08
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No improvement since epoch 23: early stopping\n" + ] + } + ], + "source": [ + "learn.fit_flat_cos(config['epochs'], lr=0.0014)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "k6NIV2T7EdO2" + }, + "source": [ + "# Getting the compressed representations" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9E9LxZUPPJWX" + }, + "source": [ + "Next we're going to grade our compressed representations and then attempt to train on them." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "id": "bnnQC-jbEdO2" + }, + "outputs": [], + "source": [ + "dl = learn.dls.test_dl(df)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l5vf0sGyPTlS" + }, + "source": [ + "Let's predict over all the data manually using PyTorch:" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "id": "LG7l2jCVPV5k" + }, + "outputs": [], + "source": [ + "outs = []\n", + "for batch in dl:\n", + " with torch.no_grad():\n", + " learn.model.eval()\n", + " learn.model.cuda()\n", + " out = learn.model(*batch[:2], True).cpu().numpy()\n", + " outs.append(out)\n", + "outs = np.concatenate(outs)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "id": "yk4-bKfAPYxR", + "outputId": "8b4db2ca-feee-44ee-88be-be8d3914c458" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(32561, 128)" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "outs.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YIjJ0MOjP8ge" + }, + "source": [ + "As well as get the actual preds and targs:" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "id": "_sLsBpFfQADo", + "outputId": "5a65abdb-6edc-4a94-f4b8-71ae3a9c3de7" + }, + "outputs": [ + { + "data": { + "text/html": [], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "(cat_preds, cont_preds, z), (cat_targs, cont_targs) = learn.get_preds(dl=dl, reorder=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ykgzRNmcEdO-" + }, + "source": [ + "# Measuring accuracy" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SsEY3yF_EdO-" + }, + "source": [ + "## Continuous" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 195 + }, + "id": "fejwMQRbEdO_", + "outputId": "ce6d5ba6-17a6-4028-9fff-4bbb2f1911a5" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
GroupByagefnlwgteducation-num
0Min0.0000.4610.000
0Max51.168463768.5943.768
0Mean2.55527842.5400.254
0Median1.81521517.4590.196
0R20.9290.8700.982
\n", + "
" + ], + "text/plain": [ + " GroupBy age fnlwgt education-num\n", + "0 Min 0.000 0.461 0.000\n", + "0 Max 51.168 463768.594 3.768\n", + "0 Mean 2.555 27842.540 0.254\n", + "0 Median 1.815 21517.459 0.196\n", + "0 R2 0.929 0.870 0.982" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.metrics import r2_score\n", + "\n", + "cont_preds = pd.DataFrame(cont_preds, columns=cont_names)\n", + "cont_targs = pd.DataFrame(cont_targs, columns=cont_names)\n", + "\n", + "preds = pd.DataFrame((cont_preds.values * stds.values) + means.values, columns=cont_preds.columns)\n", + "targets = pd.DataFrame((cont_targs.values * stds.values) + means.values, columns=cont_targs.columns)\n", + "\n", + "mi = (np.abs(targets-preds)).min().to_frame().T\n", + "ma = (np.abs(targets-preds)).max().to_frame().T\n", + "mean = (np.abs(targets-preds)).mean().to_frame().T\n", + "median = (np.abs(targets-preds)).median().to_frame().T\n", + "r2 = pd.DataFrame.from_dict({c:[r2_score(targets[c], preds[c])] for c in preds.columns})\n", + "\n", + "\n", + "for d,name in zip([mi,ma,mean,median,r2], ['Min', 'Max', 'Mean', 'Median', 'R2']):\n", + " d = d.insert(0, 'GroupBy', name)\n", + " \n", + "data = pd.concat([mi,ma,mean,median,r2])\n", + "data" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Xgg2KiaTQIre" + }, + "source": [ + "We can also grab the R2:" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 50 + }, + "id": "nUVqCx3YEdPB", + "outputId": "25dcd1b9-7142-4b09-988f-16bba3ea2e6f" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0 0.927\n", + "dtype: float64" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "r2.mean(axis=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fXB3fiBrEdPC" + }, + "source": [ + "## Categorical" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "id": "nOEzxfpREdPF" + }, + "outputs": [], + "source": [ + "cat_reduced = torch.zeros_like(cat_targs)\n", + "pos=0\n", + "for i, (k,v) in enumerate(total_cats.items()):\n", + " cat_reduced[:,i] = cat_preds[:,pos:pos+v].argmax(dim=1)\n", + " pos += v" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "id": "zlH1IcACEdPH" + }, + "outputs": [], + "source": [ + "cat_preds = pd.DataFrame(cat_reduced, columns=cat_names)\n", + "cat_targs = pd.DataFrame(cat_targs, columns=cat_names)\n", + "\n", + "from sklearn.metrics import balanced_accuracy_score, f1_score\n", + "\n", + "accuracy = pd.DataFrame.from_dict({c:[balanced_accuracy_score(cat_targs[c], cat_preds[c])] for c in cat_preds.columns})" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": { + "id": "ZMqj8zcKEdPJ" + }, + "outputs": [], + "source": [ + "f1 = pd.DataFrame.from_dict({c:[f1_score(cat_targs[c], cat_preds[c], average='weighted')] for c in cat_preds.columns})" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 143 + }, + "id": "vuBk02t5EdPL", + "outputId": "432517e9-1e98-44c8-e82b-00dba4c73947" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MetricNameworkclasseducationmarital-statusoccupationrelationshipraceeducation-num_na
0Accuracy0.8160.9970.8730.9880.9840.9600.963
0F10.9920.9990.9950.9950.9930.9950.999
\n", + "
" + ], + "text/plain": [ + " MetricName workclass education marital-status occupation relationship \\\n", + "0 Accuracy 0.816 0.997 0.873 0.988 0.984 \n", + "0 F1 0.992 0.999 0.995 0.995 0.993 \n", + "\n", + " race education-num_na \n", + "0 0.960 0.963 \n", + "0 0.995 0.999 " + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "for d,name in zip([accuracy, f1], ['Accuracy', 'F1']):\n", + " d = d.insert(0, 'MetricName', name)\n", + "pd.concat([accuracy, f1])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3GcWeUSqQXIG" + }, + "source": [ + "And check it's overall accuracy:" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 50 + }, + "id": "2rOtbaScEdPN", + "outputId": "c685c93a-9a39-4d72-fa6c-f42fae03aa63" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0 0.940\n", + "dtype: float64" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "accuracy.mean(axis=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u9K317948SvR" + }, + "source": [ + "## Predicting\n", + "\n", + "Now that we have our compressed representations, let's use them to train a new model" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": { + "id": "HAM08-DsJsjh" + }, + "outputs": [], + "source": [ + "ys = df['salary'].to_numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": { + "id": "RRbccF9zKBV5" + }, + "outputs": [], + "source": [ + "test_eq(len(outs), len(ys))" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": { + "id": "OX0IPAG8KhbL" + }, + "outputs": [], + "source": [ + "df_outs = pd.DataFrame(columns=['salary'] + list(range(0,config['hidden_size'])))" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": { + "id": "9KjjmrbfKqIN" + }, + "outputs": [], + "source": [ + "df_outs['salary'] = ys" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": { + "id": "gfO5CknNKtXl" + }, + "outputs": [], + "source": [ + "df_outs[list(range(0,config['hidden_size']))] = outs" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": { + "id": "v6mUmbz5Lkqb" + }, + "outputs": [], + "source": [ + "pd.options.mode.chained_assignment=None" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": { + "id": "l7lAP34pMAvN" + }, + "outputs": [], + "source": [ + "splits = RandomSplitter()(range_of(df))" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": { + "id": "FGezIL33MVXx" + }, + "outputs": [], + "source": [ + "df_outs[list(range(0,config['hidden_size']))] = df_outs[list(range(0,config['hidden_size']))].astype(np.float16)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": { + "id": "M1MWujHOKwBf" + }, + "outputs": [], + "source": [ + "cont_names_ = list(range(0,config['hidden_size']))\n", + "to2 = TabularPandas(df_outs, procs = [Normalize], cont_names=cont_names_, splits=splits, y_names=['salary'], reduce_memory=False, \n", + " y_block=CategoryBlock())" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": { + "id": "G4AKuKfbLgwQ" + }, + "outputs": [], + "source": [ + "dls2 = to2.dataloaders(bs=1024)" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": { + "id": "NoYN1wi3NaYE" + }, + "outputs": [], + "source": [ + "def accuracy(inp, targ, axis=-1):\n", + " \"Compute accuracy with `targ` when `pred` is bs * n_classes\"\n", + " pred,targ = flatten_check(inp.argmax(dim=axis), targ)\n", + " return (pred == targ).float().mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": { + "id": "lhOMRBYCLAZy" + }, + "outputs": [], + "source": [ + "learn2 = tabular_learner(dls2, layers=[200,100], config={'ps':0.05}, metrics=[accuracy])" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 195 + }, + "id": "UuvlsOh4MlN_", + "outputId": "64b00442-2801-4bfe-a9ca-d823c3140590" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epochtrain_lossvalid_lossaccuracytime
00.3948430.3513180.83184900:00
10.3691110.3465180.83538100:00
20.3588630.3443590.83998800:00
30.3529220.3454990.84106300:00
40.3484090.3459770.83722400:00
50.3447940.3440400.84014100:00
60.3427730.3449620.83937300:00
70.3407860.3440120.83983400:00
80.3391760.3475670.83829800:00
90.3373960.3485540.83922000:00
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "learn2.fit(10, 1e-3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "TabularAE.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}