diff --git a/pyproject.toml b/pyproject.toml index 8c918d8..d827c85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ packages = [ {include = "scgen"}, ] readme = "README.md" -version = "2.1.1" +version = "2.1.2" [tool.poetry.dependencies] adjustText = "*" diff --git a/scgen/_scgen.py b/scgen/_scgen.py index 428f85e..9d8a57f 100644 --- a/scgen/_scgen.py +++ b/scgen/_scgen.py @@ -46,6 +46,8 @@ class SCGEN(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): >>> adata.obsm["X_scgen"] = vae.get_latent_representation() """ + _module_cls = SCGENVAE + def __init__( self, adata: AnnData, @@ -57,14 +59,24 @@ def __init__( ): super().__init__(adata) - self.module = SCGENVAE( - n_input=self.summary_stats.n_vars, - n_hidden=n_hidden, - n_latent=n_latent, - n_layers=n_layers, - dropout_rate=dropout_rate, - **model_kwargs, - ) + self._module_kwargs = { + "n_hidden": n_hidden, + "n_latent": n_latent, + "n_layers": n_layers, + "dropout_rate": dropout_rate, + **model_kwargs + } + if self._module_init_on_train: + self.module = None + else: + self.module = SCGENVAE( + n_input=self.summary_stats.n_vars, + n_hidden=n_hidden, + n_latent=n_latent, + n_layers=n_layers, + dropout_rate=dropout_rate, + **model_kwargs, + ) self._model_summary_string = ( "SCGEN Model with the following params: \nn_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: " "{}" diff --git a/scgen/_scgenvae.py b/scgen/_scgenvae.py index 113d8cb..310bd72 100644 --- a/scgen/_scgenvae.py +++ b/scgen/_scgenvae.py @@ -45,6 +45,7 @@ def __init__( use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none", kl_weight: float = 0.00005, + **kwargs # needed when initilaized with a datamodule ): super().__init__() self.n_layers = n_layers