Skip to content

Commit

Permalink
AutoEncoder architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyasi2002 committed Dec 20, 2023
1 parent 020781f commit 615450b
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
import torch.nn.functional as F
import torch.nn as nn

class ConvAutoencoder_GELU(nn.Module):
def __init__(self, device, z_dim=10):
super().__init__()
self.device = device

self.encoder_conv2D = nn.Sequential(
nn.Conv2d(1, 8, 3, stride=2, padding=1),
nn.GELU(),
nn.Conv2d(8, 16, 3, stride=2, padding=1),
nn.BatchNorm2d(16),
nn.GELU(),
nn.Conv2d(16, 32, 3, stride=2, padding=0),
nn.GELU()
)

## Flatten Layer
self.flatten = nn.Flatten(start_dim=1)

## Linear Section
self.encoder_linear = nn.Sequential(
nn.Linear(3 * 3 * 32, 128),
nn.GELU(),
nn.Linear(128, z_dim),
)

self.decoder_linear = nn.Sequential(
nn.Linear(z_dim, 128),
nn.GELU(),
nn.Linear(128, 3 * 3 * 32),
nn.GELU(),
)

self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32, 3, 3))

self.decoder_convt2d = nn.Sequential(
nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=1),
nn.BatchNorm2d(16),
nn.GELU(),
nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(8),
nn.GELU(),
nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1),
)


def forward(self, x):
## Encode the image to latent space
x = self.encoder_conv2D(x)
x = self.flatten(x)
x = self.encoder_linear(x)

## Add random gaussian noise to latent encoding
noise = torch.randn(x.shape).to(self.device)
x = x + noise

## Decode the latent encoding back to reconstructed image
x = self.decoder_linear(x)
x = self.unflatten(x)
x = self.decoder_convt2d(x)
x = torch.sigmoid(x)

return x

0 comments on commit 615450b

Please sign in to comment.