Skip to content

Commit

Permalink
Support specifying a torch range
Browse files Browse the repository at this point in the history
  • Loading branch information
oraluben authored Feb 25, 2025
1 parent c670ad8 commit 4272b38
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@ def _make_version_file(version, sha):


def _get_pytorch_version():
if "PYTORCH_VERSION" in os.environ:
return f"torch=={os.environ['PYTORCH_VERSION']}"
return "torch"
pytorch_dep = os.getenv("TORCH_PACKAGE_NAME", "torch")
if version_pin := os.getenv("PYTORCH_VERSION"):
pytorch_dep += "==" + version_pin
elif (version_pin_ge := os.getenv("PYTORCH_VERSION_GE")) and (version_pin_lt := os.getenv("PYTORCH_VERSION_LT")):
pytorch_dep += f">={version_pin_ge},<{version_pin_lt}"
return pytorch_dep


class clean(distutils.command.clean.clean):
Expand Down

0 comments on commit 4272b38

Please sign in to comment.