Skip to content

Commit

Permalink
Unit test for all supported pretrained networks and all supported layers
Browse files Browse the repository at this point in the history
  • Loading branch information
ArashAkbarinia committed Dec 17, 2023
1 parent 7b2cdad commit 62b3a1b
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions tests/models/pretrained_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,30 @@

import pytest
import torch
from torchvision import models as torch_models

from osculari.models import readout, available_layers
from osculari.models import readout, available_layers, available_models
from osculari.models import pretrained_models


def all_imagenet_networks_layers():
"""All pretrained ImageNet networks and supported layers."""
for net_name in torch_models.list_models(module=torch_models):
def all_networks_layers():
"""All supported pretrained networks and supported layers."""
for net_name in available_models(flatten=True):
for layer in available_layers(net_name):
yield net_name, layer


@pytest.mark.parametrize("net_name,layer", all_imagenet_networks_layers())
@pytest.mark.parametrize("net_name,layer", all_networks_layers())
def test_imagenet_models(net_name, layer):
img_size = 224
expected_sizes = {
'clip_RN50x4': 288,
'clip_RN50x16': 384,
'clip_RN50x64': 448,
'clip_ViT-L/14@336px': 336,
}
img_size = expected_sizes[net_name] if net_name in expected_sizes else 224
x1 = torch.randn(2, 3, img_size, img_size)
x2 = torch.randn(2, 3, img_size, img_size)
weights = None
weights = 'none'
readout_kwargs = {
'architecture': net_name, 'img_size': img_size,
'weights': weights,
Expand Down

0 comments on commit 62b3a1b

Please sign in to comment.