Skip to content

Commit

Permalink
style(examples): Reformat code using black
Browse files Browse the repository at this point in the history
  • Loading branch information
sangstar committed Jan 31, 2024
1 parent 365e68f commit 71e7da3
Showing 1 changed file with 31 additions and 26 deletions.
57 changes: 31 additions & 26 deletions examples/hf_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@


def check_file_exists(
file: str,
file: str,
):
"""
Check if file exists and is not empty. If the file is found locally,
Expand All @@ -55,18 +55,18 @@ def check_file_exists(
return True
else:
with stream_io.open_stream(file, "rb") as f:
if f.read(1) == b'':
if f.read(1) == b"":
return False
else:
return True


def serialize_model(
model: torch.nn.Module,
config: Optional[Union[ConfigMixin, AutoConfig, dict]],
model_directory: str,
model_prefix: str = None,
force: bool = False,
model: torch.nn.Module,
config: Optional[Union[ConfigMixin, AutoConfig, dict]],
model_directory: str,
model_prefix: str = None,
force: bool = False,
):
"""
Remove the tensors from a PyTorch model, convert them to NumPy
Expand All @@ -91,8 +91,10 @@ def serialize_model(
model_prefix = "model"

dir_prefix = f"{model_directory}/{model_prefix}"
config_file_exists, weights_file_exists = (check_file_exists(f"{dir_prefix}-config.json"),
check_file_exists(f"{dir_prefix}.tensors"))
config_file_exists, weights_file_exists = (
check_file_exists(f"{dir_prefix}-config.json"),
check_file_exists(f"{dir_prefix}.tensors"),
)
if config is None:
config = model
if config is not None:
Expand All @@ -114,16 +116,16 @@ def serialize_model(


def load_model(
path_uri: str,
model_class: Union[
Type[PreTrainedModel], Type[ModelMixin], Type[ConfigMixin]
],
config_class: Optional[
Union[Type[PretrainedConfig], Type[ConfigMixin], Type[AutoConfig]]
] = None,
model_prefix: Optional[str] = "model",
device: torch.device = utils.get_device(),
dtype: Optional[str] = None,
path_uri: str,
model_class: Union[
Type[PreTrainedModel], Type[ModelMixin], Type[ConfigMixin]
],
config_class: Optional[
Union[Type[PretrainedConfig], Type[ConfigMixin], Type[AutoConfig]]
] = None,
model_prefix: Optional[str] = "model",
device: torch.device = utils.get_device(),
dtype: Optional[str] = None,
) -> torch.nn.Module:
"""
Given a path prefix, load the model with a custom extension
Expand Down Expand Up @@ -212,7 +214,12 @@ def df_main(args: argparse.Namespace) -> None:
logger.info("GPU: " + utils.get_gpu_name())
logger.info("PYTHON USED RAM: " + utils.get_mem_usage())

serialize_model(pipeline.text_encoder.eval(), pipeline.text_encoder.config, output_prefix, "encoder")
serialize_model(
pipeline.text_encoder.eval(),
pipeline.text_encoder.config,
output_prefix,
"encoder",
)
serialize_model(pipeline.vae.eval(), None, output_prefix, "vae")
serialize_model(pipeline.unet.eval(), None, output_prefix, "unet")

Expand Down Expand Up @@ -258,7 +265,9 @@ def hf_main(args):

# May not be necessary if users can be assumed won't accidentally
# include trailing slashes in their output_prefix
output_prefix = output_prefix[:-1] if output_prefix[-1] == "/" else output_prefix
output_prefix = (
output_prefix[:-1] if output_prefix[-1] == "/" else output_prefix
)

print("MODEL PATH:", args.input_directory)
print("OUTPUT PREFIX:", output_prefix)
Expand All @@ -276,11 +285,7 @@ def hf_main(args):
logger.info("GPU: " + utils.get_gpu_name())
logger.info("PYTHON USED RAM: " + utils.get_mem_usage())

serialize_model(model,
model_config,
output_prefix,
None,
args.force)
serialize_model(model, model_config, output_prefix, None, args.force)

if args.validate:
# Not sure if this part is needed as, although I doubt it,
Expand Down

0 comments on commit 71e7da3

Please sign in to comment.