Skip to content

Commit

Permalink
Merge pull request #60 from valentingol/dev
Browse files Browse the repository at this point in the history
🚀 pixel_classes_cond is always list
  • Loading branch information
valentingol authored Dec 9, 2022
2 parents 3d6eb53 + c74eeb1 commit e1b5934
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 21 deletions.
4 changes: 2 additions & 2 deletions CONFIG_SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ it is in model configuration (see below).**

- `pixel_size_cond` (default 6): size of the pixels to condition with

- `pixel_classes_cond` (default None): list of classes of the pixels to
condition with. If None, all classes are used.
- `pixel_classes_cond` (default []): list of classes of the pixels to
condition with. If empty, all classes are used.

## Model

Expand Down
4 changes: 2 additions & 2 deletions gan_facies/configs/default/data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ pin_memory: False # True to pin memory on GPU memory for faster transfer
n_pixels_cond: null
# pixel_size_cond: only used for conditional models, size of the pixels to sample
pixel_size_cond: 6
# pixel_classes_cond: list of classes of conditioning pixels (None for all classes)
pixel_classes_cond: null
# pixel_classes_cond: list of classes of conditioning pixels (empty for all classes)
pixel_classes_cond: []
22 changes: 11 additions & 11 deletions gan_facies/utils/conditioning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utilities for evaluation."""
import random
from typing import List, Optional, Union
from typing import List, Union

import numpy as np
import torch
Expand All @@ -13,7 +13,7 @@ def generate_pixel_maps(batch_size: int,
n_classes: int,
n_pixels: Union[int, List[int]],
pixel_size: int,
pixel_classes: Optional[List],
pixel_classes: List,
data_size: int,
device: torch.device = "cpu") -> torch.Tensor:
"""Generate random pixel maps for conditioning.
Expand All @@ -30,8 +30,8 @@ def generate_pixel_maps(batch_size: int,
will be sampled uniformly between the two values.
pixel_size: int
Size of the pixels to sample.
pixel_classes: list or None
If list, the class of the pixels to sample. If None, all classes
pixel_classes: list
If list, the class of the pixels to sample. If empty, all classes
will be eventually sampled.
data_size : int
Size of the data.
Expand All @@ -53,6 +53,9 @@ def generate_pixel_maps(batch_size: int,
f"found list of lenght {len(n_pixels)}.")
raise ValueError("n_pixels must be int or tuple of 2 ints, "
f"found type {type(n_pixels)}.")
if not isinstance(pixel_classes, list):
raise ValueError("pixel_classes must be list or None, found "
f"type {type(pixel_classes)}.")
pixel_maps = torch.zeros((batch_size, n_classes, data_size, data_size),
device=device, dtype=torch.float32)
for i_batch in range(batch_size):
Expand All @@ -69,16 +72,13 @@ def generate_pixel_maps(batch_size: int,
pixels_h = [i*pixel_size + k for k in grid_h for i, _ in pixels_idx]
pixels_w = [j*pixel_size + k for k in grid_w for _, j in pixels_idx]
# Randomly sample classes and copy them on all big-pixels
if isinstance(pixel_classes, list):
classes_np = np.random.choice(pixel_classes, size=n_pixels_int,
replace=True)
classes = torch.from_numpy(classes_np).to(device)
elif pixel_classes is None:
if pixel_classes == []:
classes = torch.randint(0, n_classes, (n_pixels_int, ),
device=device)
else:
raise ValueError("pixel_classes must be list or None, found "
f"type {type(pixel_classes)}.")
classes_np = np.random.choice(pixel_classes, size=n_pixels_int,
replace=True)
classes = torch.from_numpy(classes_np).to(device)
classes = repeat(classes, "n -> (h w n)", h=pixel_size, w=pixel_size)
pixel_maps[i_batch, classes, pixels_h, pixels_w] = 1.0
# Class 0 here is actually the mask of the pixels to keep
Expand Down
12 changes: 6 additions & 6 deletions tests/utils/test_conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

def test_generate_pixel_maps() -> None:
"""Test generate_pixel_maps."""
# Case n_pixels is int, pixel_classes is None
# Case n_pixels is int, pixel_classes is empty
pixel_maps = generate_pixel_maps(batch_size=2, n_classes=3, n_pixels=5,
pixel_size=3, pixel_classes=None,
pixel_size=3, pixel_classes=[],
data_size=32)
check.equal(pixel_maps.shape, (2, 3, 32, 32))
check.equal(pixel_maps.dtype, torch.float32)
Expand All @@ -32,7 +32,7 @@ def test_generate_pixel_maps() -> None:
check.is_true(((torch.sum(pixel_maps[:, 1:], dim=1) == 0.0)
| (torch.sum(pixel_maps[:, 1:], dim=1) == 1.0)).all())

# Case n_pixels is list of length 2, pixel_classes is not None
# Case n_pixels is list of length 2, pixel_classes is not empty
pixel_maps = generate_pixel_maps(batch_size=2, n_classes=3,
n_pixels=[5, 10], pixel_classes=[0, 2],
pixel_size=1,
Expand All @@ -44,11 +44,11 @@ def test_generate_pixel_maps() -> None:
# Case invalid n_pixels
with check.raises(ValueError):
generate_pixel_maps(batch_size=2, n_classes=3, n_pixels=[5, 10, 15],
pixel_size=1, pixel_classes=None, data_size=32)
pixel_size=1, pixel_classes=[], data_size=32)
with check.raises(ValueError):
generate_pixel_maps(batch_size=2, n_classes=3,
n_pixels='3', pixel_size=1, # type: ignore
pixel_classes=None, data_size=32)
pixel_classes=[], data_size=32)
# Case invalid pixel_classes
with check.raises(ValueError):
generate_pixel_maps(batch_size=2, n_classes=3,
Expand All @@ -66,7 +66,7 @@ def test_colorize_pixel_map(mocker: MockerFixture) -> None:
return_value=np.random.randint(0, 256, (34 * 5, 34 * 5, 3),
dtype=np.uint8))
pixel_maps = generate_pixel_maps(batch_size=25, n_classes=3, n_pixels=5,
pixel_size=3, pixel_classes=None,
pixel_size=3, pixel_classes=[],
data_size=32)
color_pixel_maps = colorize_pixel_map(pixel_maps)
check.is_instance(color_pixel_maps, np.ndarray)
Expand Down

0 comments on commit e1b5934

Please sign in to comment.