Skip to content

Commit

Permalink
Merge pull request #61 from valentingol/dev
Browse files Browse the repository at this point in the history
🆙 Upgrade to 2.5.1
  • Loading branch information
valentingol authored Dec 16, 2022
2 parents e1b5934 + c70c0d1 commit 60f8381
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 21 deletions.
20 changes: 14 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,16 @@ DCGAN architecture).

### Installation

Install the module and dependencies in a virtual environment with Python 3.7-3.10:
Install the module and dependencies in a virtual environment with Python 3.7-3.10.
Using Linux or WSL is recommended to use [JAX](https://github.com/google/jax)
for faster metrics computation but it is not mandatory.

```bash
pip install -e .
pip install -r requirements.txt
# for developers only:
# dependencies without JAX (windows users):
pip install -r requirements-nojax.txt
# for developers:
pip install -r requirements-dev.txt
```

Expand All @@ -90,9 +94,6 @@ A small dataset is available by default in this repository. It contains 2000
synthesized images representing some channels and 3 kind of facies and was
generated in the [GANSim project](https://github.com/SuihongSong/GeoModeling_GANSim-2D_Condition_to_Well_Facies_and_Global_Features)
(under [MIT license](./assets/third_party_licenses/GANSim%20MIT%20LICENSE)).
More synthesized data are available
[here](https://zenodo.org/record/3993791#.X1FQuMhKhaR). **If you use this dataset
in your work, please cite the original authors.**

You can simply run a train on the default dataset with unconditional SAGAN
model using the following command in `gan_facies` folder:
Expand All @@ -104,6 +105,13 @@ python gan_facies/apps/train.py
You can see the progress of the training in the terminal and the resulted
images and trained networks in the `res` folder.

More data are available in this repository. Simply extract `datasets/datasets.zip`
in the `datasets` folder. It contains a bigger version of the GANSim dataset as
well as [Rongier et al. (2016)](https://hal.archives-ouvertes.fr/hal-01351694/) dataset
and [Stanford VI-E dataset](https://github.com/SCRFpublic/Stanford-VI-E/tree/master/Facies).

**If you use this dataset in your work, please cite the original authors.**

## Use your own dataset

Of course, you can use your own dataset. Simply drop it in the `datasets` folder.
Expand Down Expand Up @@ -148,7 +156,7 @@ configuration, then set the generator learning rate to 0.001 and the generator
random input dimension to 64:

```bash
python gan_facies/apps/train.py --config gan_facies/configs/exp/my_config.yaml--training.g_lr=0.001\
python gan_facies/apps/train.py --config gan_facies/configs/exp/my_config.yaml --training.g_lr=0.001\
--model.z_dim=64
```

Expand Down
Binary file added datasets/datasets.zip
Binary file not shown.
4 changes: 2 additions & 2 deletions gan_facies/apps/post_process.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@
"source": [
"### Pipeline\n",
"\n",
"Indicator cipplings are \"superposable\" with each other."
"Indicator clipping are \"superposable\" with each other."
]
},
{
Expand Down Expand Up @@ -422,7 +422,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.15"
"version": "3.8.16 (default, Dec 7 2022, 01:12:06) \n[GCC 11.3.0]"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
31 changes: 20 additions & 11 deletions gan_facies/metrics/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@
"""
from typing import Dict, List, Tuple

import jax
import jax.numpy as jnp
import numpy as np
from skimage.measure import label, regionprops

# Force Jax using CPU (using Jax with GPU will enter in conflict
# with Pytorch that should use all VRAM available)
jax.config.update('jax_platform_name', 'cpu')
try:
import jax
import jax.numpy as jnp
JAX_AVAILABLE = True
except ImportError:
jnp = np
JAX_AVAILABLE = False

if JAX_AVAILABLE:
# Force Jax using CPU (using Jax with GPU will enter in conflict
# with Pytorch that should use all VRAM available)
jax.config.update('jax_platform_name', 'cpu')

PropertiesType = Dict[str, np.ndarray]

Expand Down Expand Up @@ -149,9 +156,8 @@ def get_perimeter(components: np.ndarray, neighbors: np.ndarray,
class_id: int) -> np.ndarray:
"""Compute perimeter or surface area."""

@jax.jit
def perimeter_component_jit(components: jnp.ndarray, i: jnp.ndarray,
mask_ext: jnp.ndarray) -> jnp.ndarray:
def perimeter_component(components: jnp.ndarray, i: jnp.ndarray,
mask_ext: jnp.ndarray) -> jnp.ndarray:
"""Compute perimeter of component i and stock it in perimeters."""
axis = tuple(range(1, neighbors.ndim))
mask_compo_i = jnp.expand_dims(components == i, axis=-1)
Expand All @@ -161,6 +167,9 @@ def perimeter_component_jit(components: jnp.ndarray, i: jnp.ndarray,
perimeter = jnp.sum(index, axis=axis, dtype=jnp.int32)
return perimeter

if JAX_AVAILABLE:
perimeter_component = jax.jit(perimeter_component)

# connect_1: neighbors with 1-connectivity
# 2D: 4-neighbors, 3D: 6-neighbors
connect_1 = (neighbors[..., :4]
Expand All @@ -176,9 +185,9 @@ def perimeter_component_jit(components: jnp.ndarray, i: jnp.ndarray,
components = jnp.array(components, dtype=jnp.int32)
mask_ext = jnp.array(mask_ext, dtype=jnp.uint8)
for i in range(1, jnp.max(components) + 1):
perimeter = perimeter_component_jit(components,
jnp.array(i, dtype=jnp.uint8),
mask_ext)
perimeter = perimeter_component(components,
jnp.array(i, dtype=jnp.uint8),
mask_ext)
perimeters.append(perimeter)
perimeters_np = np.array(perimeters).T
return perimeters_np # shape (n_imgs, max_n_components), dtype=np.int32
Expand Down
13 changes: 13 additions & 0 deletions requirements-nojax.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# clearml==1.6.3 # if you want to use clearml
einops==0.4.1
matplotlib==3.5.2
opencv-python==4.6.0.66
pandas==1.4.3
Pillow==9.2.0
pytorch-ignite==0.4.9
rich==12.5.1
scikit_image==0.19.3
thop
torch==1.8.1
# wandb==0.12.21 # if you want to use wandb
yaecs==1.0.1
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
"""Setup of gan-face-editing."""

import os
import pathlib

from setuptools import setup

HERE = pathlib.Path(__file__).parent

# The text of the README file
README = (HERE / "README.md").read_text()
README = os.path.join(HERE, "README.md")

# Installation
config = {
'name': 'gan-facies',
'version': '2.5.0',
'version': '2.5.1',
'description': 'Facies modeling with GAN.',
'long_description': README,
'long_description_content_type': 'text/markdown',
Expand Down

0 comments on commit 60f8381

Please sign in to comment.