Skip to content

Commit

Permalink
Add warning if jsbeutifier not installed, set default for h5 in infer…
Browse files Browse the repository at this point in the history
…ence, fix import (#237)

Co-authored-by: Mohammad Amin Nabian <m.a.nabiyan@gmail.com>
  • Loading branch information
daviddpruitt and mnabian authored Nov 17, 2023
1 parent 4028d0c commit b615801
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
7 changes: 7 additions & 0 deletions modulus/experimental/sfno/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
parser.add_argument("--cuda_graph_mode", default='none', type=str, choices=["none", "fwdbwd", "step"], help="Specify which parts to capture under cuda graph")
parser.add_argument("--enable_benchy", action='store_true')
parser.add_argument("--disable_ddp", action='store_true')
parser.add_argument("--enable_odirect", action='store_true')
parser.add_argument("--enable_nhwc", action='store_true')
parser.add_argument("--checkpointing_level", default=0, type=int, help="How aggressively checkpointing is used")
parser.add_argument("--epsilon_factor", default = 0, type = float)
Expand All @@ -58,6 +59,11 @@
# parse
args = parser.parse_args()

# check whether the right h5py package is installed
odirect_env_var_name = "ENABLE_H5PY_ODIRECT"
if args.enable_odirect and os.environ.get(odirect_env_var_name, "False").lower() != "true":
raise RuntimeError(f"Error, {odirect_env_var_name} h5py with MPI support is not installed. Please refer to README for instructions on how to install it.")

# parse parameters
params = YParams(os.path.abspath(args.yaml_config), args.config)
params['epsilon_factor'] = args.epsilon_factor
Expand Down Expand Up @@ -122,6 +128,7 @@
params['cuda_graph_mode'] = args.cuda_graph_mode
params['enable_benchy'] = args.enable_benchy
params['disable_ddp'] = args.disable_ddp
params['enable_odirect'] = args.enable_odirect
params['enable_nhwc'] = args.enable_nhwc
params['checkpointing'] = args.checkpointing_level
params['enable_synthetic_data'] = args.enable_synthetic_data
Expand Down
20 changes: 15 additions & 5 deletions modulus/experimental/sfno/networks/model_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@

import logging

import warnings

try:
import jsbeautifier
use_jsbeautifier = True
except ImportError:
raise ImportError('jsbeautifier is not installed. Please install it with "pip install jsbeautifier"')
warnings.warn('jsbeautifier is not installed. Please install it with "pip install jsbeautifier"')
use_jsbeautifier = False

class LocalPackage:
"""
Expand Down Expand Up @@ -101,11 +105,15 @@ def save_model_package(params):
"""
# save out the current state of the parameters, make it human readable
config_path = os.path.join(params.experiment_dir, "config.json")
jsopts = jsbeautifier.default_options()
jsopts.indent_size = 2

msg = json.dumps(params.to_dict())
if use_jsbeautifier:
jsopts = jsbeautifier.default_options()
jsopts.indent_size = 2

msg = jsbeautifier.beautify(msg, jsopts)

with open(config_path, "w") as f:
msg = jsbeautifier.beautify(json.dumps(params.to_dict()), jsopts)
f.write(msg)

if hasattr(params, "add_orography") and params.add_orography:
Expand All @@ -131,7 +139,9 @@ def save_model_package(params):
"entrypoint": {"name": "networks.model_package:load_time_loop"},
}
with open(os.path.join(params.experiment_dir, "metadata.json"), "w") as f:
msg = jsbeautifier.beautify(json.dumps(fcn_mip_data), jsopts)
msg = json.dumps(fcn_mip_data)
if use_jsbeautifier:
msg = jsbeautifier.beautify(msg, jsopts)
f.write(msg)


Expand Down
1 change: 1 addition & 0 deletions modulus/experimental/sfno/utils/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import io
import numpy as np
import os
import concurrent.futures as cf
from PIL import Image
from moviepy.editor import ImageSequenceClip
Expand Down

0 comments on commit b615801

Please sign in to comment.