diff --git a/README.md b/README.md index 33b2eb2..49a41d8 100644 --- a/README.md +++ b/README.md @@ -75,12 +75,16 @@ DCGAN architecture). ### Installation -Install the module and dependencies in a virtual environment with Python 3.7-3.10: +Install the module and dependencies in a virtual environment with Python 3.7-3.10. +Using Linux or WSL is recommended to use [JAX](https://github.com/google/jax) +for faster metrics computation but it is not mandatory. ```bash pip install -e . pip install -r requirements.txt -# for developers only: +# dependencies without JAX (windows users): +pip install -r requirements-nojax.txt +# for developers: pip install -r requirements-dev.txt ``` @@ -90,9 +94,6 @@ A small dataset is available by default in this repository. It contains 2000 synthesized images representing some channels and 3 kind of facies and was generated in the [GANSim project](https://github.com/SuihongSong/GeoModeling_GANSim-2D_Condition_to_Well_Facies_and_Global_Features) (under [MIT license](./assets/third_party_licenses/GANSim%20MIT%20LICENSE)). -More synthesized data are available -[here](https://zenodo.org/record/3993791#.X1FQuMhKhaR). **If you use this dataset -in your work, please cite the original authors.** You can simply run a train on the default dataset with unconditional SAGAN model using the following command in `gan_facies` folder: @@ -104,6 +105,13 @@ python gan_facies/apps/train.py You can see the progress of the training in the terminal and the resulted images and trained networks in the `res` folder. +More data are available in this repository. Simply extract `datasets/datasets.zip` +in the `datasets` folder. It contains a bigger version of the GANSim dataset as +well as [Rongier et al. (2016)](https://hal.archives-ouvertes.fr/hal-01351694/) dataset +and [Stanford VI-E dataset](https://github.com/SCRFpublic/Stanford-VI-E/tree/master/Facies). + +**If you use this dataset in your work, please cite the original authors.** + ## Use your own dataset Of course, you can use your own dataset. Simply drop it in the `datasets` folder. @@ -148,7 +156,7 @@ configuration, then set the generator learning rate to 0.001 and the generator random input dimension to 64: ```bash -python gan_facies/apps/train.py --config gan_facies/configs/exp/my_config.yaml--training.g_lr=0.001\ +python gan_facies/apps/train.py --config gan_facies/configs/exp/my_config.yaml --training.g_lr=0.001\ --model.z_dim=64 ``` diff --git a/datasets/datasets.zip b/datasets/datasets.zip new file mode 100644 index 0000000..9cfba40 Binary files /dev/null and b/datasets/datasets.zip differ diff --git a/gan_facies/apps/post_process.ipynb b/gan_facies/apps/post_process.ipynb index 72efa6d..dc79b3b 100644 --- a/gan_facies/apps/post_process.ipynb +++ b/gan_facies/apps/post_process.ipynb @@ -272,7 +272,7 @@ "source": [ "### Pipeline\n", "\n", - "Indicator cipplings are \"superposable\" with each other." + "Indicator clipping are \"superposable\" with each other." ] }, { @@ -422,7 +422,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.15" + "version": "3.8.16 (default, Dec 7 2022, 01:12:06) \n[GCC 11.3.0]" }, "orig_nbformat": 4, "vscode": { diff --git a/gan_facies/metrics/components.py b/gan_facies/metrics/components.py index 32b1dcb..886f05e 100644 --- a/gan_facies/metrics/components.py +++ b/gan_facies/metrics/components.py @@ -4,14 +4,21 @@ """ from typing import Dict, List, Tuple -import jax -import jax.numpy as jnp import numpy as np from skimage.measure import label, regionprops -# Force Jax using CPU (using Jax with GPU will enter in conflict -# with Pytorch that should use all VRAM available) -jax.config.update('jax_platform_name', 'cpu') +try: + import jax + import jax.numpy as jnp + JAX_AVAILABLE = True +except ImportError: + jnp = np + JAX_AVAILABLE = False + +if JAX_AVAILABLE: + # Force Jax using CPU (using Jax with GPU will enter in conflict + # with Pytorch that should use all VRAM available) + jax.config.update('jax_platform_name', 'cpu') PropertiesType = Dict[str, np.ndarray] @@ -149,9 +156,8 @@ def get_perimeter(components: np.ndarray, neighbors: np.ndarray, class_id: int) -> np.ndarray: """Compute perimeter or surface area.""" - @jax.jit - def perimeter_component_jit(components: jnp.ndarray, i: jnp.ndarray, - mask_ext: jnp.ndarray) -> jnp.ndarray: + def perimeter_component(components: jnp.ndarray, i: jnp.ndarray, + mask_ext: jnp.ndarray) -> jnp.ndarray: """Compute perimeter of component i and stock it in perimeters.""" axis = tuple(range(1, neighbors.ndim)) mask_compo_i = jnp.expand_dims(components == i, axis=-1) @@ -161,6 +167,9 @@ def perimeter_component_jit(components: jnp.ndarray, i: jnp.ndarray, perimeter = jnp.sum(index, axis=axis, dtype=jnp.int32) return perimeter + if JAX_AVAILABLE: + perimeter_component = jax.jit(perimeter_component) + # connect_1: neighbors with 1-connectivity # 2D: 4-neighbors, 3D: 6-neighbors connect_1 = (neighbors[..., :4] @@ -176,9 +185,9 @@ def perimeter_component_jit(components: jnp.ndarray, i: jnp.ndarray, components = jnp.array(components, dtype=jnp.int32) mask_ext = jnp.array(mask_ext, dtype=jnp.uint8) for i in range(1, jnp.max(components) + 1): - perimeter = perimeter_component_jit(components, - jnp.array(i, dtype=jnp.uint8), - mask_ext) + perimeter = perimeter_component(components, + jnp.array(i, dtype=jnp.uint8), + mask_ext) perimeters.append(perimeter) perimeters_np = np.array(perimeters).T return perimeters_np # shape (n_imgs, max_n_components), dtype=np.int32 diff --git a/requirements-nojax.txt b/requirements-nojax.txt new file mode 100644 index 0000000..d3af878 --- /dev/null +++ b/requirements-nojax.txt @@ -0,0 +1,13 @@ +# clearml==1.6.3 # if you want to use clearml +einops==0.4.1 +matplotlib==3.5.2 +opencv-python==4.6.0.66 +pandas==1.4.3 +Pillow==9.2.0 +pytorch-ignite==0.4.9 +rich==12.5.1 +scikit_image==0.19.3 +thop +torch==1.8.1 +# wandb==0.12.21 # if you want to use wandb +yaecs==1.0.1 diff --git a/setup.py b/setup.py index 46f66e9..fe2d15f 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ """Setup of gan-face-editing.""" +import os import pathlib from setuptools import setup @@ -7,12 +8,12 @@ HERE = pathlib.Path(__file__).parent # The text of the README file -README = (HERE / "README.md").read_text() +README = os.path.join(HERE, "README.md") # Installation config = { 'name': 'gan-facies', - 'version': '2.5.0', + 'version': '2.5.1', 'description': 'Facies modeling with GAN.', 'long_description': README, 'long_description_content_type': 'text/markdown',