From 7a13cc343fa9ec40974933638720893fb7f86d5f Mon Sep 17 00:00:00 2001 From: Ashish Kumar Singh Date: Tue, 27 Aug 2024 12:13:12 -0400 Subject: [PATCH] feat: parameterized few stuff in vit --- flaxdiff/models/simple_vit.py | 9 ++++++--- setup.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/flaxdiff/models/simple_vit.py b/flaxdiff/models/simple_vit.py index b63d0ed..cf1d3c2 100644 --- a/flaxdiff/models/simple_vit.py +++ b/flaxdiff/models/simple_vit.py @@ -58,6 +58,9 @@ class UViT(nn.Module): dtype: Any = jnp.float32 precision: Any = jax.lax.Precision.HIGH use_projection: bool = False + use_flash_attention: bool = False + use_self_and_cross: bool = False + force_fp32_for_softmax: bool = True activation:Callable = jax.nn.swish norm_groups:int=8 dtype: Optional[Dtype] = None @@ -102,7 +105,7 @@ def __call__(self, x, temb, textcontext=None): for i in range(self.num_layers // 2): x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, - use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True, + use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax, only_pure_attention=False, kernel_init=self.kernel_init())(x) skips.append(x) @@ -110,7 +113,7 @@ def __call__(self, x, temb, textcontext=None): # Middle block x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, - use_flash_attention=False, use_self_and_cross=True, force_fp32_for_softmax=True, + use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.fforce_fp32_for_softmax, only_pure_attention=False, kernel_init=self.kernel_init())(x) @@ -121,7 +124,7 @@ def __call__(self, x, temb, textcontext=None): dtype=self.dtype, precision=self.precision)(skip) x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, - use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True, + use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.fforce_fp32_for_softmax, only_pure_attention=False, kernel_init=self.kernel_init())(skip) diff --git a/setup.py b/setup.py index 1e88bf2..b462991 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( name='flaxdiff', packages=find_packages(), - version='0.1.23', + version='0.1.24', description='A versatile and easy to understand Diffusion library', long_description=open('README.md').read(), long_description_content_type='text/markdown',