Skip to content

Commit

Permalink
feat: some architecture flexibilities added to uvit
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Sep 8, 2024
1 parent 6b9b4a4 commit 7dbd273
Show file tree
Hide file tree
Showing 7 changed files with 914 additions and 13 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
name: Upload Python Package

on:
push:
branches: [ "main" ]
release:
types: [published]

Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ datacache
gcsfuse.yml
*.csv
*.tsv
*.parquet
*.parquet
*.arrow
214 changes: 212 additions & 2 deletions datasets/dataset preparations.ipynb

Large diffs are not rendered by default.

602 changes: 602 additions & 0 deletions datasets/datasets/laion2B-en-aesthetic-4.2_37M/dataset_info.json

Large diffs are not rendered by default.

61 changes: 61 additions & 0 deletions datasets/datasets/laion2B-en-aesthetic-4.2_37M/state.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
{
"_data_files": [
{
"filename": "data-00000-of-00017.arrow"
},
{
"filename": "data-00001-of-00017.arrow"
},
{
"filename": "data-00002-of-00017.arrow"
},
{
"filename": "data-00003-of-00017.arrow"
},
{
"filename": "data-00004-of-00017.arrow"
},
{
"filename": "data-00005-of-00017.arrow"
},
{
"filename": "data-00006-of-00017.arrow"
},
{
"filename": "data-00007-of-00017.arrow"
},
{
"filename": "data-00008-of-00017.arrow"
},
{
"filename": "data-00009-of-00017.arrow"
},
{
"filename": "data-00010-of-00017.arrow"
},
{
"filename": "data-00011-of-00017.arrow"
},
{
"filename": "data-00012-of-00017.arrow"
},
{
"filename": "data-00013-of-00017.arrow"
},
{
"filename": "data-00014-of-00017.arrow"
},
{
"filename": "data-00015-of-00017.arrow"
},
{
"filename": "data-00016-of-00017.arrow"
}
],
"_fingerprint": "9e2180a190e4d3ae",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": null,
"_output_all_columns": false,
"_split": "train"
}
43 changes: 34 additions & 9 deletions flaxdiff/models/simple_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Callable, Any, Optional, Tuple
from .simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
from .attention import TransformerBlock
from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init, ResidualBlock
import einops
from flax.typing import Dtype, PrecisionLike
from functools import partial
Expand Down Expand Up @@ -68,6 +68,7 @@ class UViT(nn.Module):
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
kernel_init: Callable = partial(kernel_init, 1.0)
add_residualblock_output: bool = False

def setup(self):
if self.norm_groups > 0:
Expand All @@ -80,6 +81,8 @@ def __call__(self, x, temb, textcontext=None):
# Time embedding
temb = FourierEmbedding(features=self.emb_features)(temb)
temb = TimeProjection(features=self.emb_features)(temb)

original_img = x

# Patch embedding
x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
Expand Down Expand Up @@ -141,14 +144,36 @@ def __call__(self, x, temb, textcontext=None):
x = x[:, 1 + num_text_tokens:, :]
x = unpatchify(x, channels=self.output_channels)
# print(f'Shape of x after final dense layer: {x.shape}')
x = nn.Conv(
features=self.output_channels,
kernel_size=(3, 3),

if self.add_residualblock_output:
# Concatenate the original image
x = jnp.concatenate([original_img, x], axis=-1)

x = ResidualBlock(
"conv",
name="final_residual",
features=64,
kernel_init=self.kernel_init(1.0),
kernel_size=(3,3),
strides=(1, 1),
activation=self.activation,
norm_groups=self.norm_groups,
dtype=self.dtype,
precision=self.precision,
named_norms=False
)(x, temb)

x = self.norm()(x)
x = self.activation(x)

x = ConvLayer(
"conv",
features=self.output_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding='SAME',
dtype=self.dtype,
precision=self.precision,
kernel_init=kernel_init(0.0),
# activation=jax.nn.mish
kernel_init=self.kernel_init(0.0),
dtype=self.dtype,
precision=self.precision
)(x)

return x
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.27',
version='0.1.28',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 7dbd273

Please sign in to comment.