Skip to content

Commit

Permalink
[Metadat utils] fix: json lines ordering. (#7744)
Browse files Browse the repository at this point in the history
fix: json lines ordering.
  • Loading branch information
sayakpaul authored Apr 23, 2024
1 parent fc9fecc commit 5a69227
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions utils/update_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import pandas as pd
from datasets import Dataset
from huggingface_hub import upload_folder
from huggingface_hub import hf_hub_download, upload_folder

from diffusers.pipelines.auto_pipeline import (
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
Expand All @@ -39,6 +39,9 @@
)


PIPELINE_TAG_JSON = "pipeline_tags.json"


def get_supported_pipeline_table() -> dict:
"""
Generates a dictionary containing the supported auto classes for each pipeline type,
Expand All @@ -57,8 +60,8 @@ def get_supported_pipeline_table() -> dict:
(class_name.__name__, "image-to-image", "AutoPipelineForInpainting")
for _, class_name in AUTO_INPAINT_PIPELINES_MAPPING.items()
]
all_supported_pipeline_classes.sort(key=lambda x: x[0])
all_supported_pipeline_classes = list(set(all_supported_pipeline_classes))
all_supported_pipeline_classes.sort(key=lambda x: x[0])

data = {}
data["pipeline_class"] = [sample[0] for sample in all_supported_pipeline_classes]
Expand All @@ -79,8 +82,24 @@ def update_metadata(commit_sha: str):
pipelines_table = pd.DataFrame(pipelines_table)
pipelines_dataset = Dataset.from_pandas(pipelines_table)

hub_pipeline_tags_json = hf_hub_download(
repo_id="huggingface/diffusers-metadata",
filename=PIPELINE_TAG_JSON,
repo_type="dataset",
)
with open(hub_pipeline_tags_json) as f:
hub_pipeline_tags_json = f.read()

with tempfile.TemporaryDirectory() as tmp_dir:
pipelines_dataset.to_json(os.path.join(tmp_dir, "pipeline_tags.json"))
pipelines_dataset.to_json(os.path.join(tmp_dir, PIPELINE_TAG_JSON))

with open(os.path.join(tmp_dir, PIPELINE_TAG_JSON)) as f:
pipeline_tags_json = f.read()

hub_pipeline_tags_equal = hub_pipeline_tags_json == pipeline_tags_json
if hub_pipeline_tags_equal:
print("No updates, not pushing the metadata files.")
return

if commit_sha is not None:
commit_message = (
Expand Down

0 comments on commit 5a69227

Please sign in to comment.