Skip to content

Commit

Permalink
Add postprocess script to log data
Browse files Browse the repository at this point in the history
  • Loading branch information
mmattamala committed Feb 7, 2024
1 parent b617745 commit a9518b7
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 72 deletions.
4 changes: 2 additions & 2 deletions tests/test_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_feature_extractor():
device = "cuda" if torch.cuda.is_available() else "cpu"
segmentation_types = ["none", "grid", "slic", "random", "stego"]
feature_types = ["dino", "dinov2", "stego"]
backbone_types = ["vit_small", "vit_base", "vit_small_reg", "vit_base_reg"]
backbone_types = ["vit_small", "vit_base"] # , "vit_small_reg", "vit_base_reg"]

for seg_type, feat_type, back_type in itertools.product(segmentation_types, feature_types, backbone_types):
if seg_type == "stego" and feat_type != "stego":
Expand All @@ -39,7 +39,7 @@ def test_feature_extractor():

ax[0].imshow(transform(img).permute(0, 2, 3, 1)[0].cpu())
ax[0].set_title("Image")
ax[1].imshow(seg.cpu(), cmap=plt.colormaps.get("inferno"))
ax[1].imshow(seg.cpu(), cmap=plt.colormaps.get("gray"))
ax[1].set_title("Segmentation")
plt.tight_layout()

Expand Down
Loading

0 comments on commit a9518b7

Please sign in to comment.