-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathVAE_test_script.py
55 lines (39 loc) · 880 Bytes
/
VAE_test_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# -*- coding: utf-8 -*-
from VAE.logVar_VAE import Vanilla_VAE
from VAE.Conv_VAE import Conv_VAE
from VAE.visualize import plotInOut, plotInOut_Conv
import torch
import torchvision
import torchvision.transforms as transforms
from datasets.MNIST import load_MNIST, test_MNIST
testloader = test_MNIST()
trainloader = load_MNIST(1)
#%% NORMAL VAE
name = 'test1'
savepath = 'results/'+name
state_dict = torch.load(savepath)
x_dim = 28*28
h1_dim = 500
h2_dim = 200
z_dim = 10
vae1 = Vanilla_VAE(x_dim,h1_dim,h2_dim,z_dim, use_cuda = False)
vae1.load_state_dict(state_dict)
vae1.eval()
#%%
plotInOut(testloader,vae1)
#%% CONV VAE
#
#name = 'conv1'
#savepath = 'results/'+name
#
#state_dict = torch.load(savepath)
#
#h_dim = 500
#z_dim = 10
#
#vae1 = Conv_VAE(h_dim,z_dim, use_cuda = False)
#vae1.load_state_dict(state_dict)
#
##%%
#
#plotInOut_Conv(testloader,vae1)