Skip to content

Commit

Permalink
Implemented MMD vae.
Browse files Browse the repository at this point in the history
  • Loading branch information
EtienneT committed Feb 17, 2021
1 parent a34ccff commit 4da49c3
Show file tree
Hide file tree
Showing 3 changed files with 1,980 additions and 54 deletions.
96 changes: 42 additions & 54 deletions TabularAE.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -17,26 +17,26 @@
"output_type": "stream",
"text": [
"Name: fastai\n",
"Version: 2.0.15\n",
"Version: 2.0.8\n",
"Summary: fastai simplifies training fast and accurate neural nets using modern best practices\n",
"Home-page: https://github.com/fastai/fastai/tree/master/\n",
"Home-page: https://github.com/fastai/fastai\n",
"Author: Jeremy Howard, Sylvain Gugger, and contributors\n",
"Author-email: info@fast.ai\n",
"License: Apache Software License 2.0\n",
"Location: /usr/local/lib/python3.6/dist-packages\n",
"Requires: torchvision, scikit-learn, packaging, pip, pyyaml, pillow, torch, fastprogress, requests, spacy, scipy, pandas, fastcore, matplotlib\n",
"Location: /mnt/c/work/ml/fastai\n",
"Requires: pip, packaging, fastcore, torchvision, matplotlib, pandas, requests, pyyaml, fastprogress, pillow, scikit-learn, scipy, spacy, pandas, torch\n",
"Required-by: \n",
"---\n",
"Name: fastcore\n",
"Version: 1.0.16\n",
"Version: 1.0.1\n",
"Summary: Python supercharged for fastai development\n",
"Home-page: https://github.com/fastai/fastcore/tree/master/\n",
"Home-page: https://github.com/fastai/fastcore\n",
"Author: Jeremy Howard and Sylvain Gugger\n",
"Author-email: infos@fast.ai\n",
"License: Apache Software License 2.0\n",
"Location: /usr/local/lib/python3.6/dist-packages\n",
"Requires: packaging, pip\n",
"Required-by: fastai\n"
"Location: /home/etienne/miniconda3/envs/fastai/lib/python3.8/site-packages\n",
"Requires: packaging, wheel, pip, numpy, dataclasses\n",
"Required-by: nbdev, fastai2, fastai\n"
]
}
],
Expand Down Expand Up @@ -507,7 +507,7 @@
},
{
"cell_type": "code",
"execution_count": 103,
"execution_count": 18,
"metadata": {
"id": "0moqukV_EdOs"
},
Expand Down Expand Up @@ -563,7 +563,7 @@
},
{
"cell_type": "code",
"execution_count": 104,
"execution_count": 19,
"metadata": {
"id": "24aYC4fk4Qev"
},
Expand Down Expand Up @@ -606,7 +606,7 @@
},
{
"cell_type": "code",
"execution_count": 105,
"execution_count": 20,
"metadata": {
"id": "zbsDGXPX_VUV"
},
Expand All @@ -617,14 +617,14 @@
},
{
"cell_type": "code",
"execution_count": 106,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"b = dls.one_batch()\n",
"# b = dls.one_batch()\n",
"\n",
"out = learn.model(*b[:2])\n",
"loss = loss_func(out, *b[-2:])"
"# out = learn.model(*b[:2])\n",
"# loss = loss_func(out, *b[-2:])"
]
},
{
Expand All @@ -638,7 +638,7 @@
},
{
"cell_type": "code",
"execution_count": 107,
"execution_count": 22,
"metadata": {
"id": "ThV0ch4x7R58"
},
Expand Down Expand Up @@ -666,7 +666,7 @@
},
{
"cell_type": "code",
"execution_count": 108,
"execution_count": 23,
"metadata": {
"id": "i7xmhROaEdOw"
},
Expand All @@ -687,7 +687,7 @@
},
{
"cell_type": "code",
"execution_count": 109,
"execution_count": 24,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -712,57 +712,45 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>8.551754</td>\n",
" <td>3.598067</td>\n",
" <td>00:00</td>\n",
" <td>7.849534</td>\n",
" <td>3.430095</td>\n",
" <td>00:03</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>4.080617</td>\n",
" <td>1.116776</td>\n",
" <td>00:00</td>\n",
" <td>3.788902</td>\n",
" <td>1.078850</td>\n",
" <td>00:03</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>2.433498</td>\n",
" <td>0.573103</td>\n",
" <td>00:00</td>\n",
" <td>2.278958</td>\n",
" <td>0.546157</td>\n",
" <td>00:03</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1.632088</td>\n",
" <td>0.242973</td>\n",
" <td>00:00</td>\n",
" <td>1.537396</td>\n",
" <td>0.235419</td>\n",
" <td>00:03</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1.184522</td>\n",
" <td>0.177599</td>\n",
" <td>00:00</td>\n",
" <td>1.130815</td>\n",
" <td>0.188008</td>\n",
" <td>00:03</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.918635</td>\n",
" <td>0.137939</td>\n",
" <td>00:00</td>\n",
" <td>0.889883</td>\n",
" <td>0.135979</td>\n",
" <td>00:03</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>0.758113</td>\n",
" <td>0.131483</td>\n",
" <td>00:00</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>0.649586</td>\n",
" <td>0.102631</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>0.585502</td>\n",
" <td>0.111842</td>\n",
" <td>00:01</td>\n",
" <td>0.736168</td>\n",
" <td>0.142726</td>\n",
" <td>00:03</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
Expand All @@ -778,7 +766,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"No improvement since epoch 7: early stopping\n"
"No improvement since epoch 5: early stopping\n"
]
}
],
Expand Down
File renamed without changes.
Loading

0 comments on commit 4da49c3

Please sign in to comment.