Skip to content

Commit

Permalink
Merge pull request #59 from valentingol/dev
Browse files Browse the repository at this point in the history
🆙 Update to 2.5.0
  • Loading branch information
valentingol authored Dec 2, 2022
2 parents c4264af + 7a04ef5 commit 3d6eb53
Show file tree
Hide file tree
Showing 18 changed files with 818 additions and 43 deletions.
5 changes: 4 additions & 1 deletion CONFIG_SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ it is in model configuration (see below).**
- `prefetch_factor` (default 2): number of batch to load with CPU while
the GPU is running

- `persistant_workers` (default False): whether copying worker at the end of epoc
- `persistant_workers` (default False): whether copying worker at the end of epoch
(False) or keep them (True)

- `pin_memory` (default False): whether to pin memory on GPU for faster transfer
Expand All @@ -84,6 +84,9 @@ it is in model configuration (see below).**

- `pixel_size_cond` (default 6): size of the pixels to condition with

- `pixel_classes_cond` (default None): list of classes of the pixels to
condition with. If None, all classes are used.

## Model

Default configuration in `gan_facies/configs/default/model.yaml`.
Expand Down
6 changes: 4 additions & 2 deletions gan_facies/apps/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ def test(config: ConfigType) -> None:
with torch.no_grad():
n_pixels = config.data.n_pixels_cond
pixel_size = config.data.pixel_size_cond
pixel_classes = config.data.pixel_classes_cond
data_size = config.model.data_size
pixel_maps = generate_pixel_maps(
batch_size=batch_size, n_classes=n_classes,
n_pixels=n_pixels, pixel_size=pixel_size, data_size=data_size,
n_pixels=n_pixels, pixel_size=pixel_size,
pixel_classes=pixel_classes, data_size=data_size,
device=device)
colored_pixel_maps = colorize_pixel_map(pixel_maps)
images, attn_list = generator.generate(z_input, pixel_maps,
Expand Down Expand Up @@ -115,7 +117,7 @@ def test(config: ConfigType) -> None:
compute_save_indicators(data_loader, config)
# Compute and print metrics
metrics = evaluate(gen=generator, config=config, training=False, step=step,
save_json=False, save_csv=True)
save_json=False, save_csv=True)[:2]
print(f'MACs: {macs / 1e9:.2f}G')
print("Metrics w.r.t training set:")
print_metrics(metrics)
Expand Down
436 changes: 436 additions & 0 deletions gan_facies/apps/post_process.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions gan_facies/configs/default/data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ pin_memory: False # True to pin memory on GPU memory for faster transfer
n_pixels_cond: null
# pixel_size_cond: only used for conditional models, size of the pixels to sample
pixel_size_cond: 6
# pixel_classes_cond: list of classes of conditioning pixels (None for all classes)
pixel_classes_cond: null
4 changes: 4 additions & 0 deletions gan_facies/configs/exp/usagan128.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ data.train_batch_size: 64
data.test_batch_size: 64

model.data_size: 128
model.sym_attn: False

training.total_step: 1000000000
training.sample_step: 200
Expand All @@ -20,3 +21,6 @@ wandb.use_wandb: False

data.n_pixels_cond: 10
data.pixel_size_cond: 2

metrics.connectivity: 2
metrics.unit_component_size: 8
2 changes: 1 addition & 1 deletion gan_facies/gan/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def compute_metrics(self, gen: Module) -> MetricsType:
metrics = metr.evaluate(gen=gen, config=self.config, training=True,
step=self.step + 1,
indicators_path=self.indicators_path,
save_json=False, save_csv=True)
save_json=False, save_csv=True)[:2]
print()
return metrics

Expand Down
2 changes: 2 additions & 0 deletions gan_facies/gan/cond_sagan/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,12 @@ def create_pixel_maps(self) -> None:
"""Generate pixel maps for evaluation."""
n_pixels = self.config.data.n_pixels_cond
pixel_size = self.config.data.pixel_size_cond
pixel_cls = self.config.data.pixel_classes_cond
data_size = self.config.model.data_size
self.fixed_pixel_maps = generate_pixel_maps(batch_size=self.batch_size,
n_classes=self.n_classes,
n_pixels=n_pixels,
pixel_size=pixel_size,
pixel_classes=pixel_cls,
data_size=data_size,
device="cuda:0")
34 changes: 24 additions & 10 deletions gan_facies/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from gan_facies.utils.configs import ConfigType

IndicatorsList = List[Dict[str, List[float]]]
EvalOutType = Tuple[Dict[str, float], Dict[str, float], np.ndarray,
List[Dict[str, List[float]]]]


