Skip to content

Commit

Permalink
[BugFix] torch 2.0 compatibility fix
Browse files Browse the repository at this point in the history
ghstack-source-id: 90dd8ba898d215bd09cb810ed88c1f301c4ae77b
Pull Request resolved: #2475
  • Loading branch information
vmoens committed Oct 9, 2024
1 parent d5c93fb commit 4790d3b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 13 deletions.
9 changes: 7 additions & 2 deletions torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
except ImportError:
__version__ = None

try:
from torch.compiler import is_dynamo_compiling
except ImportError:
from torch._dynamo import is_compiling as is_dynamo_compiling

_init_extension()

try:
Expand Down Expand Up @@ -69,7 +74,7 @@ def _inv(self):
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
return inv

Expand All @@ -84,7 +89,7 @@ def _inv(self):
inv = self._inv()
if inv is None:
inv = ComposeTransform([p.inv for p in reversed(self.parts)])
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
inv._inv = weakref.ref(self)
else:
Expand Down
21 changes: 11 additions & 10 deletions torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
# speeds up distribution construction
D.Distribution.set_default_validate_args(False)

try:
from torch.compiler import is_dynamo_compiling
except ImportError:
from torch._dynamo import is_compiling as is_dynamo_compiling


class IndependentNormal(D.Independent):
"""Implements a Normal distribution with location scaling.
Expand Down Expand Up @@ -112,7 +117,7 @@ def inv(self):
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
return inv

Expand Down Expand Up @@ -320,7 +325,7 @@ def inv(self):
inv = self._inv()
if inv is None:
inv = _PatchedComposeTransform([p.inv for p in reversed(self.parts)])
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
inv._inv = weakref.ref(self)
return inv
Expand All @@ -334,7 +339,7 @@ def inv(self):
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
return inv

Expand Down Expand Up @@ -432,15 +437,13 @@ def __init__(
self.high = high

if safe_tanh:
if torch.compiler.is_dynamo_compiling():
if is_dynamo_compiling():
_err_compile_safetanh()
t = SafeTanhTransform()
else:
t = D.TanhTransform()
# t = D.TanhTransform()
if torch.compiler.is_dynamo_compiling() or (
self.non_trivial_max or self.non_trivial_min
):
if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min):
t = _PatchedComposeTransform(
[
t,
Expand All @@ -467,9 +470,7 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
if self.tanh_loc:
loc = (loc / self.upscale).tanh() * self.upscale
# loc must be rescaled if tanh_loc
if torch.compiler.is_dynamo_compiling() or (
self.non_trivial_max or self.non_trivial_min
):
if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min):
loc = loc + (self.high - self.low) / 2 + self.low
self.loc = loc
self.scale = scale
Expand Down
7 changes: 6 additions & 1 deletion torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
raise err_ft from err
from torchrl.envs.utils import step_mdp

try:
from torch.compiler import is_dynamo_compiling
except ImportError:
from torch._dynamo import is_compiling as is_dynamo_compiling

_GAMMA_LMBDA_DEPREC_ERROR = (
"Passing gamma / lambda parameters through the loss constructor "
"is a deprecated feature. To customize your value function, "
Expand Down Expand Up @@ -460,7 +465,7 @@ def _cache_values(func):

@functools.wraps(func)
def new_func(self, netname=None):
if torch.compiler.is_dynamo_compiling():
if is_dynamo_compiling():
if netname is not None:
return func(self, netname)
else:
Expand Down

0 comments on commit 4790d3b

Please sign in to comment.