Skip to content

Commit

Permalink
v0.1.1
Browse files Browse the repository at this point in the history
  • Loading branch information
mbarbetti committed Nov 10, 2023
1 parent 0578b0b commit 37d96dc
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
3 changes: 1 addition & 2 deletions src/pidgan/players/generators/Generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from tensorflow import keras

LEAKY_ALPHA = 0.1
SEED = 42


class Generator(keras.Model):
Expand Down Expand Up @@ -106,7 +105,7 @@ def summary(self, **kwargs) -> None:
self._seq.summary(**kwargs)

def generate(self, x, seed=None, return_latent_sample=False) -> tf.Tensor:
tf.random.set_seed(seed=SEED)
tf.random.set_seed(seed=seed)
x, latent_sample = self._prepare_input(x, seed=seed)
out = self._seq(x)
if return_latent_sample:
Expand Down
2 changes: 1 addition & 1 deletion src/pidgan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.0"
__version__ = "0.1.1"
9 changes: 6 additions & 3 deletions tests/players/generators/test_Generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,13 @@ def test_model_eval(model, sample_weight):


def test_model_generate(model):
out = model.generate(x, seed=42)
comparison = out.numpy() == model.generate(x, seed=42).numpy()
no_seed_out = model.generate(x, seed=None)
comparison = no_seed_out.numpy() != model.generate(x, seed=None).numpy()
assert comparison.all()
comparison = out.numpy() != model.generate(x, seed=24).numpy()
seed_out = model.generate(x, seed=42)
comparison = seed_out.numpy() == model.generate(x, seed=42).numpy()
assert comparison.all()
comparison = seed_out.numpy() != model.generate(x, seed=24).numpy()
assert comparison.any()


Expand Down

0 comments on commit 37d96dc

Please sign in to comment.