def wasserstein_distances(data1: Union[np.ndarray, IndicatorsList],
Expand Down Expand Up @@ -211,7 +213,7 @@ def compute_save_indicators(data_loader: DistributedDataLoader,

def evaluate(gen: nn.Module, config: ConfigType, training: bool, step: int,
indicators_path: Optional[str] = None, save_json: bool = False,
save_csv: bool = False, n_images: int = 1024) -> MetricsType:
save_csv: bool = False, n_images: int = 1024) -> EvalOutType:
"""Compute metrics from generator output.
Compute Wasserstein distances between generated images indicators
Expand Down Expand Up @@ -245,14 +247,18 @@ def evaluate(gen: nn.Module, config: ConfigType, training: bool, step: int,
Lower values means better similarity.
other_metrics : Dict[str, float]
Other metrics.
data_gen : np.ndarray
Generated images used to compute metrics.
indicators : IndicatorsList
List of indicators for the generated images.
"""
gen.eval()
device = next(gen.parameters()).device
batch_size = (config.data.train_batch_size
if training else config.data.test_batch_size)
print(" -> Generating images for metrics calculation:", end='\r')
# Generate more than n_images images to compute metrics
data_gen = []
data_gen_list = []
with torch.no_grad():
n_wrongs, total_pixels = 0, 0
for k in range(max(1, int(np.ceil(n_images // batch_size)))):
Expand All @@ -267,11 +273,13 @@ def evaluate(gen: nn.Module, config: ConfigType, training: bool, step: int,
if 'cond' in config.model.architecture:
n_pixels = config.data.n_pixels_cond
pixel_size = config.data.pixel_size_cond
pixel_classes = config.data.pixel_classes_cond
data_size = config.model.data_size
pixel_maps = generate_pixel_maps(batch_size=batch_size,
n_classes=gen.n_classes,
n_pixels=n_pixels,
pixel_size=pixel_size,
pixel_classes=pixel_classes,
data_size=data_size,
device=device)
out = gen(z_input, pixel_maps=pixel_maps)
Expand All @@ -290,8 +298,9 @@ def evaluate(gen: nn.Module, config: ConfigType, training: bool, step: int,
else:
out = gen(z_input)
out = torch.argmax(out, dim=1).detach().cpu().numpy()
data_gen_list.append(out)
del out
torch.cuda.empty_cache() # Free GPU memory
data_gen.append(out)
print(
" -> Generating images for metrics calculation: "
f"{(k + 1)*batch_size} images", end='\r')
Expand All @@ -300,12 +309,12 @@ def evaluate(gen: nn.Module, config: ConfigType, training: bool, step: int,
else:
other_metrics = {}
print()
data_gen_arr = np.vstack(data_gen)
data_gen_arr = data_gen_arr.astype(np.uint8)
data_gen = np.vstack(data_gen_list)
data_gen = data_gen.astype(np.uint8)
print(" -> Computing indicators...")

indicators_list_ref = get_reference_indicators(config, indicators_path,
data_gen_arr)
data_gen)

metrics_save_dir = osp.join(config.output_dir, config.run_name, 'metrics')
os.makedirs(metrics_save_dir, exist_ok=True)
Expand All @@ -324,12 +333,17 @@ def evaluate(gen: nn.Module, config: ConfigType, training: bool, step: int,
f'test_metrics_step_{step}')

# Compute metrics and save boxes locally if needed
w_dists = wasserstein_distances(data_gen_arr, indicators_list_ref,
save_boxes_path=save_boxes_path,
**config.metrics)[0]
unit_component_size = config.metrics.unit_component_size
connectivity = config.metrics.connectivity
w_dists, inds = wasserstein_distances(
data_gen, indicators_list_ref,
save_boxes_path=save_boxes_path,
connectivity=connectivity,
unit_component_size=unit_component_size)
indicators = inds[0] # indicators for generated images
save_metrics((w_dists, other_metrics), metrics_save_path,
save_json=save_json, save_csv=save_csv)
return w_dists, other_metrics
return w_dists, other_metrics, data_gen, indicators


def print_metrics(metrics: MetricsType, step: Optional[int] = None) -> None:
Expand Down
25 changes: 20 additions & 5 deletions gan_facies/utils/conditioning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utilities for evaluation."""
import random
from typing import List, Union
from typing import List, Optional, Union

import numpy as np
import torch
Expand All @@ -9,11 +9,14 @@
import gan_facies.data.process as proc


def generate_pixel_maps(batch_size: int, n_classes: int,
def generate_pixel_maps(batch_size: int,
n_classes: int,
n_pixels: Union[int, List[int]],
pixel_size: int, data_size: int,
pixel_size: int,
pixel_classes: Optional[List],
data_size: int,
device: torch.device = "cpu") -> torch.Tensor:
"""Generate pixel maps and eventually color them.
"""Generate random pixel maps for conditioning.
Parameters
----------
Expand All @@ -27,6 +30,9 @@ def generate_pixel_maps(batch_size: int, n_classes: int,
will be sampled uniformly between the two values.
pixel_size: int
Size of the pixels to sample.
pixel_classes: list or None
If list, the class of the pixels to sample. If None, all classes
will be eventually sampled.
data_size : int
Size of the data.
device : torch.device, optional
Expand Down Expand Up @@ -63,7 +69,16 @@ def generate_pixel_maps(batch_size: int, n_classes: int,
pixels_h = [i*pixel_size + k for k in grid_h for i, _ in pixels_idx]
pixels_w = [j*pixel_size + k for k in grid_w for _, j in pixels_idx]
# Randomly sample classes and copy them on all big-pixels
classes = torch.randint(0, n_classes, (n_pixels_int, ), device=device)
if isinstance(pixel_classes, list):
classes_np = np.random.choice(pixel_classes, size=n_pixels_int,
replace=True)
classes = torch.from_numpy(classes_np).to(device)
elif pixel_classes is None:
classes = torch.randint(0, n_classes, (n_pixels_int, ),
device=device)
else:
raise ValueError("pixel_classes must be list or None, found "
f"type {type(pixel_classes)}.")
classes = repeat(classes, "n -> (h w n)", h=pixel_size, w=pixel_size)
pixel_maps[i_batch, classes, pixels_h, pixels_w] = 1.0
# Class 0 here is actually the mask of the pixels to keep
Expand Down
Loading

0 comments on commit 3d6eb53

Please sign in to comment.