diff --git a/util/misc.py b/util/misc.py index 6d4d076..bcf7282 100644 --- a/util/misc.py +++ b/util/misc.py @@ -27,7 +27,7 @@ # needed due to empty tensor bug in pytorch and torchvision 0.5 import torchvision -if float(torchvision.__version__[:3]) < 0.5: +if float(torchvision.__version__.split('.')[1]) < 5: import math from torchvision.ops.misc import _NewEmptyTensorOp def _check_size_scale_factor(dim, size, scale_factor): @@ -54,7 +54,7 @@ def _output_size(dim, input, size, scale_factor): return [ int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim) ] -elif float(torchvision.__version__[:3]) < 0.7: +elif float(torchvision.__version__.split('.')[1]) < 7: from torchvision.ops import _new_empty_tensor from torchvision.ops.misc import _output_size @@ -487,7 +487,7 @@ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corne This will eventually be supported natively by PyTorch, and this class can go away. """ - if float(torchvision.__version__[:3]) < 0.7: + if float(torchvision.__version__.split('.')[1]) < 7: if input.numel() > 0: return torch.nn.functional.interpolate( input, size, scale_factor, mode, align_corners @@ -495,7 +495,7 @@ class can go away. output_shape = _output_size(2, input, size, scale_factor) output_shape = list(input.shape[:-2]) + list(output_shape) - if float(torchvision.__version__[:3]) < 0.5: + if float(torchvision.__version__.split('.')[1]) < 5: return _NewEmptyTensorOp.apply(input, output_shape) return _new_empty_tensor(input, output_shape) else: